diff --git a/wren-ai-service/tests/pytest/test_utils.py b/wren-ai-service/tests/pytest/test_utils.py index 74929d8c3..2be8222ca 100644 --- a/wren-ai-service/tests/pytest/test_utils.py +++ b/wren-ai-service/tests/pytest/test_utils.py @@ -1,10 +1,12 @@ import asyncio import os +from dataclasses import asdict import pytest from pytest_mock import MockFixture import src.utils as utils +from src.globals import ServiceMetadata, create_service_metadata def _mock(mocker: MockFixture) -> tuple: @@ -30,28 +32,28 @@ def _mock(mocker: MockFixture) -> tuple: return llm_provider, embedder_provider -@pytest.mark.skip(reason="TODO") -def test_service_metadata(mocker: MockFixture): +@pytest.fixture +def service_metadata(mocker: MockFixture): current_path = os.path.dirname(__file__) - utils.service_metadata( + + return create_service_metadata( *_mock(mocker), pyproject_path=os.path.join(current_path, "../data/mock_pyproject.toml"), ) - assert utils.MODELS_METADATA == { + +def test_service_metadata(service_metadata: ServiceMetadata): + assert service_metadata.models_metadata == { "generation_model": "mock-llm-model", "generation_model_kwargs": {}, "embedding_model": "mock-embedding-model", "embedding_model_dim": 768, } - assert utils.SERVICE_VERSION == "0.8.0-mock" + assert service_metadata.service_version == "0.8.0-mock" -@pytest.mark.skip(reason="TODO") -def test_trace_metadata(mocker: MockFixture): - metadata = mocker.patch("src.utils.MODELS_METADATA", {}) - version = mocker.patch("src.utils.SERVICE_VERSION", "0.8.0-mock") +def test_trace_metadata(service_metadata: ServiceMetadata, mocker: MockFixture): function = mocker.patch( "src.utils.langfuse_context.update_current_trace", return_value=None ) @@ -63,18 +65,18 @@ class Request: user_id = "mock-user-id" @utils.trace_metadata - async def my_function(_: str, b: Request): + async def my_function(_: str, b: Request, **kwargs): return "Hello, World!" - asyncio.run(my_function("", Request())) + asyncio.run(my_function("", Request(), service_metadata=asdict(service_metadata))) function.assert_called_once_with( user_id="mock-user-id", session_id="mock-thread-id", - release=version, + release=service_metadata.service_version, metadata={ "mdl_hash": "mock-mdl-hash", "project_id": "mock-project-id", - **metadata, + **service_metadata.models_metadata, }, )