diff --git a/lambda/models/domain_objects.py b/lambda/models/domain_objects.py index 8aa8c048..830ce2d0 100644 --- a/lambda/models/domain_objects.py +++ b/lambda/models/domain_objects.py @@ -14,10 +14,12 @@ """Domain objects for interacting with the model endpoints.""" +import time +import uuid from enum import Enum -from typing import Annotated, Dict, List, Optional, Union +from typing import Annotated, Any, Dict, List, Optional, Union -from pydantic import BaseModel, Field, NonNegativeInt, PositiveInt +from pydantic import BaseModel, ConfigDict, Field, NonNegativeInt, PositiveInt from pydantic.functional_validators import AfterValidator, field_validator, model_validator from typing_extensions import Self from utilities.validators import validate_all_fields_defined, validate_any_fields_defined, validate_instance_type @@ -292,3 +294,32 @@ class DeleteModelResponse(ApiResponseBase): """Response object when deleting a model.""" pass + + +class IngestionType(Enum): + AUTO = "auto" + MANUAL = "manual" + + +class RagDocument(BaseModel): + """Rag Document Entity for storing in DynamoDB.""" + + pk: Optional[str] = None + document_id: str = Field(default_factory=lambda: str(uuid.uuid4())) + repository_id: str + collection_id: str + document_name: str + source: str + sub_docs: List[str] = Field(default_factory=lambda: []) + ingestion_type: IngestionType = Field(default_factory=lambda: IngestionType.MANUAL) + upload_date: int = Field(default_factory=lambda: int(time.time())) + + model_config = ConfigDict(use_enum_values=True, validate_default=True) + + def __init__(self, **data: Any) -> None: + super().__init__(**data) + self.pk = self.createPartitionKey(self.repository_id, self.collection_id) + + @staticmethod + def createPartitionKey(repository_id: str, collection_id: str) -> str: + return f"{repository_id}#{collection_id}" diff --git a/lambda/repository/lambda_functions.py b/lambda/repository/lambda_functions.py index 016e80dd..58e7549b 100644 --- a/lambda/repository/lambda_functions.py +++ b/lambda/repository/lambda_functions.py @@ -22,6 +22,8 @@ import requests from botocore.config import Config from lisapy.langchain import LisaOpenAIEmbeddings +from models.domain_objects import IngestionType, RagDocument +from repository.rag_document_repo import RagDocumentRepository from utilities.common_functions import api_wrapper, get_cert_path, get_id_token, retry_config from utilities.exceptions import HTTPException from utilities.file_processing import process_record @@ -29,13 +31,14 @@ from utilities.vector_store import find_repository_by_id, get_registered_repositories, get_vector_store_client logger = logging.getLogger(__name__) +region_name = os.environ["AWS_REGION"] session = boto3.Session() -ssm_client = boto3.client("ssm", region_name=os.environ["AWS_REGION"], config=retry_config) -secrets_client = boto3.client("secretsmanager", region_name=os.environ["AWS_REGION"], config=retry_config) -iam_client = boto3.client("iam", region_name=os.environ["AWS_REGION"], config=retry_config) +ssm_client = boto3.client("ssm", region_name, config=retry_config) +secrets_client = boto3.client("secretsmanager", region_name, config=retry_config) +iam_client = boto3.client("iam", region_name, config=retry_config) s3 = session.client( "s3", - region_name=os.environ["AWS_REGION"], + region_name, config=Config( retries={ "max_attempts": 3, @@ -45,6 +48,8 @@ ), ) lisa_api_endpoint = "" +registered_repositories: List[Dict[str, Any]] = [] +doc_repo = RagDocumentRepository(os.environ["RAG_DOCUMENT_TABLE"]) def _get_embeddings(model_name: str, id_token: str) -> LisaOpenAIEmbeddings: @@ -183,7 +188,7 @@ def _get_embeddings_pipeline(model_name: str) -> Any: def list_all(event: dict, context: dict) -> List[Dict[str, Any]]: """Return info on all available repositories. - Currently there is not support for dynamic repositories so only a single OpenSearch repository + Currently, there is no support for dynamic repositories so only a single OpenSearch repository is returned. """ @@ -212,9 +217,22 @@ def similarity_search(event: dict, context: dict) -> Dict[str, Any]: """Return documents matching the query. Conducts similarity search against the vector store returning the top K - documents based on the specified query. 'topK' can be set as an optional - querystring parameter, if it is not specified the top 3 documents will be - returned. + documents based on the specified query. + + Args: + event (dict): The Lambda event object containing: + - queryStringParameters.modelName: Name of the embedding model + - queryStringParameters.query: Search query text + - queryStringParameters.repositoryType: Type of repository + - queryStringParameters.topK (optional): Number of results to return (default: 3) + context (dict): The Lambda context object + + Returns: + Dict[str, Any]: A dictionary containing: + - docs: List of matching documents with their content and metadata + + Raises: + ValidationError: If required parameters are missing or invalid """ query_string_params = event["queryStringParameters"] model_name = query_string_params["modelName"] @@ -240,7 +258,7 @@ def similarity_search(event: dict, context: dict) -> Dict[str, Any]: return doc_return -def ensure_repository_access(event, repository): +def ensure_repository_access(event: dict[str, Any], repository: dict[str, Any]) -> None: "Ensures a user has access to the repository or else raises an HTTPException" user_groups = json.loads(event["requestContext"]["authorizer"]["groups"]) or [] if not user_has_group(user_groups, repository["allowedGroups"]): @@ -248,56 +266,158 @@ def ensure_repository_access(event, repository): @api_wrapper -def purge_document(event: dict, context: dict) -> Dict[str, Any]: - """Purge all records related to the specified document from the RAG repository.""" - user_id = event["requestContext"]["authorizer"]["username"] - repository_id = event["pathParameters"]["repositoryId"] - document_id = event["pathParameters"]["sessionId"] +def delete_document(event: dict, context: dict) -> Dict[str, Any]: + """Purge all records related to the specified document from the RAG repository. If a documentId is supplied, a + single document will be removed. If a documentName is supplied, all documents with that name will be removed - logger.info( - f"Purging records associated with document {document_id} " - f"(requesting user: {user_id}), repository: {repository_id}" - ) + Args: + event (dict): The Lambda event object containing: + - pathParameters.repositoryId: The repository id of VectorStore + - queryStringParameters.collectionId: The collection identifier + - queryStringParameters.repositoryType: Type of repository of VectorStore + - queryStringParameters.documentId (optional): Name of document to purge + - queryStringParameters.documentName (optional): Name of document to purge + context (dict): The Lambda context object + + Returns: + Dict[str, Any]: A dictionary containing: + - documentName (str): Name of the purged document + - recordsPurged (int): Number of records purged from VectorStore + + Raises: + ValueError: If document is not found in repository + """ + path_params = event.get("pathParameters", {}) + repository_id = path_params.get("repositoryId") + + query_string_params = event["queryStringParameters"] + collection_id = query_string_params["collectionId"] + document_id = query_string_params.get("documentId") + document_name = query_string_params.get("documentName") - return {"documentId": document_id, "recordsPurged": 0} + if not document_id and not document_name: + raise ValidationError("Either documentId or documentName must be specified") + if document_id and document_name: + raise ValidationError("Only one of documentId or documentName must be specified") + + docs = [] + if document_id: + docs = [doc_repo.find_by_id(document_id)] + elif document_name: + docs = doc_repo.find_by_name(repository_id, collection_id, document_name) + + if not docs: + raise ValueError(f"No documents found in repository collection {repository_id}:{collection_id}") + + # Grab all sub document ids related to the parent document(s) + subdoc_ids = [sub_doc for doc in docs for sub_doc in doc.get("sub_docs", [])] + + id_token = get_id_token(event) + embeddings = _get_embeddings(model_name=collection_id, id_token=id_token) + vs = get_vector_store_client(repository_id=repository_id, index=collection_id, embeddings=embeddings) + + vs.delete(ids=subdoc_ids) + + doc_repo.batch_delete(docs) + + return { + "documentName": docs[0].get("document_name"), + "removedDocuments": len(docs), + "removedDocumentChunks": len(subdoc_ids), + } @api_wrapper def ingest_documents(event: dict, context: dict) -> dict: - """Ingest a set of documents into the specified repository.""" + """Ingest documents into the RAG repository. + + Args: + event (dict): The Lambda event object containing: + - body.embeddingModel.modelName: Document collection id + - body.keys: List of s3 keys to ingest + - pathParameters.repositoryId: Repository id (VectorStore) + - queryStringParameters.repositoryType: Repository type (VectorStore) + - queryStringParameters.chunkSize (optional): Size of text chunks + - queryStringParameters.chunkOverlap (optional): Overlap between chunks + context (dict): The Lambda context object + + Returns: + dict: A dictionary containing: + - ids (list): List of generated document IDs + - count (int): Total number of documents ingested + + Raises: + ValidationError: If required parameters are missing or invalid + """ body = json.loads(event["body"]) embedding_model = body["embeddingModel"] model_name = embedding_model["modelName"] + path_params = event.get("pathParameters", {}) + repository_id = path_params.get("repositoryId") + query_string_params = event["queryStringParameters"] chunk_size = int(query_string_params["chunkSize"]) if "chunkSize" in query_string_params else None chunk_overlap = int(query_string_params["chunkOverlap"]) if "chunkOverlap" in query_string_params else None - repository_id = event["pathParameters"]["repositoryId"] logger.info(f"using repository {repository_id}") repository = find_repository_by_id(repository_id) ensure_repository_access(event, repository) - docs = process_record(s3_keys=body["keys"], chunk_size=chunk_size, chunk_overlap=chunk_overlap) + keys = body["keys"] + docs = process_record(s3_keys=keys, chunk_size=chunk_size, chunk_overlap=chunk_overlap) texts = [] # list of strings metadatas = [] # list of dicts + all_ids = [] + id_token = get_id_token(event) + embeddings = _get_embeddings(model_name=model_name, id_token=id_token) + vs = get_vector_store_client(repository_id, index=model_name, embeddings=embeddings) + # Batch document ingestion one parent document at a time for doc_list in docs: + document_name = doc_list[0].metadata.get("name") + doc_source = doc_list[0].metadata.get("source") for doc in doc_list: texts.append(doc.page_content) metadatas.append(doc.metadata) + # Ingest document into vector store + ids = vs.add_texts(texts=texts, metadatas=metadatas) - id_token = get_id_token(event) - embeddings = _get_embeddings(model_name=model_name, id_token=id_token) - vs = get_vector_store_client(repository_id, index=model_name, embeddings=embeddings) - ids = vs.add_texts(texts=texts, metadatas=metadatas) - return {"ids": ids, "count": len(ids)} + # Add document to RagDocTable + doc_entity = RagDocument( + repository_id=repository_id, + collection_id=model_name, + document_name=document_name, + source=doc_source, + sub_docs=ids, + ingestion_type=IngestionType.MANUAL, + ) + doc_repo.save(doc_entity) + + all_ids.extend(ids) + + return {"ids": all_ids, "count": len(all_ids)} @api_wrapper def presigned_url(event: dict, context: dict) -> dict: - """Generate a pre-signed URL for uploading files to the RAG ingest bucket.""" + """Generate a pre-signed URL for uploading files to the RAG ingest bucket. + + Args: + event (dict): The Lambda event object containing: + - body: The key for the file + - requestContext.authorizer.username: The authenticated username + context (dict): The Lambda context object + + Returns: + dict: A dictionary containing: + - response: The presigned URL response object with upload fields and URL + + Notes: + - URL expires in 3600 seconds (1 hour) + - Maximum file size is 52428800 bytes (50MB) + """ response = "" key = event["body"] @@ -325,3 +445,31 @@ def presigned_url(event: dict, context: dict) -> dict: def get_groups(event: Any) -> List[str]: groups: List[str] = json.loads(event["requestContext"]["authorizer"]["groups"]) return groups + + +@api_wrapper +def list_docs(event: dict, context: dict) -> List[RagDocument]: + """List all documents for a given repository/collection. + + Args: + event (dict): The Lambda event object containing query parameters + - pathParameters.repositoryId: The repository id to list documents for + - queryStringParameters.collectionId: The collection id to list documents for + context (dict): The Lambda context object + + Returns: + list[RagDocument]: A list of RagDocument objects representing all documents + in the specified collection + + Raises: + KeyError: If collectionId is not provided in queryStringParameters + """ + + path_params = event.get("pathParameters", {}) + repository_id = path_params.get("repositoryId") + + query_string_params = event.get("queryStringParameters", {}) + collection_id = query_string_params.get("collectionId") + + docs: List[RagDocument] = doc_repo.list_all(repository_id, collection_id) + return docs diff --git a/lambda/repository/pipeline_ingest_documents.py b/lambda/repository/pipeline_ingest_documents.py index 45ae5a5d..919b0e09 100644 --- a/lambda/repository/pipeline_ingest_documents.py +++ b/lambda/repository/pipeline_ingest_documents.py @@ -19,17 +19,20 @@ from typing import Any, Dict, List import boto3 +from repository.lambda_functions import RagDocumentRepository from utilities.common_functions import retry_config from utilities.file_processing import process_record from utilities.validation import validate_chunk_params, validate_model_name, validate_repository_type, ValidationError from utilities.vector_store import get_vector_store_client -from .lambda_functions import _get_embeddings_pipeline +from .lambda_functions import _get_embeddings_pipeline, IngestionType, RagDocument logger = logging.getLogger(__name__) session = boto3.Session() ssm_client = boto3.client("ssm", region_name=os.environ["AWS_REGION"], config=retry_config) +doc_repo = RagDocumentRepository(os.environ["RAG_DOCUMENT_TABLE"]) + def batch_texts(texts: List[str], metadatas: List[Dict], batch_size: int = 500) -> list[tuple[list[str], list[dict]]]: """ @@ -110,6 +113,7 @@ def handle_pipeline_ingest_documents(event: Dict[str, Any], context: Any) -> Dic # Prepare texts and metadata texts = [] metadatas = [] + for doc_list in docs: for doc in doc_list: texts.append(doc.page_content) @@ -147,6 +151,17 @@ def handle_pipeline_ingest_documents(event: Dict[str, Any], context: Any) -> Dic logger.info(f"Successfully processed {len(all_ids)} chunks from {s3_key} for repository {repository_id}") + # Store RagDocument entry in Document Table + doc_entity = RagDocument( + repository_id=repository_id, + collection_id=embedding_model, + document_name=key, + source=docs[0][0].metadata.get("source"), + sub_docs=all_ids, + ingestion_type=IngestionType.AUTO, + ) + doc_repo.save(doc_entity) + return { "message": f"Successfully processed document {s3_key}", "repository_id": repository_id, diff --git a/lambda/repository/rag_document_repo.py b/lambda/repository/rag_document_repo.py new file mode 100644 index 00000000..1e87cc0a --- /dev/null +++ b/lambda/repository/rag_document_repo.py @@ -0,0 +1,192 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Dict, List + +import boto3 +from boto3.dynamodb.conditions import Key +from botocore.exceptions import ClientError +from models.domain_objects import RagDocument + +logger = logging.getLogger(__name__) + + +class RagDocumentRepository: + """RAG Document repository for DynamoDB""" + + def __init__(self, table_name: str): + self.dynamodb = boto3.resource("dynamodb") + self.table = self.dynamodb.Table(table_name) + + def delete(self, pk: str, document_id: str) -> None: + """Delete a document using partition key and sort key. + + Args: + pk: Partition key value + document_id: Sort key value + + Returns: + Dict containing the response from DynamoDB + + Raises: + ClientError: If deletion fails + """ + try: + self.table.delete_item(Key={"pk": pk, "document_id": document_id}) + except ClientError as e: + print(f"Error deleting document: {e.response['Error']['Message']}") + raise + + def batch_delete(self, items: List[Dict[str, str]]) -> None: + """Delete multiple documents in a batch. + + Args: + items: List of dictionaries containing pk and document_id pairs + + Raises: + ClientError: If batch deletion fails + """ + try: + with self.table.batch_writer() as batch: + for item in items: + batch.delete_item(Key={"pk": item["pk"], "document_id": item["document_id"]}) + except ClientError as e: + print(f"Error in batch deletion: {e.response['Error']['Message']}") + raise + + def save(self, document: RagDocument) -> RagDocument: + """Save a document to DynamoDB. + + Args: + document: Dictionary containing document attributes + + Returns: + Dict containing the response from DynamoDB + + Raises: + ClientError: If save operation fails + """ + try: + response = self.table.put_item(Item=document.model_dump()) + return response + except ClientError as e: + print(f"Error saving document: {e.response['Error']['Message']}") + raise + + def batch_save(self, documents: List[RagDocument]) -> None: + """Save multiple documents in a batch. + + Args: + documents: List of document dictionaries + + Raises: + ClientError: If batch save operation fails + """ + try: + with self.table.batch_writer() as batch: + for doc in documents: + batch.put_item(Item=doc.model_dump()) + except ClientError as e: + print(f"Error in batch save: {e.response['Error']['Message']}") + raise + + def find_by_id(self, document_id: str) -> RagDocument: + """Query documents using GSI. + + Args: + document_id: Document ID to query + index_name: Name of the GSI + + Returns: + List of matching documents + + Raises: + ClientError: If query operation fails + """ + try: + response = self.table.query( + IndexName="document_index", + KeyConditionExpression="document_id = :document_id", + ExpressionAttributeValues={":document_id": document_id}, + ) + docs = response.get("Items") + if not docs: + raise KeyError(f"Document not found for document_id {document_id}") + if len(docs) > 1: + raise ValueError(f"Multiple items found for document_id {document_id}") + + logging.info(docs[0]) + + return docs[0] + except ClientError as e: + print(f"Error querying document: {e.response['Error']['Message']}") + raise + + def _get_documents_by_name(self, repository_id: str, collection_id: str, document_name: str) -> list[RagDocument]: + """Get a list of documents from the RagDocTable by name. + + Args: + document_name (str): The name of the documents to retrieve + repository_id (str): The repository id to list documents for + collection_id (str): The collection id to list documents for + + Returns: + list[RagDocument]: A list of document objects matching the specified name + + Raises: + KeyError: If no documents are found with the specified name + """ + pk = RagDocument.createPartitionKey(repository_id, collection_id) + response = self.table.query( + KeyConditionExpression=Key("pk").eq(pk), FilterExpression=Key("document_name").eq(document_name) + ) + docs: list[RagDocument] = response["Items"] + + # Handle paginated Dynamo results + while "LastEvaluatedKey" in response: + response = self.table.query( + KeyConditionExpression=Key("pk").eq(pk), + FilterExpression=Key("document_name").eq(document_name), + ExclusiveStartKey=response["LastEvaluatedKey"], + ) + docs.extend(response["Items"]) + + return docs + + def list_all(self, repository_id: str, collection_id: str) -> List[RagDocument]: + """List all documents in a collection. + + Args: + repository_id: Repository ID + collection_id: Collection ID + + Returns: + List of documents + """ + pk = RagDocument.createPartitionKey(repository_id, collection_id) + response = self.table.query( + KeyConditionExpression=Key("pk").eq(pk), + ) + docs: List[RagDocument] = response["Items"] + + # Handle paginated Dynamo results + while "LastEvaluatedKey" in response: + response = self.table.query( + KeyConditionExpression=Key("pk").eq(pk), + ExclusiveStartKey=response["LastEvaluatedKey"], + ) + docs.extend(response["Items"]) + + return docs diff --git a/lambda/repository/state_machine/pipeline_ingest_documents.py b/lambda/repository/state_machine/pipeline_ingest_documents.py index ec681144..ee1cd6f4 100644 --- a/lambda/repository/state_machine/pipeline_ingest_documents.py +++ b/lambda/repository/state_machine/pipeline_ingest_documents.py @@ -18,7 +18,11 @@ import boto3 from models.document_processor import DocumentProcessor +from models.domain_objects import IngestionType, RagDocument from models.vectorstore import VectorStore +from repository.lambda_functions import RagDocumentRepository + +doc_repo = RagDocumentRepository(os.environ["RAG_DOCUMENT_TABLE"]) def handle_pipeline_ingest_documents(event: Dict[str, Any], context: Any) -> Dict[str, Any]: @@ -42,6 +46,7 @@ def handle_pipeline_ingest_documents(event: Dict[str, Any], context: Any) -> Dic chunk_overlap = int(os.environ["CHUNK_OVERLAP"]) embedding_model = os.environ["EMBEDDING_MODEL"] collection_name = os.environ["COLLECTION_NAME"] + repository_id = os.environ["REPOSITORY_ID"] # Initialize document processor and vectorstore doc_processor = DocumentProcessor() @@ -54,9 +59,21 @@ def handle_pipeline_ingest_documents(event: Dict[str, Any], context: Any) -> Dic # Chunk document chunks = doc_processor.chunk_text(text=content, chunk_size=chunk_size, chunk_overlap=chunk_overlap) + source = f"s3://{bucket}/{key}" # Store chunks in vectorstore - vectorstore.add_texts(texts=chunks, metadata={"source": f"s3://{bucket}/{key}"}) + ids = vectorstore.add_texts(texts=chunks, metadata={"source": source}) + + # Store in DocTable + doc_entity = RagDocument( + repository_id=repository_id, + collection_id=collection_name, + document_name=key, + source=source, + sub_docs=ids, + ingestion_type=IngestionType.AUTO, + ) + doc_repo.save(doc_entity) return { "statusCode": 200, diff --git a/lambda/utilities/file_processing.py b/lambda/utilities/file_processing.py index f82640c2..98c5776b 100644 --- a/lambda/utilities/file_processing.py +++ b/lambda/utilities/file_processing.py @@ -33,8 +33,8 @@ s3 = session.client("s3") -def _get_metadata(s3_uri: str) -> dict: - return {"source": s3_uri} +def _get_metadata(s3_uri: str, name: str) -> dict: + return {"source": s3_uri, "name": name} def _get_s3_uri(bucket: str, key: str) -> str: @@ -153,7 +153,11 @@ def process_record( raise e s3_uri = _get_s3_uri(bucket=bucket, key=key) extracted_text = _extract_text_by_content_type(content_type=content_type, s3_object=s3_object) - docs = [Document(page_content=extracted_text, metadata=_get_metadata(s3_uri=s3_uri))] - chunks.append(_generate_chunks(docs, chunk_size=chunk_size, chunk_overlap=chunk_overlap)) + docs = [Document(page_content=extracted_text, metadata=_get_metadata(s3_uri=s3_uri, name=key))] + doc_chunks = _generate_chunks(docs, chunk_size=chunk_size, chunk_overlap=chunk_overlap) + # Update part number of doc metadata + for i, doc in enumerate(doc_chunks): + doc.metadata["part"] = i + 1 + chunks.append(doc_chunks) return chunks diff --git a/lambda/utilities/vector_store.py b/lambda/utilities/vector_store.py index cd830747..f5d4b21e 100644 --- a/lambda/utilities/vector_store.py +++ b/lambda/utilities/vector_store.py @@ -16,7 +16,7 @@ import json import logging import os -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List import boto3 import create_env_variables # noqa: F401 @@ -36,7 +36,7 @@ registered_repositories: List[Dict[str, Any]] = [] -def get_registered_repositories() -> List[Any]: +def get_registered_repositories() -> List[dict]: """Get a list of all registered RAG repositories.""" global registered_repositories if not registered_repositories: @@ -46,12 +46,16 @@ def get_registered_repositories() -> List[Any]: return registered_repositories -def find_repository_by_id(repository_id: str) -> Optional[Dict[str, Any]]: +def find_repository_by_id(repository_id: str) -> Dict[str, Any]: """Find a RAG repository by id.""" - return next( + repository = next( (repository for repository in get_registered_repositories() if repository["repositoryId"] == repository_id), None, ) + if repository is None: + raise ValueError(f"Repository with ID '{repository_id}' not found") + + return repository def get_vector_store_client(repository_id: str, index: str, embeddings: Embeddings) -> VectorStore: diff --git a/lib/rag/api/repository.ts b/lib/rag/api/repository.ts index 73ead2e9..e0e05534 100644 --- a/lib/rag/api/repository.ts +++ b/lib/rag/api/repository.ts @@ -31,11 +31,11 @@ import { Vpc } from '../../networking/vpc'; * @property {IAuthorizer} authorizer - APIGW authorizer * @property {Record} baseEnvironment - Default environment properties applied to all * lambdas - * @property {LayerVersion[]} commonLayers - Lambda layers for all Lambdas. + * @property {ILayerVersion[]} commonLayers - Lambda layers for all Lambdas. * @property {IRole} lambdaExecutionRole - Execution role for lambdas - * @property {IRestApi} restAPI - REST APIGW for UI and Lambdas + * @property {string} restApiId - REST APIGW for UI and Lambdas * @property {ISecurityGroup[]} securityGroups - Security groups for Lambdas - * @property {IVpc} vpc - Stack VPC + * @property {Vpc} vpc - Stack VPC */ type RepositoryApiProps = { authorizer: IAuthorizer; @@ -84,11 +84,21 @@ export class RepositoryApi extends Construct { }, }, { - name: 'purge_document', + name: 'presigned_url', resource: 'repository', - description: 'Purges all records associated with a document from the repository', - path: 'repository/{repositoryId}/{documentId}', - method: 'DELETE', + description: 'Generates a presigned url for uploading files to RAG', + path: 'repository/presigned-url', + method: 'POST', + environment: { + ...baseEnvironment, + }, + }, + { + name: 'similarity_search', + resource: 'repository', + description: 'Run a similarity search against the specified repository using the specified query', + path: 'repository/{repositoryId}/similaritySearch', + method: 'GET', environment: { ...baseEnvironment, }, @@ -105,20 +115,20 @@ export class RepositoryApi extends Construct { }, }, { - name: 'presigned_url', + name: 'delete_document', resource: 'repository', - description: 'Generates a presigned url for uploading files to RAG', - path: 'repository/presigned-url', - method: 'POST', + description: 'Deletes all records associated with a document from the repository', + path: 'repository/{repositoryId}/document', + method: 'DELETE', environment: { ...baseEnvironment, }, }, { - name: 'similarity_search', + name: 'list_docs', resource: 'repository', - description: 'Run a similarity search against the specified repository using the specified query', - path: 'repository/{repositoryId}/similaritySearch', + description: 'List all docs for a repository', + path: 'repository/{repositoryId}/document', method: 'GET', environment: { ...baseEnvironment, diff --git a/lib/rag/index.ts b/lib/rag/index.ts index abdea9e1..316bb824 100644 --- a/lib/rag/index.ts +++ b/lib/rag/index.ts @@ -24,11 +24,12 @@ import { CfnOutput, RemovalPolicy, Stack, StackProps } from 'aws-cdk-lib'; import { IAuthorizer } from 'aws-cdk-lib/aws-apigateway'; import { ISecurityGroup } from 'aws-cdk-lib/aws-ec2'; import { AnyPrincipal, CfnServiceLinkedRole, Effect, PolicyStatement, Role } from 'aws-cdk-lib/aws-iam'; -import { Code, LayerVersion, ILayerVersion } from 'aws-cdk-lib/aws-lambda'; +import { Code, ILayerVersion, LayerVersion } from 'aws-cdk-lib/aws-lambda'; import { Domain, EngineVersion, IDomain } from 'aws-cdk-lib/aws-opensearchservice'; import { Credentials, DatabaseInstance, DatabaseInstanceEngine } from 'aws-cdk-lib/aws-rds'; import { Bucket, HttpMethods } from 'aws-cdk-lib/aws-s3'; import { ISecret, Secret } from 'aws-cdk-lib/aws-secretsmanager'; +import { AttributeType, Table } from 'aws-cdk-lib/aws-dynamodb'; import { StringParameter } from 'aws-cdk-lib/aws-ssm'; import { Construct } from 'constructs'; @@ -88,6 +89,7 @@ export class LisaRagStack extends Stack { const bucket = new Bucket(this, createCdkId(['LISA', 'RAG', config.deploymentName, config.deploymentStage]), { removalPolicy: config.removalPolicy, autoDeleteObjects: config.removalPolicy === RemovalPolicy.DESTROY, + enforceSSL: true, cors: [ { allowedMethods: [HttpMethods.GET, HttpMethods.POST], @@ -98,6 +100,25 @@ export class LisaRagStack extends Stack { ], }); + const docTable = new Table(this, createCdkId([config.deploymentName, 'RagDocumentTable']), { + partitionKey: { + name: 'pk', // Composite of repo/collection ids + type: AttributeType.STRING, + }, + sortKey: { + name: 'document_id', + type: AttributeType.STRING + }, + deletionProtection: true, + }); + docTable.addGlobalSecondaryIndex({ + indexName: 'document_index', + partitionKey: { + name: 'document_id', + type: AttributeType.STRING, + }, + }); + const baseEnvironment: Record = { REGISTERED_MODELS_PS_NAME: modelsPs.parameterName, BUCKET_NAME: bucket.bucketName, @@ -105,6 +126,7 @@ export class LisaRagStack extends Stack { CHUNK_OVERLAP: config.ragFileProcessingConfig!.chunkOverlap.toString(), LISA_API_URL_PS_NAME: endpointUrl.parameterName, REST_API_VERSION: 'v2', + RAG_DOCUMENT_TABLE: docTable.tableName, }; // Add REST API SSL Cert ARN if it exists to be used to verify SSL calls to REST API @@ -333,7 +355,8 @@ export class LisaRagStack extends Stack { repositoryId: ragConfig.repositoryId, type: ragConfig.type, layers: [commonLambdaLayer, ragLambdaLayer.layer, sdkLayer], - registeredRepositoriesParamName + registeredRepositoriesParamName, + ragDocumentTable: docTable }); console.log(`[DEBUG] Successfully created pipeline ${index}`); } catch (error) { @@ -369,6 +392,7 @@ export class LisaRagStack extends Stack { ragRepositoriesParam.grantRead(lambdaRole); modelsPs.grantRead(lambdaRole); endpointUrl.grantRead(lambdaRole); + docTable.grantReadWriteData(lambdaRole); } /** diff --git a/lib/rag/state_machine/ingest-pipeline.ts b/lib/rag/state_machine/ingest-pipeline.ts index 4ea14ec8..2d94fef4 100644 --- a/lib/rag/state_machine/ingest-pipeline.ts +++ b/lib/rag/state_machine/ingest-pipeline.ts @@ -39,6 +39,7 @@ import { RagRepositoryType } from '../../schema'; import * as kms from 'aws-cdk-lib/aws-kms'; import * as cdk from 'aws-cdk-lib'; import { getDefaultRuntime } from '../../api-base/utils'; +import { Table } from 'aws-cdk-lib/aws-dynamodb'; type PipelineConfig = { chunkOverlap: number; @@ -66,6 +67,7 @@ type IngestPipelineStateMachineProps = BaseProps & { type: RagRepositoryType; layers?: ILayerVersion[]; registeredRepositoriesParamName: string; + ragDocumentTable: Table; }; /** @@ -77,7 +79,7 @@ export class IngestPipelineStateMachine extends Construct { constructor (scope: Construct, id: string, props: IngestPipelineStateMachineProps) { super(scope, id); - const {config, vpc, pipelineConfig, rdsConfig, repositoryId, type, layers, registeredRepositoriesParamName} = props; + const {config, vpc, pipelineConfig, rdsConfig, repositoryId, type, layers, registeredRepositoriesParamName, ragDocumentTable} = props; // Create KMS key for environment variable encryption const kmsKey = new kms.Key(this, 'EnvironmentEncryptionKey', { @@ -98,6 +100,7 @@ export class IngestPipelineStateMachine extends Construct { RDS_CONNECTION_INFO_PS_NAME: `${config.deploymentPrefix}/LisaServeRagPGVectorConnectionInfo`, OPENSEARCH_ENDPOINT_PS_NAME: `${config.deploymentPrefix}/lisaServeRagRepositoryEndpoint`, LISA_API_URL_PS_NAME: `${config.deploymentPrefix}/lisaServeRestApiUri`, + RAG_DOCUMENT_TABLE: ragDocumentTable.tableName, LOG_LEVEL: config.logLevel, REGISTERED_REPOSITORIES_PS_NAME: registeredRepositoriesParamName, REGISTERED_REPOSITORIES_PS_PREFIX: `${config.deploymentPrefix}/LisaServeRagConnectionInfo/`, @@ -120,9 +123,23 @@ export class IngestPipelineStateMachine extends Construct { `arn:${cdk.Aws.PARTITION}:s3:::${pipelineConfig.s3Bucket}/*` ] }); + // Allow DynamoDB Read/Write to RAG Document Table + const dynamoPolicyStatement = new PolicyStatement({ + effect: Effect.ALLOW, + actions: [ + 'dynamodb:BatchGetItem', + 'dynamodb:GetItem', + 'dynamodb:Query', + 'dynamodb:Scan', + 'dynamodb:BatchWriteItem', + 'dynamodb:PutItem', + 'dynamodb:UpdateItem', + ], + resources: [ragDocumentTable.tableArn, `${ragDocumentTable.tableArn}/index/*`] + }); // Create array of policy statements - const policyStatements = [s3PolicyStatement]; + const policyStatements = [s3PolicyStatement, dynamoPolicyStatement]; // Create IAM certificate policy if certificate ARN is provided let certPolicyStatement;