diff --git a/byaldi/RAGModel.py b/byaldi/RAGModel.py index 0b87bd8..89c232a 100644 --- a/byaldi/RAGModel.py +++ b/byaldi/RAGModel.py @@ -1,11 +1,18 @@ from pathlib import Path -from typing import Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union from PIL import Image from byaldi.colpali import ColPaliModel + from byaldi.objects import Result +# Optional langchain integration +try: + from byaldi.integrations import ByaldiLangChainRetriever +except ImportError: + pass + class RAGMultiModalModel: """ @@ -50,7 +57,10 @@ def from_pretrained( """ instance = cls() instance.model = ColPaliModel.from_pretrained( - pretrained_model_name_or_path, index_root=index_root, device=device, verbose=verbose + pretrained_model_name_or_path, + index_root=index_root, + device=device, + verbose=verbose, ) return instance @@ -165,3 +175,6 @@ def search( def get_doc_ids_to_file_names(self): return self.model.get_doc_ids_to_file_names() + + def as_langchain_retriever(self, **kwargs: Any): + return ByaldiLangChainRetriever(model=self, kwargs=kwargs) diff --git a/byaldi/integrations/__init__.py b/byaldi/integrations/__init__.py new file mode 100644 index 0000000..5841288 --- /dev/null +++ b/byaldi/integrations/__init__.py @@ -0,0 +1,8 @@ +_all__ = [] + +try: + from byaldi.integrations._langchain import ByaldiLangChainRetriever + + _all__.append("ByaldiLangChainRetriever") +except ImportError: + pass diff --git a/byaldi/integrations/_langchain.py b/byaldi/integrations/_langchain.py new file mode 100644 index 0000000..f07b0be --- /dev/null +++ b/byaldi/integrations/_langchain.py @@ -0,0 +1,21 @@ +from typing import Any, List + +from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun +from langchain_core.retrievers import BaseRetriever + +from byaldi.objects import Result + + +class ByaldiLangChainRetriever(BaseRetriever): + model: Any + kwargs: dict = {} + + def _get_relevant_documents( + self, + query: str, + *, + run_manager: CallbackManagerForRetrieverRun, # noqa + ) -> List[Result]: + """Get documents relevant to a query.""" + docs = self.model.search(query, **self.kwargs) + return docs \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index df19ff4..287157f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,7 @@ dependencies = [ [project.optional-dependencies] dev = ["pytest>=7.4.0", "ruff>=0.1.9"] server = ["uvicorn", "fastapi"] +langchain = ["langchain-core"] [project.urls] "Homepage" = "https://github.com/answerdotai/byaldi"