Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/rag list api #215

Merged
merged 7 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 33 additions & 2 deletions lambda/models/domain_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@

"""Domain objects for interacting with the model endpoints."""

import time
import uuid
from enum import Enum
from typing import Annotated, Dict, List, Optional, Union
from typing import Annotated, Any, Dict, List, Optional, Union

from pydantic import BaseModel, Field, NonNegativeInt, PositiveInt
from pydantic import BaseModel, ConfigDict, Field, NonNegativeInt, PositiveInt
from pydantic.functional_validators import AfterValidator, field_validator, model_validator
from typing_extensions import Self
from utilities.validators import validate_all_fields_defined, validate_any_fields_defined, validate_instance_type
Expand Down Expand Up @@ -292,3 +294,32 @@ class DeleteModelResponse(ApiResponseBase):
"""Response object when deleting a model."""

pass


class IngestionType(Enum):
AUTO = "auto"
MANUAL = "manual"


class RagDocument(BaseModel):
"""Rag Document Entity for storing in DynamoDB."""

pk: Optional[str] = None
document_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
repository_id: str
collection_id: str
document_name: str
source: str
sub_docs: List[str] = Field(default_factory=lambda: [])
ingestion_type: IngestionType = Field(default_factory=lambda: IngestionType.MANUAL)
upload_date: int = Field(default_factory=lambda: int(time.time()))

model_config = ConfigDict(use_enum_values=True, validate_default=True)

def __init__(self, **data: Any) -> None:
super().__init__(**data)
self.pk = self.createPartitionKey(self.repository_id, self.collection_id)

@staticmethod
def createPartitionKey(repository_id: str, collection_id: str) -> str:
return f"{repository_id}#{collection_id}"
204 changes: 176 additions & 28 deletions lambda/repository/lambda_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,23 @@
import requests
from botocore.config import Config
from lisapy.langchain import LisaOpenAIEmbeddings
from models.domain_objects import IngestionType, RagDocument
from repository.rag_document_repo import RagDocumentRepository
from utilities.common_functions import api_wrapper, get_cert_path, get_id_token, retry_config
from utilities.exceptions import HTTPException
from utilities.file_processing import process_record
from utilities.validation import validate_model_name, ValidationError
from utilities.vector_store import find_repository_by_id, get_registered_repositories, get_vector_store_client

logger = logging.getLogger(__name__)
region_name = os.environ["AWS_REGION"]
session = boto3.Session()
ssm_client = boto3.client("ssm", region_name=os.environ["AWS_REGION"], config=retry_config)
secrets_client = boto3.client("secretsmanager", region_name=os.environ["AWS_REGION"], config=retry_config)
iam_client = boto3.client("iam", region_name=os.environ["AWS_REGION"], config=retry_config)
ssm_client = boto3.client("ssm", region_name, config=retry_config)
secrets_client = boto3.client("secretsmanager", region_name, config=retry_config)
iam_client = boto3.client("iam", region_name, config=retry_config)
s3 = session.client(
"s3",
region_name=os.environ["AWS_REGION"],
region_name,
config=Config(
retries={
"max_attempts": 3,
Expand All @@ -45,6 +48,8 @@
),
)
lisa_api_endpoint = ""
registered_repositories: List[Dict[str, Any]] = []
doc_repo = RagDocumentRepository(os.environ["RAG_DOCUMENT_TABLE"])


def _get_embeddings(model_name: str, id_token: str) -> LisaOpenAIEmbeddings:
Expand Down Expand Up @@ -183,7 +188,7 @@ def _get_embeddings_pipeline(model_name: str) -> Any:
def list_all(event: dict, context: dict) -> List[Dict[str, Any]]:
"""Return info on all available repositories.

Currently there is not support for dynamic repositories so only a single OpenSearch repository
Currently, there is no support for dynamic repositories so only a single OpenSearch repository
is returned.
"""

Expand Down Expand Up @@ -212,9 +217,22 @@ def similarity_search(event: dict, context: dict) -> Dict[str, Any]:
"""Return documents matching the query.

Conducts similarity search against the vector store returning the top K
documents based on the specified query. 'topK' can be set as an optional
querystring parameter, if it is not specified the top 3 documents will be
returned.
documents based on the specified query.

Args:
event (dict): The Lambda event object containing:
- queryStringParameters.modelName: Name of the embedding model
- queryStringParameters.query: Search query text
- queryStringParameters.repositoryType: Type of repository
- queryStringParameters.topK (optional): Number of results to return (default: 3)
context (dict): The Lambda context object

Returns:
Dict[str, Any]: A dictionary containing:
- docs: List of matching documents with their content and metadata

Raises:
ValidationError: If required parameters are missing or invalid
"""
query_string_params = event["queryStringParameters"]
model_name = query_string_params["modelName"]
Expand All @@ -240,64 +258,166 @@ def similarity_search(event: dict, context: dict) -> Dict[str, Any]:
return doc_return


def ensure_repository_access(event, repository):
def ensure_repository_access(event: dict[str, Any], repository: dict[str, Any]) -> None:
"Ensures a user has access to the repository or else raises an HTTPException"
user_groups = json.loads(event["requestContext"]["authorizer"]["groups"]) or []
if not user_has_group(user_groups, repository["allowedGroups"]):
raise HTTPException(status_code=403, message="User does not have permission to access this repository")


@api_wrapper
def purge_document(event: dict, context: dict) -> Dict[str, Any]:
"""Purge all records related to the specified document from the RAG repository."""
user_id = event["requestContext"]["authorizer"]["username"]
repository_id = event["pathParameters"]["repositoryId"]
document_id = event["pathParameters"]["sessionId"]
def delete_document(event: dict, context: dict) -> Dict[str, Any]:
"""Purge all records related to the specified document from the RAG repository. If a documentId is supplied, a
single document will be removed. If a documentName is supplied, all documents with that name will be removed

logger.info(
f"Purging records associated with document {document_id} "
f"(requesting user: {user_id}), repository: {repository_id}"
)
Args:
event (dict): The Lambda event object containing:
- pathParameters.repositoryId: The repository id of VectorStore
- queryStringParameters.collectionId: The collection identifier
- queryStringParameters.repositoryType: Type of repository of VectorStore
- queryStringParameters.documentId (optional): Name of document to purge
- queryStringParameters.documentName (optional): Name of document to purge
context (dict): The Lambda context object

Returns:
Dict[str, Any]: A dictionary containing:
- documentName (str): Name of the purged document
- recordsPurged (int): Number of records purged from VectorStore

Raises:
ValueError: If document is not found in repository
"""
path_params = event.get("pathParameters", {})
repository_id = path_params.get("repositoryId")

query_string_params = event["queryStringParameters"]
collection_id = query_string_params["collectionId"]
document_id = query_string_params.get("documentId")
document_name = query_string_params.get("documentName")

return {"documentId": document_id, "recordsPurged": 0}
if not document_id and not document_name:
raise ValidationError("Either documentId or documentName must be specified")
if document_id and document_name:
raise ValidationError("Only one of documentId or documentName must be specified")

docs = []
if document_id:
docs = [doc_repo.find_by_id(document_id)]
elif document_name:
docs = doc_repo.find_by_name(repository_id, collection_id, document_name)

if not docs:
raise ValueError(f"No documents found in repository collection {repository_id}:{collection_id}")

# Grab all sub document ids related to the parent document(s)
subdoc_ids = [sub_doc for doc in docs for sub_doc in doc.get("sub_docs", [])]

id_token = get_id_token(event)
embeddings = _get_embeddings(model_name=collection_id, id_token=id_token)
vs = get_vector_store_client(repository_id=repository_id, index=collection_id, embeddings=embeddings)

vs.delete(ids=subdoc_ids)

doc_repo.batch_delete(docs)

return {
"documentName": docs[0].get("document_name"),
"removedDocuments": len(docs),
"removedDocumentChunks": len(subdoc_ids),
}


@api_wrapper
def ingest_documents(event: dict, context: dict) -> dict:
"""Ingest a set of documents into the specified repository."""
"""Ingest documents into the RAG repository.

Args:
event (dict): The Lambda event object containing:
- body.embeddingModel.modelName: Document collection id
- body.keys: List of s3 keys to ingest
- pathParameters.repositoryId: Repository id (VectorStore)
- queryStringParameters.repositoryType: Repository type (VectorStore)
- queryStringParameters.chunkSize (optional): Size of text chunks
- queryStringParameters.chunkOverlap (optional): Overlap between chunks
context (dict): The Lambda context object

Returns:
dict: A dictionary containing:
- ids (list): List of generated document IDs
- count (int): Total number of documents ingested

Raises:
ValidationError: If required parameters are missing or invalid
"""
body = json.loads(event["body"])
embedding_model = body["embeddingModel"]
model_name = embedding_model["modelName"]

path_params = event.get("pathParameters", {})
repository_id = path_params.get("repositoryId")

query_string_params = event["queryStringParameters"]
chunk_size = int(query_string_params["chunkSize"]) if "chunkSize" in query_string_params else None
chunk_overlap = int(query_string_params["chunkOverlap"]) if "chunkOverlap" in query_string_params else None
repository_id = event["pathParameters"]["repositoryId"]
logger.info(f"using repository {repository_id}")

repository = find_repository_by_id(repository_id)
ensure_repository_access(event, repository)

docs = process_record(s3_keys=body["keys"], chunk_size=chunk_size, chunk_overlap=chunk_overlap)
keys = body["keys"]
docs = process_record(s3_keys=keys, chunk_size=chunk_size, chunk_overlap=chunk_overlap)

texts = [] # list of strings
metadatas = [] # list of dicts
all_ids = []
id_token = get_id_token(event)
embeddings = _get_embeddings(model_name=model_name, id_token=id_token)
vs = get_vector_store_client(repository_id, index=model_name, embeddings=embeddings)

# Batch document ingestion one parent document at a time
for doc_list in docs:
document_name = doc_list[0].metadata.get("name")
doc_source = doc_list[0].metadata.get("source")
for doc in doc_list:
texts.append(doc.page_content)
metadatas.append(doc.metadata)
# Ingest document into vector store
ids = vs.add_texts(texts=texts, metadatas=metadatas)

id_token = get_id_token(event)
embeddings = _get_embeddings(model_name=model_name, id_token=id_token)
vs = get_vector_store_client(repository_id, index=model_name, embeddings=embeddings)
ids = vs.add_texts(texts=texts, metadatas=metadatas)
return {"ids": ids, "count": len(ids)}
# Add document to RagDocTable
doc_entity = RagDocument(
repository_id=repository_id,
collection_id=model_name,
document_name=document_name,
source=doc_source,
sub_docs=ids,
ingestion_type=IngestionType.MANUAL,
)
doc_repo.save(doc_entity)

all_ids.extend(ids)

return {"ids": all_ids, "count": len(all_ids)}


@api_wrapper
def presigned_url(event: dict, context: dict) -> dict:
"""Generate a pre-signed URL for uploading files to the RAG ingest bucket."""
"""Generate a pre-signed URL for uploading files to the RAG ingest bucket.

Args:
event (dict): The Lambda event object containing:
- body: The key for the file
- requestContext.authorizer.username: The authenticated username
context (dict): The Lambda context object

Returns:
dict: A dictionary containing:
- response: The presigned URL response object with upload fields and URL

Notes:
- URL expires in 3600 seconds (1 hour)
- Maximum file size is 52428800 bytes (50MB)
"""
response = ""
key = event["body"]

Expand Down Expand Up @@ -325,3 +445,31 @@ def presigned_url(event: dict, context: dict) -> dict:
def get_groups(event: Any) -> List[str]:
groups: List[str] = json.loads(event["requestContext"]["authorizer"]["groups"])
return groups


@api_wrapper
def list_docs(event: dict, context: dict) -> List[RagDocument]:
"""List all documents for a given repository/collection.

Args:
event (dict): The Lambda event object containing query parameters
- pathParameters.repositoryId: The repository id to list documents for
- queryStringParameters.collectionId: The collection id to list documents for
context (dict): The Lambda context object

Returns:
list[RagDocument]: A list of RagDocument objects representing all documents
in the specified collection

Raises:
KeyError: If collectionId is not provided in queryStringParameters
"""

path_params = event.get("pathParameters", {})
repository_id = path_params.get("repositoryId")

query_string_params = event.get("queryStringParameters", {})
collection_id = query_string_params.get("collectionId")

docs: List[RagDocument] = doc_repo.list_all(repository_id, collection_id)
return docs
17 changes: 16 additions & 1 deletion lambda/repository/pipeline_ingest_documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,20 @@
from typing import Any, Dict, List

import boto3
from repository.lambda_functions import RagDocumentRepository
from utilities.common_functions import retry_config
from utilities.file_processing import process_record
from utilities.validation import validate_chunk_params, validate_model_name, validate_repository_type, ValidationError
from utilities.vector_store import get_vector_store_client

from .lambda_functions import _get_embeddings_pipeline
from .lambda_functions import _get_embeddings_pipeline, IngestionType, RagDocument

logger = logging.getLogger(__name__)
session = boto3.Session()
ssm_client = boto3.client("ssm", region_name=os.environ["AWS_REGION"], config=retry_config)

doc_repo = RagDocumentRepository(os.environ["RAG_DOCUMENT_TABLE"])


def batch_texts(texts: List[str], metadatas: List[Dict], batch_size: int = 500) -> list[tuple[list[str], list[dict]]]:
"""
Expand Down Expand Up @@ -110,6 +113,7 @@ def handle_pipeline_ingest_documents(event: Dict[str, Any], context: Any) -> Dic
# Prepare texts and metadata
texts = []
metadatas = []

for doc_list in docs:
for doc in doc_list:
texts.append(doc.page_content)
Expand Down Expand Up @@ -147,6 +151,17 @@ def handle_pipeline_ingest_documents(event: Dict[str, Any], context: Any) -> Dic

logger.info(f"Successfully processed {len(all_ids)} chunks from {s3_key} for repository {repository_id}")

# Store RagDocument entry in Document Table
doc_entity = RagDocument(
repository_id=repository_id,
collection_id=embedding_model,
document_name=key,
source=docs[0][0].metadata.get("source"),
sub_docs=all_ids,
ingestion_type=IngestionType.AUTO,
)
doc_repo.save(doc_entity)

return {
"message": f"Successfully processed document {s3_key}",
"repository_id": repository_id,
Expand Down
Loading
Loading