Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Add language detection and filtering #38

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 4 additions & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"version": "0.2.0",
"description": "Prevent GPT prompt attacks for Node.js & TypeScript",
"main": "./dist/index.js",
"scripts": {
"scripts": {
"test": "jest",
"build": "tsc --build",
"clean": "tsc --build --clean",
Expand Down Expand Up @@ -43,5 +43,8 @@
"jest": "^29.4.1",
"ts-jest": "^29.0.5",
"typescript": "^4.9.5"
},
"dependencies": {
"lande": "^1.0.10"
}
}
56 changes: 43 additions & 13 deletions src/index.ts
Original file line number Diff line number Diff line change
@@ -1,76 +1,106 @@
#!/usr/bin/env ts-node
import {
promptContainsDenyListItems,
countPromptTokens,
encodePromptOutput,
promptContainsKnownAttack
promptContainsKnownAttack,
promptContainsLanguages,
promptContainsDenyListItems,
encodePromptOutput
} from './utils';

enum FAILURE_REASON {
DENY_LIST = 'CONTAINS_DENY_LIST_ITEM',
MAX_TOKEN_THRESHOLD = 'EXCEEDS_MAX_TOKEN_THRESHOLD',
KNOWN_ATTACK = 'CONTAINS_KNOWN_ATTACK'
KNOWN_ATTACK = 'CONTAINS_KNOWN_ATTACK',
LANGUAGE_VALIDATION = 'FAILED_LANGUAGE_VALIDATION'
}

type UserPolicyOptions = {
maxTokens?: number;
denyList?: string[];
ignoreDefaultDenyList?: boolean;
allowedLanguages?: string[];
deniedLanguages?: string[];
encodeOutput?: boolean;
};

interface PromptGuardPolicy {
maxTokens: number; // 1 token is ~4 characters in english
denyList: string[]; // this should be a fuzzy match
denyList: string[]; // this should use a fuzzy match but doesn't currently
disableAttackMitigation: boolean;
allowedLanguages: string[];
deniedLanguages: string[];
encodeOutput: boolean; // uses byte pair encoding to turn text into a series of integers
}

type PromptOutput = {
pass: boolean; // false if processing fails validation rules (max tokens, deny list, allow list)
pass: boolean; // false if processing fails validation rules
output: string | number[]; // provide the processed prompt or failure reason
};

export class PromptGuard {
promptGuardPolicy: PromptGuardPolicy;
policy: PromptGuardPolicy;

constructor(userPolicyOptions: UserPolicyOptions = {}) {
const defaultPromptGuardPolicy: PromptGuardPolicy = {
maxTokens: 4096,
denyList: [''],
disableAttackMitigation: false,
allowedLanguages: [''],
deniedLanguages: [''],
encodeOutput: false
};

// TODO validate the languages against the list of ISO 639-3 supported languages
// TODO validate that the allowed and denied language lists don't contain the same languages

// merge the user policy with the default policy to create the policy
this.promptGuardPolicy = {
this.policy = {
...defaultPromptGuardPolicy,
...userPolicyOptions
};
}

async process(prompt: string): Promise<PromptOutput> {
// processing order
// normalize -> quote -> escape -> check tokens -> check cache -> check for known attacks -> check allow list -> check deny list -> encode output
// check tokens -> check allowed languages -> check denied languages ->
// check for known attacks -> check deny list -> encode output

// check the prompt token count
if (countPromptTokens(prompt) > this.promptGuardPolicy.maxTokens)
if (countPromptTokens(prompt) > this.policy.maxTokens)
return { pass: false, output: FAILURE_REASON.MAX_TOKEN_THRESHOLD };

// check for the presence of allowed languages
// the prompt must be at least 10 characters long to reasonably expect to detect the language
if (prompt.length > 10) {
const allowedLanguages = this.policy.allowedLanguages;
const deniedLanguages = this.policy.deniedLanguages;

if (allowedLanguages[0] !== '') {
if (await !promptContainsLanguages(prompt, allowedLanguages))
return { pass: false, output: FAILURE_REASON.LANGUAGE_VALIDATION };
}
if (deniedLanguages[0] !== '') {
if (await promptContainsLanguages(prompt, deniedLanguages))
return { pass: false, output: FAILURE_REASON.LANGUAGE_VALIDATION };
}
}

// check for the presence of denied languages

// check prompt against known prompt attacks
if (!this.promptGuardPolicy.disableAttackMitigation) {
if (!this.policy.disableAttackMitigation) {
if (await promptContainsKnownAttack(prompt))
return { pass: false, output: FAILURE_REASON.KNOWN_ATTACK };
}

// check prompt again the user defined deny list
if (
await promptContainsDenyListItems(prompt, this.promptGuardPolicy.denyList)
await promptContainsDenyListItems(prompt, this.policy.denyList)
)
return { pass: false, output: FAILURE_REASON.DENY_LIST };

// encode the prompt output if encodeOutput is set by the user
if (this.promptGuardPolicy.encodeOutput)
if (this.policy.encodeOutput)
prompt = encodePromptOutput(prompt);

return { pass: true, output: prompt };
Expand Down
20 changes: 20 additions & 0 deletions src/tests/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,23 @@ test("Checks that the prompt output is encoded when encodeOutput is true", async
output: [5661, 318, 616, 7427, 6152, 0],
});
});

test("Checks the prompt language allow list", async () => {
const promptGuard = new PromptGuard({ allowedLanguages: ['eng']});
const prompt = "this is my awesome prompt!";
const output = await promptGuard.process(prompt);
expect(output).toStrictEqual({
pass: true,
output: prompt,
});
});

test("Checks the prompt language denied list", async () => {
const promptGuard = new PromptGuard({ deniedLanguages: ['eng']});
const prompt = "this is my awesome prompt!";
const output = await promptGuard.process(prompt);
expect(output).toStrictEqual({
pass: false,
output: 'FAILED_LANGUAGE_VALIDATION',
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ type DenyList = string[];

const attackDenyList: DenyList = [
"ignore above",
"gnore the above",
"ignore the above",
"ignore previous instructions",
"ignore the previous instructions",
"ignore above instructions",
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
9 changes: 6 additions & 3 deletions src/utils/index.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import { containsDenyListItems } from "./denylist";
import { containsDenyListItems } from './deny_list';
export const promptContainsDenyListItems = containsDenyListItems;

import { containsKnownAttack } from "./attackmitigation";
import { containsKnownAttack } from './attack_mitigation';
export const promptContainsKnownAttack = containsKnownAttack;

const encoder = require("./encoder");
import { containsLanguages } from './language_detection';
export const promptContainsLanguages = containsLanguages;

const encoder = require('./gpt_encoder');
export const countPromptTokens = encoder.countTokens;
export const encodePromptOutput = encoder.encode;
32 changes: 32 additions & 0 deletions src/utils/language_detection/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import lande from "lande";

type LandeOuput = Array<[string, number]>;
type detectLanguageOutput = string[];

export async function containsLanguages(
prompt: string,
languages: string[],
): Promise<boolean> {
const detectedLanguages: detectLanguageOutput = [];

// lande returns a sorted list of detected languages and their probabilities.
// for now, we're selecting all languages with a probability greater than 80%
// this may need to be tuned later
const landeOuput: LandeOuput = lande(prompt);

for (const lang of landeOuput) {
if (lang[1] > 0.8) detectedLanguages.push(lang[0]);
else break;
}

for (const lang of detectedLanguages) {
if (languages.includes(lang)) return true;
}

return false;
}

// export async function validateLanguageList(list: string[]): Promise<boolean> {
// //foo
// return true;
// }