Skip to content

Commit

Permalink
Merge pull request #6 from weni-ai/feature/text-splitter
Browse files Browse the repository at this point in the history
add character text splitter
  • Loading branch information
AlisoSouza authored Jan 31, 2024
2 parents ec1c24b + 7d7f75e commit 18c782b
Show file tree
Hide file tree
Showing 29 changed files with 2,558 additions and 1,077 deletions.
2 changes: 2 additions & 0 deletions .coveragerc
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
[report]
exclude_lines = pass
pragma: no cover
raise NotImplementedError
1 change: 1 addition & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ RUN poetry config virtualenvs.create false && \
COPY . .

EXPOSE 8000
EXPOSE 9200

COPY entrypoint.sh /entrypoint.sh

Expand Down
38 changes: 38 additions & 0 deletions app/celery.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import os
from celery import Celery

from typing import Dict
from app.indexer.indexer_file_manager import IndexerFileManager
from app.downloaders.s3 import S3FileDownloader

from app.handlers.nexus import NexusRESTClient
from app.text_splitters import TextSplitter, character_text_splitter


celery = Celery(__name__)
celery.conf.broker_url = os.environ.get(
"CELERY_BROKER_URL", "redis://localhost:6379"
)
celery.conf.result_backend = os.environ.get(
"CELERY_RESULT_BACKEND", "redis://localhost:6379"
)


@celery.task(name="index_file")
def index_file_data(content_base: Dict) -> bool:
from app.main import main_app

file_downloader = S3FileDownloader(
os.environ.get("AWS_STORAGE_ACCESS_KEY"),
os.environ.get("AWS_STORAGE_SECRET_KEY")
)
text_splitter = TextSplitter(character_text_splitter())
manager = IndexerFileManager(file_downloader, main_app.content_base_indexer, text_splitter)
index_result: bool = manager.index_file_url(content_base)
NexusRESTClient().index_succedded(
task_succeded=index_result,
nexus_task_uuid=content_base.get("task_uuid"),
file_type=content_base.get("extension_file")
)

return index_result
16 changes: 16 additions & 0 deletions app/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import os
from dotenv import load_dotenv

load_dotenv()


class AppConfig:
Expand All @@ -24,5 +27,18 @@ def __init__(self):
"HUGGINGFACE_API_TOKEN", "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"
),
}
self.sagemaker_aws ={
"endpoint_name": os.environ.get(
"SAGEMAKER_ENDPOINT_NAME",
"huggingface-pytorch-inference-2023-10-25-14-25-59-713",
),
"region_name": os.environ.get("SAGEMAKER_REGION_NAME", "us-east-1"),
"aws_key": os.environ.get("SAGE_MAKER_AWS_KEY"),
"aws_secret": os.environ.get("SAGE_MAKER_AWS_SECRET"),
}

self.content_base_index_name = os.environ.get(
"INDEX_CONTENTBASES_NAME", "content_bases"
)
self.sentry_dsn = os.environ.get("SENTRY_DSN", "")
self.es_timeout = os.environ.get("ELASTICSEARCH_TIMEOUT", "30")
27 changes: 27 additions & 0 deletions app/downloaders/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from abc import ABC, abstractmethod
from app.downloaders.exceptions import FileDownloaderException


class IFileDownloader(ABC):
@abstractmethod
def authenticate(self):
pass

@abstractmethod
def download_file(self):
pass

@abstractmethod
def download_file_batch(self):
pass

@abstractmethod
def download_file_bulk(self):
pass

def download_file(file_downloader, file_name: str) -> None:
handler = file_downloader
try:
handler.download_file(file_name)
except Exception as err:
raise FileDownloaderException(err)
2 changes: 2 additions & 0 deletions app/downloaders/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
class FileDownloaderException(BaseException):
pass
1 change: 1 addition & 0 deletions app/downloaders/s3/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from app.downloaders.s3.file_downloader import S3FileDownloader
46 changes: 46 additions & 0 deletions app/downloaders/s3/file_downloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import os
import boto3
from app.downloaders import IFileDownloader
from fastapi.logger import logger
from urllib.parse import urlparse
from typing import Tuple, List


class S3FileDownloader(IFileDownloader):

def __init__(self,
access_key: str,
secret_key: str,
bucket_name: str = os.environ.get("AWS_STORAGE_BUCKET_NAME"),
) -> None:
self.bucket_name = bucket_name
self.access_key = access_key
self.secret_key = secret_key
self.client = self.authenticate()

def authenticate(self):
return boto3.client(
"s3",
aws_access_key_id=self.access_key,
aws_secret_access_key=self.secret_key
)

def download_file(self, file_name):
bucket = self.bucket_name
key = file_name
local_path = f"app/files/{file_name}"

self.client.download_file(bucket, key, local_path)

def download_file_batch(self):
raise NotImplementedError

def download_file_bulk(self):
raise NotImplementedError


def get_s3_bucket_and_file_name(file_url: str)-> Tuple[str, ...]:
result = urlparse(file_url)
bucket_name = result.netloc.split('.s3')[0]
file_name = result.path.strip('/')
return bucket_name, file_name
Empty file added app/embedders/__init__.py
Empty file.
79 changes: 79 additions & 0 deletions app/embedders/embedders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from pydantic.v1 import root_validator
import time
from langchain.embeddings import SagemakerEndpointEmbeddings
from typing import Dict, List


class SagemakerEndpointEmbeddingsKeys(SagemakerEndpointEmbeddings):
aws_key: str = ""
aws_secret: str = ""

@root_validator(skip_on_failure=True)
def validate_environment(cls, values: Dict) -> Dict:
import boto3

session = boto3.Session(
aws_access_key_id=values["aws_key"],
aws_secret_access_key=values["aws_secret"]
)

values["client"] = session.client(
"sagemaker-runtime", region_name=values["region_name"]
)

return values

def embed_documents(
self, texts: List[str], chunk_size: int = 32
) -> List[List[float]]:
"""Compute doc embeddings using a SageMaker Inference Endpoint.
Args:
texts: The list of texts to embed.
chunk_size: The chunk size defines how many input texts will
be grouped together as request. If None, will use the
chunk size specified by the class.
Returns:
List of embeddings, one for each text.
"""
results = []
_chunk_size = len(texts) if chunk_size > len(texts) else chunk_size
for i in range(0, len(texts), _chunk_size):
response = self._embedding_func(texts[i:i + _chunk_size])
results.extend(response)
return results

def _embedding_func(self, texts: List[str]) -> List[List[float]]:
"""Call out to SageMaker Inference embedding endpoint."""
# replace newlines, which can negatively affect performance.
texts = list(map(lambda x: x.replace("\n", " "), texts))
_model_kwargs = self.model_kwargs or {}
_endpoint_kwargs = self.endpoint_kwargs or {}

body = self.content_handler.transform_input(texts, _model_kwargs)
content_type = self.content_handler.content_type
accepts = self.content_handler.accepts

# send request
while True:
try:
response = self.client.invoke_endpoint(
EndpointName=self.endpoint_name,
Body=body,
ContentType=content_type,
Accept=accepts,
**_endpoint_kwargs,
)
return self.content_handler.transform_output(response["Body"])
except Exception as e:
print(
f"Error raised by inference endpoint: {e}\nBody: . Trying again in 5 seconds."
)
time.sleep(60 * 5)

raise ValueError(
f"Error raised by inference endpoint: {e}. Trying again in 5 seconds."
)
return self.content_handler.transform_output(response["Body"])
Empty file added app/files/__init__.py
Empty file.
8 changes: 8 additions & 0 deletions app/handlers/authorizations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import os
from fastapi import HTTPException


def token_verification(token: str):
if os.environ.get("SENTENX_TOKEN") == token:
return
raise HTTPException(status_code=401, detail=[{"msg": str("Unauthorized")}])
95 changes: 95 additions & 0 deletions app/handlers/content_bases.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from fastapi import APIRouter, Header
from pydantic import BaseModel

from app.handlers import IDocumentHandler
from app.indexer import IDocumentIndexer

from app.celery import index_file_data
from typing import List
from typing import Annotated
from app.handlers.authorizations import token_verification


class ContentBaseIndexRequest(BaseModel):
file: str
filename: str
file_uuid: str
extension_file: str
task_uuid: str
content_base: str


class ContentBaseIndexResponse(BaseModel):
file: str
filename: str
task_uuid: str


class ContentBaseSearchRequest(BaseModel):
search: str
filter: dict[str, str] = None
threshold: float = 1.5


class ContentBaseSearchResponse(BaseModel):
response: List[str]


class ContentBaseDeleteRequest(BaseModel):
filename: str
content_base: str
file_uuid: str


class ContentBaseDeleteResponse(BaseModel):
deleted: bool


class ContentBaseHandler(IDocumentHandler):
def __init__(self, content_base_indexer: IDocumentIndexer):
self.content_base_indexer = content_base_indexer
self.router = APIRouter()
self.router.add_api_route(
"/content_base/index", endpoint=self.index, methods=["PUT"]
)
self.router.add_api_route(
"/content_base/search", endpoint=self.search, methods=["POST"]
)
self.router.add_api_route(
"/content_base/delete", endpoint=self.delete, methods=["DELETE"]
)

def index(self, request: ContentBaseIndexRequest, Authorization: Annotated[str | None, Header()] = None):
token_verification(Authorization)
content_base = request.__dict__
task = index_file_data.delay(content_base)

return ContentBaseIndexResponse(
file=request.file,
filename=request.filename,
task_uuid=task.id,
)

def batch_index(self):
raise NotImplementedError

def delete(self, request: ContentBaseDeleteRequest, Authorization: Annotated[str | None, Header()] = None):
token_verification(Authorization)
self.content_base_indexer.delete(
request.content_base,
request.filename,
request.file_uuid,
)
return ContentBaseDeleteResponse(deleted=True)

def delete_batch(self):
raise NotImplementedError

def search(self, request: ContentBaseSearchRequest, Authorization: Annotated[str | None, Header()] = None):
token_verification(Authorization)
response = self.content_base_indexer.search(
search=request.search.lower(),
threshold=request.threshold,
filter=request.filter
)
return ContentBaseSearchResponse(response=response)
25 changes: 25 additions & 0 deletions app/handlers/nexus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@

import os
import json
import requests


class NexusRESTClient:
token = os.environ.get("NEXUS_AI_TOKEN")
base_url = os.environ.get("NEXUS_AI_URL")

def __init__(self) -> None:
self.headers = {
'Authorization': self.token,
'Content-Type': "application/json"
}

def index_succedded(self, task_succeded: bool, nexus_task_uuid: str, file_type: str) -> None:
endpoint = f'{self.base_url}/api/v1/content-base-file'
data = {
"status": int(task_succeded),
"task_uuid": nexus_task_uuid,
"file_type": "text" if file_type == "txt" else "file",
}
response = requests.patch(url=endpoint, data=json.dumps(data), headers=self.headers)
response.raise_for_status()
Loading

0 comments on commit 18c782b

Please sign in to comment.