Skip to content

Commit

Permalink
together: fix chat model and embedding classes (langchain-ai#21353)
Browse files Browse the repository at this point in the history
  • Loading branch information
efriis authored May 7, 2024
1 parent d6ef5fe commit bb81ae5
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 12 deletions.
2 changes: 1 addition & 1 deletion libs/partners/together/langchain_together/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def _llm_type(self) -> str:
together_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
"""Automatically inferred from env are `TOGETHER_API_KEY` if not provided."""
together_api_base: Optional[str] = Field(
default="https://api.together.ai/v1/chat/completions", alias="base_url"
default="https://api.together.ai/v1/", alias="base_url"
)

@root_validator()
Expand Down
24 changes: 14 additions & 10 deletions libs/partners/together/langchain_together/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class TogetherEmbeddings(BaseModel, Embeddings):
client: Any = Field(default=None, exclude=True) #: :meta private:
async_client: Any = Field(default=None, exclude=True) #: :meta private:
model: str = "togethercomputer/m2-bert-80M-8k-retrieval"
"""Embeddings model name to use. Do not add suffixes like `-query` and `-passage`.
"""Embeddings model name to use.
Instead, use 'togethercomputer/m2-bert-80M-8k-retrieval' for example.
"""
dimensions: Optional[int] = None
Expand All @@ -62,7 +62,7 @@ class TogetherEmbeddings(BaseModel, Embeddings):
together_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
"""API Key for Solar API."""
together_api_base: str = Field(
default="https://api.together.ai/v1/embeddings", alias="base_url"
default="https://api.together.ai/v1/", alias="base_url"
)
"""Endpoint URL to use."""
embedding_ctx_length: int = 4096
Expand Down Expand Up @@ -166,21 +166,25 @@ def validate_environment(cls, values: Dict) -> Dict:
"default_query": values["default_query"],
}
if not values.get("client"):
sync_specific = {"http_client": values["http_client"]}
sync_specific = (
{"http_client": values["http_client"]} if values["http_client"] else {}
)
values["client"] = openai.OpenAI(
**client_params, **sync_specific
).embeddings
if not values.get("async_client"):
async_specific = {"http_client": values["http_async_client"]}
async_specific = (
{"http_client": values["http_async_client"]}
if values["http_async_client"]
else {}
)
values["async_client"] = openai.AsyncOpenAI(
**client_params, **async_specific
).embeddings
return values

@property
def _invocation_params(self) -> Dict[str, Any]:
self.model = self.model.replace("-query", "").replace("-passage", "")

params: Dict = {"model": self.model, **self.model_kwargs}
if self.dimensions is not None:
params["dimensions"] = self.dimensions
Expand All @@ -197,7 +201,7 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""
embeddings = []
params = self._invocation_params
params["model"] = params["model"] + "-passage"
params["model"] = params["model"]

for text in texts:
response = self.client.create(input=text, **params)
Expand All @@ -217,7 +221,7 @@ def embed_query(self, text: str) -> List[float]:
Embedding for the text.
"""
params = self._invocation_params
params["model"] = params["model"] + "-query"
params["model"] = params["model"]

response = self.client.create(input=text, **params)

Expand All @@ -236,7 +240,7 @@ async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
"""
embeddings = []
params = self._invocation_params
params["model"] = params["model"] + "-passage"
params["model"] = params["model"]

for text in texts:
response = await self.async_client.create(input=text, **params)
Expand All @@ -256,7 +260,7 @@ async def aembed_query(self, text: str) -> List[float]:
Embedding for the text.
"""
params = self._invocation_params
params["model"] = params["model"] + "-query"
params["model"] = params["model"]

response = await self.async_client.create(input=text, **params)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,5 @@ def chat_model_class(self) -> Type[BaseChatModel]:
@pytest.fixture
def chat_model_params(self) -> dict:
return {
"model": "meta-llama/Llama-3-8b-chat-hf",
"model": "mistralai/Mistral-7B-Instruct-v0.1",
}

0 comments on commit bb81ae5

Please sign in to comment.