Skip to content

Commit

Permalink
start adding azure
Browse files Browse the repository at this point in the history
  • Loading branch information
vemonet committed Nov 4, 2024
1 parent 2d06ba2 commit c5d421f
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 2 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
matrix:
# os: ["ubuntu-latest", "windows-latest", "macos-latest"]
os: ["ubuntu-latest"]
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
python-version: ["3.9", "3.10", "3.11", "3.12"]

steps:
- uses: actions/checkout@v4
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ chat = [
"pydantic >=2.0.0",
"pydantic-settings",
"openai",
"azure-ai-inference",
"qdrant_client",
"fastembed",
]
Expand Down
1 change: 1 addition & 0 deletions src/sparql_llm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ async def chat(request: ChatCompletionRequest):
raise ValueError("Invalid API key")

client = get_llm_client(request.model)
# print(client.models.list())

question: str = request.messages[-1].content if request.messages else ""

Expand Down
54 changes: 53 additions & 1 deletion src/sparql_llm/config.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,63 @@
from openai import OpenAI
import os
from openai import OpenAI, AzureOpenAI
from pydantic_settings import BaseSettings, SettingsConfigDict

from sparql_llm.utils import get_prefixes_for_endpoints

# import warnings
# warnings.simplefilter(action="ignore", category=UserWarning)

# https://learn.microsoft.com/en-us/azure/ai-studio/reference/reference-model-inference-api?tabs=python
# https://github.com/Azure/azure-sdk-for-python/tree/main/sdk/ai/azure-ai-inference
# Or use LlamaIndex: https://docs.llamaindex.ai/en/stable/examples/llm/azure_inference/
# Langchain does not seems to have support for azure inference yet https://github.com/langchain-ai/langchain-azure/tree/main/libs

# from azure.ai.inference import ChatCompletionsClient
# from azure.core.credentials import AzureKeyCredential

# api_key = os.getenv("AZURE_INFERENCE_CREDENTIAL", '')
# if not api_key:
# raise Exception("A key should be provided to invoke the endpoint")

# client = ChatCompletionsClient(
# endpoint='https://mistral-large-2407-kru.swedencentral.models.ai.azure.com',
# credential=AzureKeyCredential(api_key)
# )

# model_info = client.get_model_info()
# print("Model name:", model_info.model_name)
# print("Model type:", model_info.model_type)
# print("Model provider name:", model_info.model_provider_name)

# payload = {
# "messages": [
# {
# "role": "user",
# "content": "I am going to Paris, what should I see?"
# },
# {
# "role": "assistant",
# "content": "Paris, the capital of France, is known for its stunning architecture, art museums, historical landmarks, and romantic atmosphere. Here are some of the top attractions to see in Paris:\n\n1. The Eiffel Tower: The iconic Eiffel Tower is one of the most recognizable landmarks in the world and offers breathtaking views of the city.\n2. The Louvre Museum: The Louvre is one of the world's largest and most famous museums, housing an impressive collection of art and artifacts, including the Mona Lisa.\n3. Notre-Dame Cathedral: This beautiful cathedral is one of the most famous landmarks in Paris and is known for its Gothic architecture and stunning stained glass windows.\n\nThese are just a few of the many attractions that Paris has to offer. With so much to see and do, it's no wonder that Paris is one of the most popular tourist destinations in the world."
# },
# {
# "role": "user",
# "content": "What is so great about #1?"
# }
# ],
# "max_tokens": 2048,
# "temperature": 0.8,
# "top_p": 0.1
# }
# response = client.complete(payload)

# print("Response:", response.choices[0].message.content)
# print("Model:", response.model)
# print("Usage:")
# print(" Prompt tokens:", response.usage.prompt_tokens)
# print(" Total tokens:", response.usage.total_tokens)
# print(" Completion tokens:", response.usage.completion_tokens)



def get_llm_client(model: str) -> OpenAI:
if model.startswith("hf:"):
Expand Down

0 comments on commit c5d421f

Please sign in to comment.