Skip to content

Commit

Permalink
chore: make langchain optional
Browse files Browse the repository at this point in the history
  • Loading branch information
bclavie committed Sep 23, 2024
1 parent bdf05a5 commit cd1ee7c
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 8 deletions.
18 changes: 13 additions & 5 deletions byaldi/RAGModel.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
from pathlib import Path
from typing import Any, Dict, List, Optional, Union

from langchain_core.retrievers import BaseRetriever
from PIL import Image

from byaldi.colpali import ColPaliModel
from byaldi.integrations import ByaldiLangChainRetriever

from byaldi.objects import Result

# Optional langchain integration
try:
from byaldi.integrations import ByaldiLangChainRetriever
except ImportError:
pass


class RAGMultiModalModel:
"""
Expand Down Expand Up @@ -52,7 +57,10 @@ def from_pretrained(
"""
instance = cls()
instance.model = ColPaliModel.from_pretrained(
pretrained_model_name_or_path, index_root=index_root, device=device, verbose=verbose
pretrained_model_name_or_path,
index_root=index_root,
device=device,
verbose=verbose,
)
return instance

Expand Down Expand Up @@ -168,5 +176,5 @@ def search(
def get_doc_ids_to_file_names(self):
return self.model.get_doc_ids_to_file_names()

def as_langchain_retriever(self, **kwargs: Any) -> BaseRetriever:
return ByaldiLangChainRetriever(model=self, kwargs=kwargs)
def as_langchain_retriever(self, **kwargs: Any):
return ByaldiLangChainRetriever(model=self, kwargs=kwargs)
9 changes: 7 additions & 2 deletions byaldi/integrations/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
from byaldi.integrations._langchain import ByaldiLangChainRetriever
_all__ = []

__all__ = ["ByaldiLangChainRetriever"]
try:
from byaldi.integrations._langchain import ByaldiLangChainRetriever

_all__.append("ByaldiLangChainRetriever")
except ImportError:
pass
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@ dependencies = [
"srsly",
"torch",
"transformers",
"langchain-core",
]

[project.optional-dependencies]
dev = ["pytest>=7.4.0", "ruff>=0.1.9"]
server = ["uvicorn", "fastapi"]
langchain = ["langchain-core"]

[project.urls]
"Homepage" = "https://github.com/answerdotai/byaldi"
Expand Down

0 comments on commit cd1ee7c

Please sign in to comment.