Skip to content

Commit

Permalink
chore(wren-ai-service): sql2answer-minor-update (#691)
Browse files Browse the repository at this point in the history
* allow auto-refresh

* update var name

* leave comment for vertex ai model support
  • Loading branch information
cyyeh authored Oct 2, 2024
1 parent f8f9ec6 commit 6104776
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 3 deletions.
106 changes: 105 additions & 1 deletion wren-ai-service/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions wren-ai-service/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ sqlglot = "==25.18.0"
cachetools = "==5.5.0"
pyyaml = "==6.0.2"
pydantic-settings = "==2.5.2"
google-auth = "==2.35.0"

[tool.poetry.group.dev.dependencies]
pre-commit = "==3.7.1"
Expand Down
29 changes: 27 additions & 2 deletions wren-ai-service/src/providers/llm/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from typing import Any, Callable, Dict, List, Optional, Union

import backoff
import google.auth
import google.auth.transport.requests
import openai
import orjson
from haystack import component
Expand Down Expand Up @@ -57,6 +59,29 @@ def __init__(
base_url=api_base_url,
)

# check if the model is actually Vertex AI model
# currently we support Vertex AI through openai api compatible way
# in the near future, we might use litellm instead, so we can more easily support different kinds of llm providers
# this is workaround as of now
self._vertexai_creds = None
if model.startswith("google/"):
self._vertexai_creds, _ = google.auth.default(
scopes=["https://www.googleapis.com/auth/cloud-platform"]
)

def __getattr__(self, name: str) -> Any:
# dealing with auto-refreshing expired credential for Vertex AI model
# https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/call-vertex-using-openai-library#refresh_your_credentials
if self._vertexai_creds and not self._vertexai_creds.valid:
auth_req = google.auth.transport.requests.Request()
self._vertexai_creds.refresh(auth_req)

if not self._vertexai_creds.valid:
raise RuntimeError("Unable to refresh auth")

self.client.api_key = self._vertexai_creds.token
return getattr(self.client, name)

@component.output_types(replies=List[str], meta=List[Dict[str, Any]])
@backoff.on_exception(backoff.expo, openai.RateLimitError, max_time=60, max_tries=3)
async def run(
Expand Down Expand Up @@ -151,11 +176,11 @@ def get_generator(
):
if self._api_base == LLM_OPENAI_API_BASE:
logger.info(
f"Creating OpenAI generator with model kwargs: {self._model_kwargs}"
f"Creating OpenAI generator {self._generation_model} with model kwargs: {self._model_kwargs}"
)
else:
logger.info(
f"Creating OpenAI API-compatible generator with model kwargs: {self._model_kwargs}"
f"Creating OpenAI API-compatible generator {self._generation_model} with model kwargs: {self._model_kwargs}"
)

return AsyncGenerator(
Expand Down

0 comments on commit 6104776

Please sign in to comment.