Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add pdf support #105

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion backend/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,5 @@ tiktoken
tqdm~=4.65.0
types-requests==0.1.13
typing-inspect==0.8.0
typing_extensions==4.5.0
typing_extensions==4.5.0
PyPDF2
19 changes: 18 additions & 1 deletion backend/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@
from sqlalchemy import create_engine
from PIL import Image
from loguru import logger
from PyPDF2 import PdfReader

from real_agents.adapters.data_model import (
DatabaseDataModel,
DataModel,
ImageDataModel,
TableDataModel,
KaggleDataModel,
DocumentDataModel,
)
from real_agents.data_agent import (
DataSummaryExecutor,
Expand All @@ -32,7 +34,8 @@
DOCUMENT_EXTENSIONS = {"pdf", "doc", "docx", "txt"}
DATABASE_EXTENSIONS = {"sqlite", "db"}
IMAGE_EXTENSIONS = {"jpg", "png", "jpeg"}
ALLOW_EXTENSIONS = TABLE_EXTENSIONS | DOCUMENT_EXTENSIONS | DATABASE_EXTENSIONS | IMAGE_EXTENSIONS
PDF_EXTENSIONS = {"pdf"}
ALLOW_EXTENSIONS = TABLE_EXTENSIONS | DOCUMENT_EXTENSIONS | DATABASE_EXTENSIONS | IMAGE_EXTENSIONS | PDF_EXTENSIONS

LOCAL = "local"
REDIS = "redis"
Expand Down Expand Up @@ -127,6 +130,18 @@ def load_grounding_source(file_path: str) -> Any:
"size": img.size,
"mode": img.mode,
}
elif suffix == ".pdf":
brut_doc = PdfReader(file_path)
grounding_source = {
"plain_text": "".join(f'-PAGE_{str(i)}-{page.extract_text()}' for i, page in enumerate(brut_doc.pages)),
"num_pages": len(brut_doc.pages),
"metadata": {
'author': brut_doc.metadata.author,
'year': brut_doc.metadata.creation_date.year,
'title': brut_doc.metadata.title,
'subject': brut_doc.metadata.subject,
}
}
else:
raise ValueError("File type not allowed to be set as grounding source")
return grounding_source
Expand All @@ -146,6 +161,8 @@ def get_data_model_cls(file_path: str) -> DataModel:
data_model_cls = DatabaseDataModel
elif suffix == ".jpeg" or suffix == ".png" or suffix == ".jpg":
data_model_cls = ImageDataModel
elif suffix == ".pdf":
data_model_cls = DocumentDataModel
else:
raise ValueError("File type not allowed to be set as grounding source")
return data_model_cls
Expand Down
2 changes: 1 addition & 1 deletion frontend/components/Chatbar/components/ChatbarSettings.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ export const ChatbarSettings = <T,>({
className="sr-only"
tabIndex={-1}
type="file"
accept=".csv, .tsv, .xslx, .db, .sqlite, .png, .jpg, .jpeg"
accept=".csv, .tsv, .xslx, .db, .sqlite, .png, .jpg, .jpeg, .pdf"
ref={fileInputRef}
onChange={handleUpload}
/>
Expand Down
1 change: 1 addition & 0 deletions real_agents/adapters/data_model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from real_agents.adapters.data_model.kaggle import KaggleDataModel
from real_agents.adapters.data_model.plugin import APIYamlModel, SpecModel
from real_agents.adapters.data_model.table import TableDataModel
from real_agents.adapters.data_model.document import DocumentDataModel

__all__ = [
"DataModel",
Expand Down
17 changes: 17 additions & 0 deletions real_agents/adapters/data_model/document.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from typing import Any, List

from real_agents.adapters.data_model.base import DataModel


class DocumentDataModel(DataModel):
"""A data model for a document (can contain text, images, tables, other data)."""

def get_raw_data(self) -> Any:
return self.raw_data

def get_llm_side_data(self,
max_tokens: int = 2000,
chunk_size: int = 1000,
chunk_overlap: int = 200
) -> Any:
return self.raw_data['plain_text'][:max_tokens]
100 changes: 99 additions & 1 deletion real_agents/data_agent/executors/data_summary_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@
from langchain import PromptTemplate

from real_agents.adapters.callbacks.executor_streaming import ExecutorStreamingChainHandler
from real_agents.adapters.data_model import DatabaseDataModel, TableDataModel, ImageDataModel
from real_agents.adapters.data_model import (
DatabaseDataModel,
TableDataModel,
ImageDataModel,
DocumentDataModel
)
from real_agents.adapters.llm import LLMChain


Expand Down Expand Up @@ -195,3 +200,96 @@ def _parse_output(self, content: str) -> Tuple[str, str]:
bullet_points.append(f"{bullet_point_id}. " + line[1:].strip().strip('"'))
bullet_point_id += 1
return table_summary, "\n".join(bullet_points)


class DocumentSummaryExecutor(DataSummaryExecutor):
SUMMARY_PROMPT_TEMPLATE = """
{img_info}

Provide a succinct summary of the uploaded file with less than 20 words. Please ensure your summary is a complete sentence and include it within <summary></summary> tags."
Then provide {num_insights} very simple and basic suggestions in natural language about further processing with the data. The final results should be markdown '+' bullet point list, e.g., + The first suggestion."

Begin.
"""
stream_handler = ExecutorStreamingChainHandler()

def run(
self,
grounding_source: DocumentDataModel,
llm: BaseLanguageModel,
use_intelligent_summary: bool = True,
num_insights: int = 3,
) -> Dict[str, Any]:
summary = ""
if isinstance(grounding_source, DocumentDataModel):
# Basic summary
summary += (
f"Your document **{grounding_source.raw_data['metadata']['title']}** created by "
f"{grounding_source.raw_data['metadata']['author']} at "
f"{grounding_source.raw_data['metadata']['year']} year. \n"
)

# Intelligent summary
if use_intelligent_summary:
intelligent_summary = self._intelligent_summary(
grounding_source,
num_insights=num_insights,
llm=llm,
)
_, suggestions = self._parse_output(intelligent_summary)
summary += "\n" + suggestions

for stream_token in summary.split(" "):
self.stream_handler.on_llm_new_token(stream_token)
else:
raise ValueError(f"Unsupported data summary for grounding source type: {type(grounding_source)}")
return summary

def _intelligent_summary(self, grounding_source: DocumentDataModel, num_insights: int, llm: BaseLanguageModel) -> str:
"""Use LLM to generate data summary."""
summary_prompt_template = PromptTemplate(
input_variables=["img_info", "num_insights"],
template=self.SUMMARY_PROMPT_TEMPLATE,
)
method = LLMChain(llm=llm, prompt=summary_prompt_template)
result = method.run({"img_info": grounding_source.get_llm_side_data(), "num_insights": num_insights})
return result

@staticmethod
def text_summary(llm: BaseLanguageModel, reduce_template: str) -> str:
reduce_prompt = PromptTemplate.from_template(reduce_template)
map_prompt = PromptTemplate.from_template(reduce_prompt)
map_chain = LLMChain(llm=llm, prompt=map_prompt)
# Run chain
reduce_chain = LLMChain(llm=llm, prompt=reduce_prompt)

return reduce_chain, map_chain


def _parse_output(self, content: str) -> Tuple[str, str]:
"""Parse the output of the LLM to get the data summary."""
from bs4 import BeautifulSoup

# Using 'html.parser' to parse the content
soup = BeautifulSoup(content, "html.parser")
# Parsing the tag and summary contents
try:
table_summary = soup.find("summary").text
except Exception:
import traceback

traceback.print_exc()
table_summary = ""

lines = content.split("\n")
# Initialize an empty list to hold the parsed bullet points
bullet_points = []
# Loop through each line
bullet_point_id = 1
for line in lines:
# If the line starts with '+', it is a bullet point
if line.startswith("+"):
# Remove the '+ ' from the start of the line and add it to the list
bullet_points.append(f"{bullet_point_id}. " + line[1:].strip().strip('"'))
bullet_point_id += 1
return table_summary, "\n".join(bullet_points)