Skip to content

Commit

Permalink
Simplify API Token usage (#31)
Browse files Browse the repository at this point in the history
  • Loading branch information
petermuller authored Jun 15, 2024
1 parent 18d4e4e commit e9572b7
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 33 deletions.
27 changes: 14 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,10 @@ pytest lisa-sdk/tests --url <rest-url-from-cdk-output> --verify <path-to-server.
The LISA Serve ALB can be used for programmatic access outside the example Chat application.
An example use case would be for allowing LISA to serve LLM requests that originate from the [Continue VSCode Plugin](https://www.continue.dev/).
To facilitate communication directly with the LISA Serve ALB, a user with sufficient DynamoDB PutItem permissions may add
API keys to the APITokenTable, and once created, a user may make requests by including the `Api-Key` header with that token.
API keys to the APITokenTable, and once created, a user may make requests by including the `Authorization: Bearer ${token}`
header or the `Api-Key: ${token}` header with that token. If using any OpenAI-compatible library, the `api_key` fields
will use the `Authorization: Bearer ${token}` format automatically, so there is no need to include additional headers
when using those libraries.

### Adding a Token

Expand Down Expand Up @@ -375,7 +378,7 @@ aws --region $AWS_REGION dynamodb put-item --table-name LISAApiTokenTable \
}'
```

Once the token is inserted into the DynamoDB Table, a user may use the token in the `Api-Key` request header like
Once the token is inserted into the DynamoDB Table, a user may use the token in the `Authorization` request header like
in the following snippet.

```bash
Expand All @@ -384,7 +387,7 @@ token_string="YOUR_STRING_HERE"
curl ${lisa_serve_rest_url}/v2/serve/models \
-H 'accept: application/json' \
-H 'Content-Type: application/json' \
-H 'Api-Key: '${token_string}
-H "Authorization: Bearer ${token_string}"
```

### Updating a Token
Expand Down Expand Up @@ -537,17 +540,17 @@ routes as long as your underlying models can also respond to them.
By supporting the OpenAI spec, we can more easily allow users to integrate their collection of models into their LLM applications and workflows. In LISA, users can authenticate
using their OpenID Connect Identity Provider, or with an API token created through the DynamoDB token workflow as described [here](#programmatic-api-tokens). Once the token
is retrieved, users can use that in direct requests to the LISA Serve REST API. If using the IdP, users must set the 'Authorization' header, otherwise if using the API token,
users must set the 'Api-Key' header. After that, requests to `https://${lisa_serve_alb}/v2/serve` will handle the OpenAI API calls. As an example, the following call can list all
models that LISA is aware of, assuming usage of the API token.
users can set either the 'Api-Key' header or the 'Authorization' header. After that, requests to `https://${lisa_serve_alb}/v2/serve` will handle the OpenAI API calls. As an example, the following call can list all
models that LISA is aware of, assuming usage of the API token. If you are using a self-signed cert, you must also provide the `--cacert $path` option to specify a CA bundle to trust for SSL verification.

```shell
curl -s -H 'Api-Key: your-api-token' -X GET https://${lisa_serve_alb}/v2/serve/models
curl -s -H 'Api-Key: your-token' -X GET https://${lisa_serve_alb}/v2/serve/models
```

If using the IdP, the request would look like the following:

```shell
curl -s -H 'Authorization: Bearer your-bearer-token' -X GET https://${lisa_serve_alb}/v2/serve/models
curl -s -H 'Authorization: Bearer your-token' -X GET https://${lisa_serve_alb}/v2/serve/models
```

When using a library that requests an OpenAI-compatible base_url, you can provide `https://${lisa_serve_alb}/v2/serve` here. All of the OpenAI routes will
Expand Down Expand Up @@ -589,12 +592,11 @@ client = OpenAI(
client.models.list()
```

To use the models being served by LISA, the client needs four changes:
To use the models being served by LISA, the client needs only a few changes:

1. Specify the `base_url` as the LISA Serve ALB, using the /v2/serve route at the end, similar to the apiBase in the [Continue example](#continue-jetbrains-and-vs-code-plugin)
2. Change the api_key to be any string. This will be ignored by LISA, but for the OpenAI library to not fail, it needs to be defined.
3. Add the `default_headers` option, setting the header for "Api-Key" to a valid token value, defined in DynamoDB from the [token creation](#programmatic-api-tokens) steps
4. If using a self-signed cert, you must provide a certificate path for validating SSL. If you're using an ACM or public cert, then this may be omitted.
2. Add the API key that you generated from the [token generation steps](#programmatic-api-tokens) as your `api_key` field.
3. If using a self-signed cert, you must provide a certificate path for validating SSL. If you're using an ACM or public cert, then this may be omitted.
1. We provide a convenience function in the `lisa-sdk` for generating a cert path from an IAM certificate ARN if one is provided in the `RESTAPI_SSL_CERT_ARN` environment variable.

The Code block will now look like this and you can continue to use the library without any other modifications.
Expand All @@ -610,9 +612,8 @@ iam_client = boto3.client("iam")
cert_path = get_cert_path(iam_client)

client = OpenAI(
api_key="ignored", # LISA ignores this field, but it must be defined # pragma: allowlist-secret not a real key
api_key="my_key", # pragma: allowlist-secret not a real key
base_url="https://<lisa_serve_alb>/v2/serve",
default_headers={"Api-Key": "my_api_token"}, # pragma: allowlist-secret not a real key
http_client=DefaultHttpxClient(verify=cert_path), # needed for self-signed certs on your ALB, can be omitted otherwise
)
client.models.list()
Expand Down
3 changes: 1 addition & 2 deletions lambda/repository/lambda_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,10 @@ def _get_embeddings(model_name: str, id_token: str) -> LisaOpenAIEmbeddings:
lisa_api_param_response = ssm_client.get_parameter(Name=os.environ["LISA_API_URL_PS_NAME"])
lisa_api_endpoint = lisa_api_param_response["Parameter"]["Value"]

headers = {"Authorization": f"Bearer {id_token}"}
base_url = f"{lisa_api_endpoint}/{os.environ['REST_API_VERSION']}/serve"

embedding = LisaOpenAIEmbeddings(
lisa_openai_api_base=base_url, model=model_name, headers=headers, verify=get_cert_path(iam_client)
lisa_openai_api_base=base_url, model=model_name, api_token=id_token, verify=get_cert_path(iam_client)
)
return embedding

Expand Down
25 changes: 14 additions & 11 deletions lib/serve/rest-api/src/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@
from starlette.status import HTTP_401_UNAUTHORIZED

# The following are field names, not passwords or tokens
API_KEY_HEADER_NAME = "Api-Key" # pragma: allowlist secret
API_KEY_HEADER_NAMES = [
"Authorization", # OpenAI Bearer token format, collides with IdP, but that's okay for this use case
"Api-Key", # pragma: allowlist secret # Azure key format, can be used with Continue IDE plugin
]
TOKEN_EXPIRATION_NAME = "tokenExpiration" # nosec B105
TOKEN_TABLE_NAME = "TOKEN_TABLE_NAME" # nosec B105

Expand Down Expand Up @@ -134,13 +137,13 @@ def _get_token_info(self, token: str) -> Any:

def is_valid_api_token(self, headers: Dict[str, str]) -> bool:
"""Return if API Token from request headers is valid if found."""
is_valid = False
token = headers.get(API_KEY_HEADER_NAME, None)
if token:
token_info = self._get_token_info(token)
if token_info:
token_expiration = int(token_info.get(TOKEN_EXPIRATION_NAME, datetime.max.timestamp()))
current_time = int(datetime.now().timestamp())
if current_time < token_expiration: # token has not expired yet
is_valid = True
return is_valid
for header_name in API_KEY_HEADER_NAMES:
token = headers.get(header_name, "").removeprefix("Bearer").strip()
if token:
token_info = self._get_token_info(token)
if token_info:
token_expiration = int(token_info.get(TOKEN_EXPIRATION_NAME, datetime.max.timestamp()))
current_time = int(datetime.now().timestamp())
if current_time < token_expiration: # token has not expired yet
return True
return False
4 changes: 2 additions & 2 deletions lisa-sdk/LISA_v2_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,8 @@
"\n",
"# initialize OpenAI client\n",
"client = openai.OpenAI(\n",
" api_key=\"ignored\", # LISA ignores this field, but it must be defined # pragma: allowlist-secret not a real key\n",
" api_key=api_token,\n",
" base_url=lisa_serve_base_url,\n",
" default_headers={\"Api-Key\": api_token},\n",
" http_client=openai.DefaultHttpxClient(verify=cert_path), # needed for self-signed certs on your ALB, can be omitted otherwise\n",
")\n",
"\n",
Expand Down Expand Up @@ -266,6 +265,7 @@
"embeddings_response = client.embeddings.create(\n",
" model=embedding_model,\n",
" input=\"Hello, world!\",\n",
" encoding_format=\"float\",\n",
")\n",
"vector = embeddings_response.data[0].embedding\n",
"\n",
Expand Down
9 changes: 4 additions & 5 deletions lisa-sdk/lisapy/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

"""Langchain adapter."""
from typing import Any, cast, Dict, Iterator, List, Mapping, Optional, Union
from typing import Any, cast, Iterator, List, Mapping, Optional, Union

from httpx import AsyncClient as HttpAsyncClient
from httpx import Client as HttpClient
Expand Down Expand Up @@ -108,8 +108,8 @@ class LisaOpenAIEmbeddings(BaseModel, Embeddings):
model: str
"""Model name for Embeddings API."""

headers: Dict[str, str]
"""Headers to add to model request."""
api_token: str
"""API Token for communicating with LISA Serve. This can be a custom API token or the IdP Bearer token."""

verify: Union[bool, str]
"""Cert path or option for verifying SSL"""
Expand All @@ -127,14 +127,13 @@ def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.embedding_model = OpenAIEmbeddings(
openai_api_base=self.lisa_openai_api_base,
openai_api_key="ignored", # pragma: allowlist secret
openai_api_key=self.api_token,
model=self.model,
model_kwargs={
"encoding_format": "float", # keep values as floats because base64 is not widely supported
},
http_async_client=HttpAsyncClient(verify=self.verify),
http_client=HttpClient(verify=self.verify),
default_headers=self.headers,
)

def embed_documents(self, texts: List[str]) -> List[List[float]]:
Expand Down

0 comments on commit e9572b7

Please sign in to comment.