Skip to content

Commit

Permalink
chore: add mock pipe for testing and test case for query cache ttl
Browse files Browse the repository at this point in the history
  • Loading branch information
paopa committed Sep 11, 2024
1 parent 1526349 commit 5028df3
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 0 deletions.
66 changes: 66 additions & 0 deletions wren-ai-service/tests/pytest/services/mocks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from typing import Optional

from src.pipelines.ask import (
generation,
historical_question,
retrieval,
sql_summary,
)


class RetrievalMock(retrieval.Retrieval):
def __init__(self, documents: list = []):
self._documents = documents

async def run(self, query: str, id: Optional[str] = None):
return {"construct_retrieval_results": self._documents}


class HistoricalQuestionMock(historical_question.HistoricalQuestion):
def __init__(self, documents: list = []):
self._documents = documents

async def run(self, query: str, id: Optional[str] = None):
return {"formatted_output": {"documents": self._documents}}


class GenerationMock(generation.Generation):
def __init__(self, valid: list = [], invalid: list = []):
self._valid = valid
self._invalid = invalid

async def run(
self,
query: str,
contexts: list[str],
exclude: list[dict],
project_id: str | None = None,
):
return {
"post_process": {
"valid_generation_results": self._valid,
"invalid_generation_results": self._invalid,
}
}


class SQLSummaryMock(sql_summary.SQLSummary):
"""
Example for the results:
[
{
"sql": "select 1",
"summary": "the description of the sql",
}
]
"""

def __init__(self, results: list = []):
self._results = results

async def run(
self,
query: str,
sqls: list[str],
):
return {"post_process": {"sql_summary_results": self._results}}
73 changes: 73 additions & 0 deletions wren-ai-service/tests/pytest/services/test_ask.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import json
import re
import time
import uuid

import orjson
Expand All @@ -16,12 +18,19 @@
from src.web.v1.services.ask import (
AskRequest,
AskResultRequest,
AskResultResponse,
AskService,
)
from src.web.v1.services.indexing import (
IndexingService,
SemanticsPreparationRequest,
)
from tests.pytest.services.mocks import (
GenerationMock,
HistoricalQuestionMock,
RetrievalMock,
SQLSummaryMock,
)


@pytest.fixture
Expand Down Expand Up @@ -120,3 +129,67 @@ async def test_ask_with_successful_query(
# assert ask_result_response.response[0].sql != ""
# assert ask_result_response.response[0].summary != ""
# assert ask_result_response.response[0].type == "llm" or "view"


def _ask_service_ttl_mock(query: str):
return AskService(
{
"retrieval": RetrievalMock(
[
f"mock document 1 for {query}",
f"mock document 2 for {query}",
]
),
"historical_question": HistoricalQuestionMock(),
"generation": GenerationMock(
valid=["select count(*) from books"],
),
"sql_summary": SQLSummaryMock(
results=[
{
"sql": "select count(*) from books",
"summary": "mock summary",
}
]
),
},
ttl=3,
)


@pytest.mark.asyncio
async def test_ask_query_ttl():
query = "How many books are there?"
query_id = str(uuid.uuid4())

ask_service = _ask_service_ttl_mock(query)
ask_service._ask_results[query_id] = AskResultResponse(
status="understanding",
)

request = AskRequest(
query=query,
mdl_hash="mock mdl hash",
)
request.query_id = query_id

await ask_service.ask(request)

time.sleep(1)
response = ask_service.get_ask_result(
AskResultRequest(
query_id=query_id,
)
)
assert response.status == "finished"

time.sleep(3)
response = ask_service.get_ask_result(
AskResultRequest(
query_id=query_id,
)
)

assert response.status == "failed"
assert response.error.code == "OTHERS"
assert re.match(r".+ is not found", response.error.message)

0 comments on commit 5028df3

Please sign in to comment.