Skip to content

Commit

Permalink
fix skipped tests
Browse files Browse the repository at this point in the history
  • Loading branch information
cyyeh committed Sep 18, 2024
1 parent 3886dc9 commit 1f4d0ff
Showing 1 changed file with 15 additions and 13 deletions.
28 changes: 15 additions & 13 deletions wren-ai-service/tests/pytest/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import asyncio
import os
from dataclasses import asdict

import pytest
from pytest_mock import MockFixture

import src.utils as utils
from src.globals import ServiceMetadata, create_service_metadata


def _mock(mocker: MockFixture) -> tuple:
Expand All @@ -30,28 +32,28 @@ def _mock(mocker: MockFixture) -> tuple:
return llm_provider, embedder_provider


@pytest.mark.skip(reason="TODO")
def test_service_metadata(mocker: MockFixture):
@pytest.fixture
def service_metadata(mocker: MockFixture):
current_path = os.path.dirname(__file__)
utils.service_metadata(

return create_service_metadata(
*_mock(mocker),
pyproject_path=os.path.join(current_path, "../data/mock_pyproject.toml"),
)

assert utils.MODELS_METADATA == {

def test_service_metadata(service_metadata: ServiceMetadata):
assert service_metadata.models_metadata == {
"generation_model": "mock-llm-model",
"generation_model_kwargs": {},
"embedding_model": "mock-embedding-model",
"embedding_model_dim": 768,
}

assert utils.SERVICE_VERSION == "0.8.0-mock"
assert service_metadata.service_version == "0.8.0-mock"


@pytest.mark.skip(reason="TODO")
def test_trace_metadata(mocker: MockFixture):
metadata = mocker.patch("src.utils.MODELS_METADATA", {})
version = mocker.patch("src.utils.SERVICE_VERSION", "0.8.0-mock")
def test_trace_metadata(service_metadata: ServiceMetadata, mocker: MockFixture):
function = mocker.patch(
"src.utils.langfuse_context.update_current_trace", return_value=None
)
Expand All @@ -63,18 +65,18 @@ class Request:
user_id = "mock-user-id"

@utils.trace_metadata
async def my_function(_: str, b: Request):
async def my_function(_: str, b: Request, **kwargs):
return "Hello, World!"

asyncio.run(my_function("", Request()))
asyncio.run(my_function("", Request(), service_metadata=asdict(service_metadata)))

function.assert_called_once_with(
user_id="mock-user-id",
session_id="mock-thread-id",
release=version,
release=service_metadata.service_version,
metadata={
"mdl_hash": "mock-mdl-hash",
"project_id": "mock-project-id",
**metadata,
**service_metadata.models_metadata,
},
)

0 comments on commit 1f4d0ff

Please sign in to comment.