-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
221 additions
and
167 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
import re | ||
from typing import List, Tuple, Union | ||
import logging | ||
from application.parser.schema.base import Document | ||
from application.utils import get_encoding | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
class Chunker: | ||
def __init__( | ||
self, | ||
chunking_strategy: str = "classic_chunk", | ||
max_tokens: int = 2000, | ||
min_tokens: int = 150, | ||
duplicate_headers: bool = False, | ||
): | ||
if chunking_strategy not in ["classic_chunk"]: | ||
raise ValueError(f"Unsupported chunking strategy: {chunking_strategy}") | ||
self.chunking_strategy = chunking_strategy | ||
self.max_tokens = max_tokens | ||
self.min_tokens = min_tokens | ||
self.duplicate_headers = duplicate_headers | ||
self.encoding = get_encoding() | ||
|
||
def separate_header_and_body(self, text: str) -> Tuple[str, str]: | ||
header_pattern = r"^(.*?\n){3}" | ||
match = re.match(header_pattern, text) | ||
if match: | ||
header = match.group(0) | ||
body = text[len(header):] | ||
else: | ||
header, body = "", text # No header, treat entire text as body | ||
return header, body | ||
|
||
def combine_documents(self, doc: Document, next_doc: Document) -> Document: | ||
combined_text = doc.text + " " + next_doc.text | ||
combined_token_count = len(self.encoding.encode(combined_text)) | ||
new_doc = Document( | ||
text=combined_text, | ||
doc_id=doc.doc_id, | ||
embedding=doc.embedding, | ||
extra_info={**(doc.extra_info or {}), "token_count": combined_token_count} | ||
) | ||
return new_doc | ||
|
||
def split_document(self, doc: Document) -> List[Document]: | ||
split_docs = [] | ||
header, body = self.separate_header_and_body(doc.text) | ||
header_tokens = self.encoding.encode(header) if header else [] | ||
body_tokens = self.encoding.encode(body) | ||
|
||
current_position = 0 | ||
part_index = 0 | ||
while current_position < len(body_tokens): | ||
end_position = current_position + self.max_tokens - len(header_tokens) | ||
chunk_tokens = (header_tokens + body_tokens[current_position:end_position] | ||
if self.duplicate_headers or part_index == 0 else body_tokens[current_position:end_position]) | ||
chunk_text = self.encoding.decode(chunk_tokens) | ||
new_doc = Document( | ||
text=chunk_text, | ||
doc_id=f"{doc.doc_id}-{part_index}", | ||
embedding=doc.embedding, | ||
extra_info={**(doc.extra_info or {}), "token_count": len(chunk_tokens)} | ||
) | ||
split_docs.append(new_doc) | ||
current_position = end_position | ||
part_index += 1 | ||
header_tokens = [] | ||
return split_docs | ||
|
||
def classic_chunk(self, documents: List[Document]) -> List[Document]: | ||
processed_docs = [] | ||
i = 0 | ||
while i < len(documents): | ||
doc = documents[i] | ||
tokens = self.encoding.encode(doc.text) | ||
token_count = len(tokens) | ||
|
||
if self.min_tokens <= token_count <= self.max_tokens: | ||
doc.extra_info = doc.extra_info or {} | ||
doc.extra_info["token_count"] = token_count | ||
processed_docs.append(doc) | ||
i += 1 | ||
elif token_count < self.min_tokens: | ||
if i + 1 < len(documents): | ||
next_doc = documents[i + 1] | ||
next_tokens = self.encoding.encode(next_doc.text) | ||
if token_count + len(next_tokens) <= self.max_tokens: | ||
# Combine small documents | ||
combined_doc = self.combine_documents(doc, next_doc) | ||
processed_docs.append(combined_doc) | ||
i += 2 | ||
else: | ||
# Keep the small document as is if adding next_doc would exceed max_tokens | ||
doc.extra_info = doc.extra_info or {} | ||
doc.extra_info["token_count"] = token_count | ||
processed_docs.append(doc) | ||
i += 1 | ||
else: | ||
# No next document to combine with; add the small document as is | ||
doc.extra_info = doc.extra_info or {} | ||
doc.extra_info["token_count"] = token_count | ||
processed_docs.append(doc) | ||
i += 1 | ||
else: | ||
# Split large documents | ||
processed_docs.extend(self.split_document(doc)) | ||
i += 1 | ||
return processed_docs | ||
|
||
def chunk( | ||
self, | ||
documents: List[Document] | ||
) -> List[Document]: | ||
if self.chunking_strategy == "classic_chunk": | ||
return self.classic_chunk(documents) | ||
else: | ||
raise ValueError("Unsupported chunking strategy") | ||
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
import os | ||
import logging | ||
from retry import retry | ||
from tqdm import tqdm | ||
from application.core.settings import settings | ||
from application.vectorstore.vector_creator import VectorCreator | ||
|
||
|
||
@retry(tries=10, delay=60) | ||
def add_text_to_store_with_retry(store, doc, source_id): | ||
""" | ||
Add a document's text and metadata to the vector store with retry logic. | ||
Args: | ||
store: The vector store object. | ||
doc: The document to be added. | ||
source_id: Unique identifier for the source. | ||
""" | ||
try: | ||
doc.metadata["source_id"] = str(source_id) | ||
store.add_texts([doc.page_content], metadatas=[doc.metadata]) | ||
except Exception as e: | ||
logging.error(f"Failed to add document with retry: {e}") | ||
raise | ||
|
||
|
||
def embed_and_store_documents(docs, folder_name, source_id, task_status): | ||
""" | ||
Embeds documents and stores them in a vector store. | ||
Args: | ||
docs (list): List of documents to be embedded and stored. | ||
folder_name (str): Directory to save the vector store. | ||
source_id (str): Unique identifier for the source. | ||
task_status: Task state manager for progress updates. | ||
Returns: | ||
None | ||
""" | ||
# Ensure the folder exists | ||
if not os.path.exists(folder_name): | ||
os.makedirs(folder_name) | ||
|
||
# Initialize vector store | ||
if settings.VECTOR_STORE == "faiss": | ||
docs_init = [docs.pop(0)] | ||
store = VectorCreator.create_vectorstore( | ||
settings.VECTOR_STORE, | ||
docs_init=docs_init, | ||
source_id=folder_name, | ||
embeddings_key=os.getenv("EMBEDDINGS_KEY"), | ||
) | ||
else: | ||
store = VectorCreator.create_vectorstore( | ||
settings.VECTOR_STORE, | ||
source_id=source_id, | ||
embeddings_key=os.getenv("EMBEDDINGS_KEY"), | ||
) | ||
store.delete_index() | ||
|
||
total_docs = len(docs) | ||
|
||
# Process and embed documents | ||
for idx, doc in tqdm( | ||
docs, | ||
desc="Embedding 🦖", | ||
unit="docs", | ||
total=total_docs, | ||
bar_format="{l_bar}{bar}| Time Left: {remaining}", | ||
): | ||
try: | ||
# Update task status for progress tracking | ||
progress = int((idx / total_docs) * 100) | ||
task_status.update_state(state="PROGRESS", meta={"current": progress}) | ||
|
||
# Add document to vector store | ||
add_text_to_store_with_retry(store, doc, source_id) | ||
except Exception as e: | ||
logging.error(f"Error embedding document {idx}: {e}") | ||
logging.info(f"Saving progress at document {idx} out of {total_docs}") | ||
store.save_local(folder_name) | ||
break | ||
|
||
# Save the vector store | ||
if settings.VECTOR_STORE == "faiss": | ||
store.save_local(folder_name) | ||
logging.info("Vector store saved successfully.") | ||
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.