-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #6 from weni-ai/feature/text-splitter
add character text splitter
- Loading branch information
Showing
29 changed files
with
2,558 additions
and
1,077 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,4 @@ | ||
[report] | ||
exclude_lines = pass | ||
pragma: no cover | ||
raise NotImplementedError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import os | ||
from celery import Celery | ||
|
||
from typing import Dict | ||
from app.indexer.indexer_file_manager import IndexerFileManager | ||
from app.downloaders.s3 import S3FileDownloader | ||
|
||
from app.handlers.nexus import NexusRESTClient | ||
from app.text_splitters import TextSplitter, character_text_splitter | ||
|
||
|
||
celery = Celery(__name__) | ||
celery.conf.broker_url = os.environ.get( | ||
"CELERY_BROKER_URL", "redis://localhost:6379" | ||
) | ||
celery.conf.result_backend = os.environ.get( | ||
"CELERY_RESULT_BACKEND", "redis://localhost:6379" | ||
) | ||
|
||
|
||
@celery.task(name="index_file") | ||
def index_file_data(content_base: Dict) -> bool: | ||
from app.main import main_app | ||
|
||
file_downloader = S3FileDownloader( | ||
os.environ.get("AWS_STORAGE_ACCESS_KEY"), | ||
os.environ.get("AWS_STORAGE_SECRET_KEY") | ||
) | ||
text_splitter = TextSplitter(character_text_splitter()) | ||
manager = IndexerFileManager(file_downloader, main_app.content_base_indexer, text_splitter) | ||
index_result: bool = manager.index_file_url(content_base) | ||
NexusRESTClient().index_succedded( | ||
task_succeded=index_result, | ||
nexus_task_uuid=content_base.get("task_uuid"), | ||
file_type=content_base.get("extension_file") | ||
) | ||
|
||
return index_result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from abc import ABC, abstractmethod | ||
from app.downloaders.exceptions import FileDownloaderException | ||
|
||
|
||
class IFileDownloader(ABC): | ||
@abstractmethod | ||
def authenticate(self): | ||
pass | ||
|
||
@abstractmethod | ||
def download_file(self): | ||
pass | ||
|
||
@abstractmethod | ||
def download_file_batch(self): | ||
pass | ||
|
||
@abstractmethod | ||
def download_file_bulk(self): | ||
pass | ||
|
||
def download_file(file_downloader, file_name: str) -> None: | ||
handler = file_downloader | ||
try: | ||
handler.download_file(file_name) | ||
except Exception as err: | ||
raise FileDownloaderException(err) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
class FileDownloaderException(BaseException): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from app.downloaders.s3.file_downloader import S3FileDownloader |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
import os | ||
import boto3 | ||
from app.downloaders import IFileDownloader | ||
from fastapi.logger import logger | ||
from urllib.parse import urlparse | ||
from typing import Tuple, List | ||
|
||
|
||
class S3FileDownloader(IFileDownloader): | ||
|
||
def __init__(self, | ||
access_key: str, | ||
secret_key: str, | ||
bucket_name: str = os.environ.get("AWS_STORAGE_BUCKET_NAME"), | ||
) -> None: | ||
self.bucket_name = bucket_name | ||
self.access_key = access_key | ||
self.secret_key = secret_key | ||
self.client = self.authenticate() | ||
|
||
def authenticate(self): | ||
return boto3.client( | ||
"s3", | ||
aws_access_key_id=self.access_key, | ||
aws_secret_access_key=self.secret_key | ||
) | ||
|
||
def download_file(self, file_name): | ||
bucket = self.bucket_name | ||
key = file_name | ||
local_path = f"app/files/{file_name}" | ||
|
||
self.client.download_file(bucket, key, local_path) | ||
|
||
def download_file_batch(self): | ||
raise NotImplementedError | ||
|
||
def download_file_bulk(self): | ||
raise NotImplementedError | ||
|
||
|
||
def get_s3_bucket_and_file_name(file_url: str)-> Tuple[str, ...]: | ||
result = urlparse(file_url) | ||
bucket_name = result.netloc.split('.s3')[0] | ||
file_name = result.path.strip('/') | ||
return bucket_name, file_name |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
from pydantic.v1 import root_validator | ||
import time | ||
from langchain.embeddings import SagemakerEndpointEmbeddings | ||
from typing import Dict, List | ||
|
||
|
||
class SagemakerEndpointEmbeddingsKeys(SagemakerEndpointEmbeddings): | ||
aws_key: str = "" | ||
aws_secret: str = "" | ||
|
||
@root_validator(skip_on_failure=True) | ||
def validate_environment(cls, values: Dict) -> Dict: | ||
import boto3 | ||
|
||
session = boto3.Session( | ||
aws_access_key_id=values["aws_key"], | ||
aws_secret_access_key=values["aws_secret"] | ||
) | ||
|
||
values["client"] = session.client( | ||
"sagemaker-runtime", region_name=values["region_name"] | ||
) | ||
|
||
return values | ||
|
||
def embed_documents( | ||
self, texts: List[str], chunk_size: int = 32 | ||
) -> List[List[float]]: | ||
"""Compute doc embeddings using a SageMaker Inference Endpoint. | ||
Args: | ||
texts: The list of texts to embed. | ||
chunk_size: The chunk size defines how many input texts will | ||
be grouped together as request. If None, will use the | ||
chunk size specified by the class. | ||
Returns: | ||
List of embeddings, one for each text. | ||
""" | ||
results = [] | ||
_chunk_size = len(texts) if chunk_size > len(texts) else chunk_size | ||
for i in range(0, len(texts), _chunk_size): | ||
response = self._embedding_func(texts[i:i + _chunk_size]) | ||
results.extend(response) | ||
return results | ||
|
||
def _embedding_func(self, texts: List[str]) -> List[List[float]]: | ||
"""Call out to SageMaker Inference embedding endpoint.""" | ||
# replace newlines, which can negatively affect performance. | ||
texts = list(map(lambda x: x.replace("\n", " "), texts)) | ||
_model_kwargs = self.model_kwargs or {} | ||
_endpoint_kwargs = self.endpoint_kwargs or {} | ||
|
||
body = self.content_handler.transform_input(texts, _model_kwargs) | ||
content_type = self.content_handler.content_type | ||
accepts = self.content_handler.accepts | ||
|
||
# send request | ||
while True: | ||
try: | ||
response = self.client.invoke_endpoint( | ||
EndpointName=self.endpoint_name, | ||
Body=body, | ||
ContentType=content_type, | ||
Accept=accepts, | ||
**_endpoint_kwargs, | ||
) | ||
return self.content_handler.transform_output(response["Body"]) | ||
except Exception as e: | ||
print( | ||
f"Error raised by inference endpoint: {e}\nBody: . Trying again in 5 seconds." | ||
) | ||
time.sleep(60 * 5) | ||
|
||
raise ValueError( | ||
f"Error raised by inference endpoint: {e}. Trying again in 5 seconds." | ||
) | ||
return self.content_handler.transform_output(response["Body"]) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
import os | ||
from fastapi import HTTPException | ||
|
||
|
||
def token_verification(token: str): | ||
if os.environ.get("SENTENX_TOKEN") == token: | ||
return | ||
raise HTTPException(status_code=401, detail=[{"msg": str("Unauthorized")}]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
from fastapi import APIRouter, Header | ||
from pydantic import BaseModel | ||
|
||
from app.handlers import IDocumentHandler | ||
from app.indexer import IDocumentIndexer | ||
|
||
from app.celery import index_file_data | ||
from typing import List | ||
from typing import Annotated | ||
from app.handlers.authorizations import token_verification | ||
|
||
|
||
class ContentBaseIndexRequest(BaseModel): | ||
file: str | ||
filename: str | ||
file_uuid: str | ||
extension_file: str | ||
task_uuid: str | ||
content_base: str | ||
|
||
|
||
class ContentBaseIndexResponse(BaseModel): | ||
file: str | ||
filename: str | ||
task_uuid: str | ||
|
||
|
||
class ContentBaseSearchRequest(BaseModel): | ||
search: str | ||
filter: dict[str, str] = None | ||
threshold: float = 1.5 | ||
|
||
|
||
class ContentBaseSearchResponse(BaseModel): | ||
response: List[str] | ||
|
||
|
||
class ContentBaseDeleteRequest(BaseModel): | ||
filename: str | ||
content_base: str | ||
file_uuid: str | ||
|
||
|
||
class ContentBaseDeleteResponse(BaseModel): | ||
deleted: bool | ||
|
||
|
||
class ContentBaseHandler(IDocumentHandler): | ||
def __init__(self, content_base_indexer: IDocumentIndexer): | ||
self.content_base_indexer = content_base_indexer | ||
self.router = APIRouter() | ||
self.router.add_api_route( | ||
"/content_base/index", endpoint=self.index, methods=["PUT"] | ||
) | ||
self.router.add_api_route( | ||
"/content_base/search", endpoint=self.search, methods=["POST"] | ||
) | ||
self.router.add_api_route( | ||
"/content_base/delete", endpoint=self.delete, methods=["DELETE"] | ||
) | ||
|
||
def index(self, request: ContentBaseIndexRequest, Authorization: Annotated[str | None, Header()] = None): | ||
token_verification(Authorization) | ||
content_base = request.__dict__ | ||
task = index_file_data.delay(content_base) | ||
|
||
return ContentBaseIndexResponse( | ||
file=request.file, | ||
filename=request.filename, | ||
task_uuid=task.id, | ||
) | ||
|
||
def batch_index(self): | ||
raise NotImplementedError | ||
|
||
def delete(self, request: ContentBaseDeleteRequest, Authorization: Annotated[str | None, Header()] = None): | ||
token_verification(Authorization) | ||
self.content_base_indexer.delete( | ||
request.content_base, | ||
request.filename, | ||
request.file_uuid, | ||
) | ||
return ContentBaseDeleteResponse(deleted=True) | ||
|
||
def delete_batch(self): | ||
raise NotImplementedError | ||
|
||
def search(self, request: ContentBaseSearchRequest, Authorization: Annotated[str | None, Header()] = None): | ||
token_verification(Authorization) | ||
response = self.content_base_indexer.search( | ||
search=request.search.lower(), | ||
threshold=request.threshold, | ||
filter=request.filter | ||
) | ||
return ContentBaseSearchResponse(response=response) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
|
||
import os | ||
import json | ||
import requests | ||
|
||
|
||
class NexusRESTClient: | ||
token = os.environ.get("NEXUS_AI_TOKEN") | ||
base_url = os.environ.get("NEXUS_AI_URL") | ||
|
||
def __init__(self) -> None: | ||
self.headers = { | ||
'Authorization': self.token, | ||
'Content-Type': "application/json" | ||
} | ||
|
||
def index_succedded(self, task_succeded: bool, nexus_task_uuid: str, file_type: str) -> None: | ||
endpoint = f'{self.base_url}/api/v1/content-base-file' | ||
data = { | ||
"status": int(task_succeded), | ||
"task_uuid": nexus_task_uuid, | ||
"file_type": "text" if file_type == "txt" else "file", | ||
} | ||
response = requests.patch(url=endpoint, data=json.dumps(data), headers=self.headers) | ||
response.raise_for_status() |
Oops, something went wrong.