Skip to content

Commit

Permalink
Rag ETL pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
estohlmann authored Nov 21, 2024
2 parents 271b136 + a94576e commit 1cb1b28
Show file tree
Hide file tree
Showing 34 changed files with 1,344 additions and 144 deletions.
8 changes: 3 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ repos:
name: isort (python)

- repo: https://github.com/ambv/black
rev: '23.10.1'
rev: '24.10.0'
hooks:
- id: black

Expand All @@ -66,21 +66,19 @@ repos:
args: [--exit-non-zero-on-fix]

- repo: https://github.com/pycqa/flake8
rev: '6.1.0'
rev: '7.1.1'
hooks:
- id: flake8
additional_dependencies:
- flake8-docstrings
- flake8-broken-line
- flake8-bugbear
- flake8-comprehensions
- flake8-debugger
- flake8-string-format
args:
- --docstring-convention=numpy
- --max-line-length=120
- --extend-immutable-calls=Query,fastapi.Depends,fastapi.params.Depends
- --ignore=B008 # Ignore error for function calls in argument defaults
- --ignore=B008,E203 # Ignore error for function calls in argument defaults
exclude: ^(__init__.py$|.*\/__init__.py$)


Expand Down
13 changes: 13 additions & 0 deletions lambda/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
12 changes: 6 additions & 6 deletions lambda/models/domain_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ class AutoScalingConfig(BaseModel):
defaultInstanceWarmup: PositiveInt
metricConfig: MetricConfig

@model_validator(mode="after") # type: ignore
@model_validator(mode="after")
def validate_auto_scaling_config(self) -> Self:
"""Validate autoScalingConfig values."""
if self.minCapacity > self.maxCapacity:
Expand All @@ -115,7 +115,7 @@ class AutoScalingInstanceConfig(BaseModel):
maxCapacity: Optional[PositiveInt] = None
desiredCapacity: Optional[PositiveInt] = None

@model_validator(mode="after") # type: ignore
@model_validator(mode="after")
def validate_auto_scaling_instance_config(self) -> Self:
"""Validate autoScalingInstanceConfig values."""
config_fields = [self.minCapacity, self.maxCapacity, self.desiredCapacity]
Expand Down Expand Up @@ -155,7 +155,7 @@ class ContainerConfig(BaseModel):
healthCheckConfig: ContainerHealthCheckConfig
environment: Optional[Dict[str, str]] = {}

@field_validator("environment") # type: ignore
@field_validator("environment")
@classmethod
def validate_environment(cls, environment: Dict[str, str]) -> Dict[str, str]:
"""Validate that all keys in Dict are not empty."""
Expand Down Expand Up @@ -201,7 +201,7 @@ class CreateModelRequest(BaseModel):
modelUrl: Optional[str] = None
streaming: Optional[bool] = False

@model_validator(mode="after") # type: ignore
@model_validator(mode="after")
def validate_create_model_request(self) -> Self:
"""Validate whole request object."""
# Validate that an embedding model cannot be set as streaming-enabled
Expand Down Expand Up @@ -252,7 +252,7 @@ class UpdateModelRequest(BaseModel):
modelType: Optional[ModelType] = None
streaming: Optional[bool] = None

@model_validator(mode="after") # type: ignore
@model_validator(mode="after")
def validate_update_model_request(self) -> Self:
"""Validate whole request object."""
fields = [
Expand All @@ -273,7 +273,7 @@ def validate_update_model_request(self) -> Self:
raise ValueError("Embedding model cannot be set with streaming enabled.")
return self

@field_validator("autoScalingInstanceConfig") # type: ignore
@field_validator("autoScalingInstanceConfig")
@classmethod
def validate_autoscaling_instance_config(cls, config: AutoScalingInstanceConfig) -> AutoScalingInstanceConfig:
"""Validate that the AutoScaling instance config has at least one positive value."""
Expand Down
28 changes: 14 additions & 14 deletions lambda/models/lambda_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,32 +59,32 @@
stepfunctions = boto3.client("stepfunctions", region_name=os.environ["AWS_REGION"], config=retry_config)


@app.exception_handler(ModelNotFoundError) # type: ignore
@app.exception_handler(ModelNotFoundError)
async def model_not_found_handler(request: Request, exc: ModelNotFoundError) -> JSONResponse:
"""Handle exception when model cannot be found and translate to a 404 error."""
return JSONResponse(status_code=404, content={"message": str(exc)})


@app.exception_handler(RequestValidationError) # type: ignore
async def validation_exception_handler(request: Request, exc: RequestValidationError):
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError) -> JSONResponse:
"""Handle exception when request fails validation and and translate to a 422 error."""
return JSONResponse(
status_code=422, content={"detail": jsonable_encoder(exc.errors()), "type": "RequestValidationError"}
)


@app.exception_handler(InvalidStateTransitionError) # type: ignore
@app.exception_handler(ModelAlreadyExistsError) # type: ignore
@app.exception_handler(ValueError) # type: ignore
@app.exception_handler(InvalidStateTransitionError)
@app.exception_handler(ModelAlreadyExistsError)
@app.exception_handler(ValueError)
async def user_error_handler(
request: Request, exc: Union[InvalidStateTransitionError, ModelAlreadyExistsError, ValueError]
) -> JSONResponse:
"""Handle errors when customer requests options that cannot be processed."""
return JSONResponse(status_code=400, content={"message": str(exc)})


@app.post(path="", include_in_schema=False) # type: ignore
@app.post(path="/") # type: ignore
@app.post(path="", include_in_schema=False)
@app.post(path="/")
async def create_model(create_request: CreateModelRequest) -> CreateModelResponse:
"""Endpoint to create a model."""
create_handler = CreateModelHandler(
Expand All @@ -95,8 +95,8 @@ async def create_model(create_request: CreateModelRequest) -> CreateModelRespons
return create_handler(create_request=create_request)


@app.get(path="", include_in_schema=False) # type: ignore
@app.get(path="/") # type: ignore
@app.get(path="", include_in_schema=False)
@app.get(path="/")
async def list_models() -> ListModelsResponse:
"""Endpoint to list models."""
list_handler = ListModelsHandler(
Expand All @@ -107,7 +107,7 @@ async def list_models() -> ListModelsResponse:
return list_handler()


@app.get(path="/{model_id}") # type: ignore
@app.get(path="/{model_id}")
async def get_model(
model_id: Annotated[str, Path(title="The unique model ID of the model to get")], request: Request
) -> GetModelResponse:
Expand All @@ -120,7 +120,7 @@ async def get_model(
return get_handler(model_id=model_id)


@app.put(path="/{model_id}") # type: ignore
@app.put(path="/{model_id}")
async def update_model(
model_id: Annotated[str, Path(title="The unique model ID of the model to update")],
update_request: UpdateModelRequest,
Expand All @@ -134,7 +134,7 @@ async def update_model(
return update_handler(model_id=model_id, update_request=update_request)


@app.delete(path="/{model_id}") # type: ignore
@app.delete(path="/{model_id}")
async def delete_model(
model_id: Annotated[str, Path(title="The unique model ID of the model to delete")], request: Request
) -> DeleteModelResponse:
Expand All @@ -147,7 +147,7 @@ async def delete_model(
return delete_handler(model_id=model_id)


@app.get(path="/metadata/instances") # type: ignore
@app.get(path="/metadata/instances")
async def get_instances() -> list[str]:
"""Endpoint to list available instances in this region."""
return list(sess.get_service_model("ec2").shape_for("InstanceType").enum)
Expand Down
126 changes: 122 additions & 4 deletions lambda/repository/lambda_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,18 @@
from typing import Any, Dict, List

import boto3
import create_env_variables # noqa: F401
import requests
from botocore.config import Config
from lisapy.langchain import LisaOpenAIEmbeddings
from lisapy.utils import get_cert_path
from utilities.common_functions import api_wrapper, get_id_token, retry_config
from utilities.common_functions import api_wrapper, get_cert_path, get_id_token, retry_config
from utilities.file_processing import process_record
from utilities.validation import validate_model_name, ValidationError
from utilities.vector_store import get_vector_store_client

logger = logging.getLogger(__name__)
session = boto3.Session()
ssm_client = boto3.client("ssm", region_name=os.environ["AWS_REGION"], config=retry_config)
secrets_client = boto3.client("secretsmanager", region_name=os.environ["AWS_REGION"], config=retry_config)
iam_client = boto3.client("iam", region_name=os.environ["AWS_REGION"], config=retry_config)
s3 = session.client(
"s3",
Expand All @@ -54,12 +55,129 @@ def _get_embeddings(model_name: str, id_token: str) -> LisaOpenAIEmbeddings:
lisa_api_endpoint = lisa_api_param_response["Parameter"]["Value"]

base_url = f"{lisa_api_endpoint}/{os.environ['REST_API_VERSION']}/serve"
cert_path = get_cert_path(iam_client)

embedding = LisaOpenAIEmbeddings(
lisa_openai_api_base=base_url, model=model_name, api_token=id_token, verify=get_cert_path(iam_client)
lisa_openai_api_base=base_url, model=model_name, api_token=id_token, verify=cert_path
)
return embedding

# Create embeddings client that matches LisaOpenAIEmbeddings interface


class PipelineEmbeddings:
def __init__(self) -> None:
try:
# Get the management key secret name from SSM Parameter Store
secret_name_param = ssm_client.get_parameter(Name=os.environ["MANAGEMENT_KEY_SECRET_NAME_PS"])
secret_name = secret_name_param["Parameter"]["Value"]

# Get the management token from Secrets Manager using the secret name
secret_response = secrets_client.get_secret_value(SecretId=secret_name)
self.token = secret_response["SecretString"]

# Get the API endpoint from SSM
lisa_api_param_response = ssm_client.get_parameter(Name=os.environ["LISA_API_URL_PS_NAME"])
self.base_url = f"{lisa_api_param_response['Parameter']['Value']}/{os.environ['REST_API_VERSION']}/serve"

# Get certificate path for SSL verification
self.cert_path = get_cert_path(iam_client)

logger.info("Successfully initialized pipeline embeddings")
except Exception:
logger.error("Failed to initialize pipeline embeddings", exc_info=True)
raise

def embed_documents(self, texts: List[str], model_name: str) -> List[List[float]]:
if not texts:
raise ValidationError("No texts provided for embedding")

logger.info(f"Embedding {len(texts)} documents")
try:
url = f"{self.base_url}/embeddings"
request_data = {"input": texts, "model": model_name}

response = requests.post(
url,
json=request_data,
headers={"Authorization": self.token, "Content-Type": "application/json"},
verify=self.cert_path, # Use proper SSL verification
timeout=300, # 5 minute timeout
)

if response.status_code != 200:
logger.error(f"Embedding request failed with status {response.status_code}")
logger.error(f"Response content: {response.text}")
raise Exception(f"Embedding request failed with status {response.status_code}")

result = response.json()
logger.debug(f"API Response: {result}") # Log the full response for debugging

# Handle different response formats
embeddings = []
if isinstance(result, dict):
if "data" in result:
# OpenAI-style format
for item in result["data"]:
if isinstance(item, dict) and "embedding" in item:
embeddings.append(item["embedding"])
else:
embeddings.append(item) # Assume the item itself is the embedding
else:
# Try to find embeddings in the response
for key in ["embeddings", "embedding", "vectors", "vector"]:
if key in result:
embeddings = result[key]
break
elif isinstance(result, list):
# Direct list format
embeddings = result

if not embeddings:
logger.error(f"Could not find embeddings in response: {result}")
raise Exception("No embeddings found in API response")

if len(embeddings) != len(texts):
logger.error(f"Mismatch between number of texts ({len(texts)}) and embeddings ({len(embeddings)})")
raise Exception("Number of embeddings does not match number of input texts")

logger.info(f"Successfully embedded {len(texts)} documents")
return embeddings

except requests.Timeout:
logger.error("Embedding request timed out")
raise Exception("Embedding request timed out after 5 minutes")
except requests.RequestException as e:
logger.error(f"Request failed: {str(e)}", exc_info=True)
raise
except Exception as e:
logger.error(f"Failed to get embeddings: {str(e)}", exc_info=True)
raise

def embed_query(self, text: str, model_name: str) -> List[float]:
if not text or not isinstance(text, str):
raise ValidationError("Invalid query text")

logger.info("Embedding single query text")
return self.embed_documents([text], model_name)[0]


def _get_embeddings_pipeline(model_name: str) -> Any:
"""
Get embeddings for pipeline requests using management token.
Args:
model_name: Name of the embedding model to use
Raises:
ValidationError: If model name is invalid
Exception: If API request fails
"""
logger.info("Starting pipeline embeddings request")
validate_model_name(model_name)

return PipelineEmbeddings()


@api_wrapper
def list_all(event: dict, context: dict) -> List[Dict[str, Any]]:
Expand Down
Loading

0 comments on commit 1cb1b28

Please sign in to comment.