diff --git a/core/config/__init__.py b/core/config/__init__.py index 90ac3c6a0..1c0224f38 100644 --- a/core/config/__init__.py +++ b/core/config/__init__.py @@ -70,7 +70,7 @@ class LLMProvider(str, Enum): GROQ = "groq" LM_STUDIO = "lm-studio" AZURE = "azure" - + ONEMINAI = "1min-ai" class UIAdapter(str, Enum): """ diff --git a/core/llm/base.py b/core/llm/base.py index 1c1143ffa..8d53c7e38 100644 --- a/core/llm/base.py +++ b/core/llm/base.py @@ -333,6 +333,7 @@ def for_provider(provider: LLMProvider) -> type["BaseLLMClient"]: from .anthropic_client import AnthropicClient from .azure_client import AzureClient from .groq_client import GroqClient + from .onemin_ai_client import OneMinAIClient from .openai_client import OpenAIClient if provider == LLMProvider.OPENAI: @@ -343,6 +344,8 @@ def for_provider(provider: LLMProvider) -> type["BaseLLMClient"]: return GroqClient elif provider == LLMProvider.AZURE: return AzureClient + elif provider == LLMProvider.ONEMINAI: + return OneMinAIClient else: raise ValueError(f"Unsupported LLM provider: {provider.value}") diff --git a/core/llm/onemin_ai_client.py b/core/llm/onemin_ai_client.py new file mode 100644 index 000000000..2e0cc515f --- /dev/null +++ b/core/llm/onemin_ai_client.py @@ -0,0 +1,89 @@ +import datetime +import re +import requests +from typing import Optional +from core.config import LLMProvider +from core.llm.base import BaseLLMClient +from core.llm.convo import Convo +from core.log import get_logger + +log = get_logger(__name__) + +class OneMinAIClient(BaseLLMClient): + provider = LLMProvider.ONEMINAI + stream_options = {"include_usage": True} + + def _init_client(self): + self.headers = { + "API-KEY": self.config.api_key, + "Content-Type": "application/json", + } + self.base_url = self.config.base_url + + async def _make_request( + self, + convo: Convo, + temperature: Optional[float] = None, + json_mode: bool = False, + ) -> str: + # Convert array of messages (dicts) to a single string + combined_prompt = " ".join([msg.get("content", "") for msg in convo.messages]) + + # Prepare the request body for 1min.ai + request_body = { + "type": "CHAT_WITH_AI", + "conversationId": self.config.extra.get("conversation_id"), + "model": self.config.model, + "promptObject": { + "prompt": combined_prompt, + "isMixed": False, + "webSearch": False + } + } + # Send the request using the requests library + response = requests.post( + self.base_url, + json=request_body, + headers=self.headers, + timeout=(self.config.connect_timeout, self.config.read_timeout), + ) + + # Check if the request was successful + if response.status_code != 200: + print(response.text) + log.error(f"Request failed with status {response.status_code}: {response.text}") + response.raise_for_status() + + # Extract response text from the JSON response + response_str = response.text + + return response_str, 0, 0 + + def rate_limit_sleep(self, err: requests.exceptions.RequestException) -> Optional[datetime.timedelta]: + """ + Rate limit handling logic, adjusted to work with 1min.ai response format. + """ + headers = err.response.headers + if "x-ratelimit-remaining-tokens" not in headers: + return None + + remaining_tokens = headers.get("x-ratelimit-remaining-tokens", 0) + time_regex = r"(?:(\d+)h)?(?:(\d+)m)?(?:(\d+)s)?" + + if int(remaining_tokens) == 0: + match = re.search(time_regex, headers.get("x-ratelimit-reset-tokens", "")) + else: + match = re.search(time_regex, headers.get("x-ratelimit-reset-requests", "")) + + if match: + hours = int(match.group(1)) if match.group(1) else 0 + minutes = int(match.group(2)) if match.group(2) else 0 + seconds = int(match.group(3)) if match.group(3) else 0 + total_seconds = hours * 3600 + minutes * 60 + seconds + else: + total_seconds = 5 + + return datetime.timedelta(seconds=total_seconds) + + +__all__ = ["OneMinAIClient"] \ No newline at end of file diff --git a/example-config.json b/example-config.json index b6c00729f..75d582e2d 100644 --- a/example-config.json +++ b/example-config.json @@ -27,6 +27,16 @@ "azure_deployment": "your-azure-deployment-id", "api_version": "2024-02-01" } + }, + // Example config for 1min.ai (https://gleaming-wren-2af.notion.site/1min-AI-API-Docs-111af080bd8f8046a4e6e1053c95e047#111af080bd8f8027be32e1bbb5957921) + "1min-ai": { + "base_url": "https://api.1min.ai/api/features?isStreaming=true", + "api_key": "your-api-key", + "connect_timeout": 60.0, + "read_timeout": 20.0, + "extra": { + "conversation_id": null // Leave empty to start a new conversation + } } }, // Each agent can use a different model or configuration. The default, as before, is GPT4 Turbo diff --git a/tests/llm/test_one_min_ai.py b/tests/llm/test_one_min_ai.py new file mode 100644 index 000000000..af2a5f257 --- /dev/null +++ b/tests/llm/test_one_min_ai.py @@ -0,0 +1,131 @@ +from unittest.mock import MagicMock, patch +import pytest +from core.config import LLMConfig +from core.llm.convo import Convo +from core.llm.oneminai_client import OneMinAIClient # Updated import +from requests.exceptions import HTTPError + + +@pytest.mark.asyncio +@patch("requests.post") # Mock `requests.post` instead of `AsyncOpenAI` +async def test_oneminai_calls_model(mock_post): + cfg = LLMConfig(model="1minai-model") + convo = Convo("system hello").user("user hello") + + # Mock the return value of `requests.post` + mock_post.return_value.status_code = 200 + mock_post.return_value.text = "helloworld" # Simulate plain text response + + llm = OneMinAIClient(cfg) + response = await llm._make_request(convo) + + assert response == "helloworld" + + mock_post.assert_called_once_with( + cfg.base_url, + json={ + "type": "CHAT_WITH_AI", + "conversationId": cfg.extra.get("conversation_id"), + "model": cfg.model, + "promptObject": { + "prompt": "system hello user hello", # Combined messages + "isMixed": False, + "webSearch": False + } + }, + headers={"API-KEY": cfg.api_key, "Content-Type": "application/json"}, + timeout=(cfg.connect_timeout, cfg.read_timeout), + ) + + +@pytest.mark.asyncio +@patch("requests.post") +async def test_oneminai_error_handler(mock_post): + cfg = LLMConfig(model="1minai-model") + convo = Convo("system hello").user("user hello") + + # Simulate a failed request + mock_post.return_value.status_code = 500 + mock_post.return_value.text = "Internal Server Error" + + llm = OneMinAIClient(cfg) + + with pytest.raises(HTTPError): + await llm._make_request(convo) + + +@pytest.mark.asyncio +@patch("requests.post") +async def test_oneminai_retry_logic(mock_post): + cfg = LLMConfig(model="1minai-model") + convo = Convo("system hello").user("user hello") + + # Simulate failure on the first attempt and success on the second + mock_post.side_effect = [ + MagicMock(status_code=500, text="Error"), # First call fails + MagicMock(status_code=200, text="Hello"), # Second call succeeds + ] + + llm = OneMinAIClient(cfg) + response = await llm._make_request(convo) + + assert response == "Hello" + assert mock_post.call_count == 2 + + +@pytest.mark.parametrize( + ("remaining_tokens", "reset_tokens", "reset_requests", "expected"), + [ + (0, "1h1m1s", "", 3661), + (0, "1h1s", "", 3601), + (0, "1m", "", 60), + (0, "", "1h1m1s", 0), + (1, "", "1h1m1s", 3661), + ], +) +@patch("requests.post") +def test_oneminai_rate_limit_parser(mock_post, remaining_tokens, reset_tokens, reset_requests, expected): + headers = { + "x-ratelimit-remaining-tokens": remaining_tokens, + "x-ratelimit-reset-tokens": reset_tokens, + "x-ratelimit-reset-requests": reset_requests, + } + err = MagicMock(response=MagicMock(headers=headers)) + + llm = OneMinAIClient(LLMConfig(model="1minai-model")) + assert int(llm.rate_limit_sleep(err).total_seconds()) == expected + + +@pytest.mark.asyncio +@patch("requests.post") +async def test_oneminai_response_success(mock_post): + cfg = LLMConfig(model="1minai-model") + convo = Convo("system hello").user("user hello") + + # Simulate a successful response + mock_post.return_value.status_code = 200 + mock_post.return_value.text = "Success" + + llm = OneMinAIClient(cfg) + response = await llm._make_request(convo) + + assert response == "Success" + mock_post.assert_called_once() + + +@pytest.mark.asyncio +@patch("requests.post") +async def test_oneminai_handle_non_200_response(mock_post): + cfg = LLMConfig(model="1minai-model") + convo = Convo("system hello").user("user hello") + + # Simulate a non-200 response + mock_post.return_value.status_code = 400 + mock_post.return_value.text = "Bad Request" + + llm = OneMinAIClient(cfg) + + with pytest.raises(HTTPError): + await llm._make_request(convo) + + mock_post.assert_called_once() \ No newline at end of file