Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
cyyeh committed Sep 18, 2024
1 parent e94eb57 commit 82cdce7
Show file tree
Hide file tree
Showing 7 changed files with 45 additions and 47 deletions.
12 changes: 6 additions & 6 deletions wren-ai-service/src/pipelines/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@


@component
class SQLBreakdownGenerationPostProcessor:
class SQLBreakdownGenPostProcessor:
def __init__(self, engine: Engine):
self._engine = engine

Expand Down Expand Up @@ -54,8 +54,8 @@ async def run(
}

sql = self._build_cte_query(steps)
logger.debug(f"SQLBreakdownGenerationPostProcessor: steps: {pformat(steps)}")
logger.debug(f"SQLBreakdownGenerationPostProcessor: final sql: {sql}")
logger.debug(f": steps: {pformat(steps)}")
logger.debug(f"SQLBreakdownGenPostProcessor: final sql: {sql}")

if not await self._check_if_sql_executable(sql, project_id=project_id):
return {
Expand Down Expand Up @@ -100,7 +100,7 @@ async def _check_if_sql_executable(


@component
class SQLGenerationPostProcessor:
class SQLGenPostProcessor:
def __init__(self, engine: Engine):
self._engine = engine

Expand Down Expand Up @@ -133,7 +133,7 @@ async def run(
"invalid_generation_results": invalid_generation_results,
}
except Exception as e:
logger.exception(f"Error in SQLGenerationPostProcessor: {e}")
logger.exception(f"Error in SQLGenPostProcessor: {e}")

return {
"valid_generation_results": [],
Expand Down Expand Up @@ -212,7 +212,7 @@ async def _task(result: Dict[str, str]):
"""


text_to_sql_system_prompt = """
sql_generation_system_prompt = """
You are a Trino SQL expert with exceptional logical thinking skills. Your main task is to generate SQL from given DB schema and user-input natrual language queries.
Before the main task, you need to learn about some specific structures in the given DB schema.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
from src.core.provider import LLMProvider
from src.pipelines.common import (
TEXT_TO_SQL_RULES,
SQLGenerationPostProcessor,
text_to_sql_system_prompt,
SQLGenPostProcessor,
sql_generation_system_prompt,
)
from src.utils import async_timer, timer
from src.web.v1.services.ask import AskHistory
Expand Down Expand Up @@ -149,7 +149,7 @@ async def generate_sql_in_followup(prompt: dict, generator: Any) -> dict:
@observe(capture_input=False)
async def post_process(
generate_sql_in_followup: dict,
post_processor: SQLGenerationPostProcessor,
post_processor: SQLGenPostProcessor,
project_id: str | None = None,
) -> dict:
logger.debug(
Expand All @@ -171,12 +171,12 @@ def __init__(
):
self._components = {
"generator": llm_provider.get_generator(
system_prompt=text_to_sql_system_prompt
system_prompt=sql_generation_system_prompt
),
"prompt_builder": PromptBuilder(
template=text_to_sql_with_followup_user_prompt_template
),
"post_processor": SQLGenerationPostProcessor(engine=engine),
"post_processor": SQLGenPostProcessor(engine=engine),
}

self._configs = {
Expand Down
6 changes: 3 additions & 3 deletions wren-ai-service/src/pipelines/generation/sql_breakdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from src.core.engine import Engine
from src.core.pipeline import BasicPipeline, async_validate
from src.core.provider import LLMProvider
from src.pipelines.common import SQLBreakdownGenerationPostProcessor
from src.pipelines.common import SQLBreakdownGenPostProcessor
from src.utils import (
async_timer,
timer,
Expand Down Expand Up @@ -130,7 +130,7 @@ async def generate_sql_details(prompt: dict, generator: Any) -> dict:
@observe(capture_input=False)
async def post_process(
generate_sql_details: dict,
post_processor: SQLBreakdownGenerationPostProcessor,
post_processor: SQLBreakdownGenPostProcessor,
project_id: str | None = None,
) -> dict:
logger.debug(
Expand All @@ -157,7 +157,7 @@ def __init__(
"prompt_builder": PromptBuilder(
template=sql_breakdown_user_prompt_template
),
"post_processor": SQLBreakdownGenerationPostProcessor(engine=engine),
"post_processor": SQLBreakdownGenPostProcessor(engine=engine),
}

super().__init__(
Expand Down
10 changes: 5 additions & 5 deletions wren-ai-service/src/pipelines/generation/sql_correction.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from src.core.provider import LLMProvider
from src.pipelines.common import (
TEXT_TO_SQL_RULES,
SQLGenerationPostProcessor,
text_to_sql_system_prompt,
SQLGenPostProcessor,
sql_generation_system_prompt,
)
from src.utils import async_timer, timer

Expand Down Expand Up @@ -91,7 +91,7 @@ async def generate_sql_correction(prompt: dict, generator: Any) -> dict:
@observe(capture_input=False)
async def post_process(
generate_sql_correction: dict,
post_processor: SQLGenerationPostProcessor,
post_processor: SQLGenPostProcessor,
project_id: str | None = None,
) -> dict:
logger.debug(
Expand All @@ -113,12 +113,12 @@ def __init__(
):
self._components = {
"generator": llm_provider.get_generator(
system_prompt=text_to_sql_system_prompt
system_prompt=sql_generation_system_prompt
),
"prompt_builder": PromptBuilder(
template=sql_correction_user_prompt_template
),
"post_processor": SQLGenerationPostProcessor(engine=engine),
"post_processor": SQLGenPostProcessor(engine=engine),
}

self._configs = {
Expand Down
6 changes: 3 additions & 3 deletions wren-ai-service/src/pipelines/generation/sql_expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from src.core.engine import Engine
from src.core.pipeline import BasicPipeline
from src.core.provider import LLMProvider
from src.pipelines.common import SQLGenerationPostProcessor
from src.pipelines.common import SQLGenPostProcessor
from src.utils import async_timer, timer
from src.web.v1.services.ask import AskHistory

Expand Down Expand Up @@ -75,7 +75,7 @@ async def generate_sql_expansion(prompt: dict, generator: Any) -> dict:
@observe(capture_input=False)
async def post_process(
generate_sql_expansion: dict,
post_processor: SQLGenerationPostProcessor,
post_processor: SQLGenPostProcessor,
project_id: str | None = None,
) -> dict:
logger.debug(
Expand All @@ -102,7 +102,7 @@ def __init__(
"prompt_builder": PromptBuilder(
template=sql_expansion_user_prompt_template
),
"post_processor": SQLGenerationPostProcessor(engine=engine),
"post_processor": SQLGenPostProcessor(engine=engine),
}

super().__init__(
Expand Down
16 changes: 9 additions & 7 deletions wren-ai-service/src/pipelines/generation/sql_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@
from src.core.provider import LLMProvider
from src.pipelines.common import (
TEXT_TO_SQL_RULES,
SQLGenerationPostProcessor,
text_to_sql_system_prompt,
SQLGenPostProcessor,
sql_generation_system_prompt,
)
from src.utils import async_timer, timer

logger = logging.getLogger("wren-ai-service")


text_to_sql_user_prompt_template = """
sql_generation_user_prompt_template = """
### TASK ###
Given a user query that is ambiguous in nature, your task is to interpret the query in various plausible ways and
generate three SQL statements that could potentially answer each interpreted version of the queries.
Expand Down Expand Up @@ -119,7 +119,7 @@ async def generate_sql(prompt: dict, generator: Any) -> dict:
@observe(capture_input=False)
async def post_process(
generate_sql: dict,
post_processor: SQLGenerationPostProcessor,
post_processor: SQLGenPostProcessor,
project_id: str | None = None,
) -> dict:
logger.debug(
Expand All @@ -139,10 +139,12 @@ def __init__(
):
self._components = {
"generator": llm_provider.get_generator(
system_prompt=text_to_sql_system_prompt
system_prompt=sql_generation_system_prompt
),
"prompt_builder": PromptBuilder(template=text_to_sql_user_prompt_template),
"post_processor": SQLGenerationPostProcessor(engine=engine),
"prompt_builder": PromptBuilder(
template=sql_generation_user_prompt_template
),
"post_processor": SQLGenPostProcessor(engine=engine),
}

self._configs = {
Expand Down
32 changes: 14 additions & 18 deletions wren-ai-service/src/pipelines/generation/sql_regeneration.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from src.core.engine import Engine
from src.core.pipeline import BasicPipeline
from src.core.provider import LLMProvider
from src.pipelines.common import SQLBreakdownGenerationPostProcessor
from src.pipelines.common import SQLBreakdownGenPostProcessor
from src.utils import async_timer, timer
from src.web.v1.services.sql_regeneration import (
SQLExplanationWithUserCorrections,
Expand Down Expand Up @@ -79,7 +79,7 @@


@component
class SQLRegenerationRreprocesser:
class SQLRegenerationPreprocesser:
@component.output_types(
results=Dict[str, Any],
)
Expand All @@ -102,11 +102,11 @@ def run(
def preprocess(
description: str,
steps: List[SQLExplanationWithUserCorrections],
sql_regeneration_preprocesser: SQLRegenerationRreprocesser,
preprocesser: SQLRegenerationPreprocesser,
) -> dict[str, Any]:
logger.debug(f"steps: {steps}")
logger.debug(f"description: {description}")
return sql_regeneration_preprocesser.run(
return preprocesser.run(
description=description,
steps=steps,
)
Expand All @@ -116,37 +116,35 @@ def preprocess(
@observe(capture_input=False)
def sql_regeneration_prompt(
preprocess: Dict[str, Any],
sql_regeneration_prompt_builder: PromptBuilder,
prompt_builder: PromptBuilder,
) -> dict:
logger.debug(f"preprocess: {preprocess}")
return sql_regeneration_prompt_builder.run(results=preprocess["results"])
return prompt_builder.run(results=preprocess["results"])


@async_timer
@observe(as_type="generation", capture_input=False)
async def generate_sql_regeneration(
sql_regeneration_prompt: dict,
sql_regeneration_generator: Any,
generator: Any,
) -> dict:
logger.debug(
f"sql_regeneration_prompt: {orjson.dumps(sql_regeneration_prompt, option=orjson.OPT_INDENT_2).decode()}"
)
return await sql_regeneration_generator.run(
prompt=sql_regeneration_prompt.get("prompt")
)
return await generator.run(prompt=sql_regeneration_prompt.get("prompt"))


@async_timer
@observe(capture_input=False)
async def sql_regeneration_post_process(
generate_sql_regeneration: dict,
sql_regeneration_post_processor: SQLBreakdownGenerationPostProcessor,
post_processor: SQLBreakdownGenPostProcessor,
project_id: str | None = None,
) -> dict:
logger.debug(
f"generate_sql_regeneration: {orjson.dumps(generate_sql_regeneration, option=orjson.OPT_INDENT_2).decode()}"
)
return await sql_regeneration_post_processor.run(
return await post_processor.run(
replies=generate_sql_regeneration.get("replies"),
project_id=project_id,
)
Expand All @@ -162,16 +160,14 @@ def __init__(
engine: Engine,
):
self._components = {
"sql_regeneration_preprocesser": SQLRegenerationRreprocesser(),
"sql_regeneration_prompt_builder": PromptBuilder(
"preprocesser": SQLRegenerationPreprocesser(),
"prompt_builder": PromptBuilder(
template=sql_regeneration_user_prompt_template
),
"sql_regeneration_generator": llm_provider.get_generator(
"generator": llm_provider.get_generator(
system_prompt=sql_regeneration_system_prompt
),
"sql_regeneration_post_processor": SQLBreakdownGenerationPostProcessor(
engine=engine
),
"post_processor": SQLBreakdownGenPostProcessor(engine=engine),
}

super().__init__(
Expand Down

0 comments on commit 82cdce7

Please sign in to comment.