-
Notifications
You must be signed in to change notification settings - Fork 0
/
aig_prompt_protection.py
65 lines (55 loc) · 1.76 KB
/
aig_prompt_protection.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from __future__ import annotations
import sys
import click
from openai import OpenAI
from pangea import PangeaConfig
from pangea.services import AIGuard
@click.command()
@click.option("--model", default="gpt-4o-mini", show_default=True, required=True, help="OpenAI model.")
@click.option(
"--ai-guard-token",
envvar="PANGEA_AI_GUARD_TOKEN",
required=True,
help="Pangea AI Guard API token. May also be set via the `PANGEA_AI_GUARD_TOKEN` environment variable.",
)
@click.option(
"--pangea-domain",
envvar="PANGEA_DOMAIN",
default="aws.us.pangea.cloud",
show_default=True,
required=True,
help="Pangea API domain. May also be set via the `PANGEA_DOMAIN` environment variable.",
)
@click.option(
"--openai-api-key",
envvar="OPENAI_API_KEY",
required=True,
help="OpenAI API key. May also be set via the `OPENAI_API_KEY` environment variable.",
)
@click.argument("prompt")
def main(
*,
prompt: str,
model: str,
ai_guard_token: str,
pangea_domain: str,
openai_api_key: str,
) -> None:
config = PangeaConfig(domain=pangea_domain)
ai_guard = AIGuard(token=ai_guard_token, config=config)
# Guard the prompt before sending it to the LLM.
guarded = ai_guard.guard_text(prompt)
assert guarded.result
redacted_prompt = guarded.result.redacted_prompt or prompt
# Generate chat completions.
stream = OpenAI(api_key=openai_api_key).chat.completions.create(
messages=[{"role": "user", "content": redacted_prompt}], model=model, stream=True
)
for chunk in stream:
for choice in chunk.choices:
sys.stdout.write(choice.delta.content or "")
sys.stdout.flush()
sys.stdout.flush()
sys.stdout.write("\n")
if __name__ == "__main__":
main()