Skip to content

Commit

Permalink
update feedback loop
Browse files Browse the repository at this point in the history
  • Loading branch information
cyyeh committed Sep 10, 2024
1 parent 2b1a0ef commit d6b4881
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 26 deletions.
77 changes: 52 additions & 25 deletions wren-ai-service/src/pipelines/sql_explanation/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,36 +135,57 @@ def _collect_relations(relation, result, cte_names, top_level: bool = True):
return results


def _compose_sql_expression_of_select_type(select_items: List[Dict]) -> Dict:
def _compose_sql_expression_of_select_type(
select_items: List[Dict], selected_data_sources: List[List[Dict]]
) -> Dict:
def _is_select_item_existed_in_selected_data_sources(
select_item, selected_data_sources
):
for selected_data_source in selected_data_sources:
for data_source in selected_data_source:
if (
"exprSources" in select_item
and select_item["exprSources"]
and select_item["exprSources"][0]["sourceDataset"]
== data_source["sourceDataset"]
and select_item["exprSources"][0]["sourceColumn"]
== data_source["sourceColumn"]
):
return True
return False

result = {
"withFunctionCallOrMathematicalOperation": [],
"withoutFunctionCallOrMathematicalOperation": [],
}

for select_item in select_items:
if (
select_item["properties"]["includeFunctionCall"] == "true"
or select_item["properties"]["includeMathematicalOperation"] == "true"
if not _is_select_item_existed_in_selected_data_sources(
select_item, selected_data_sources
):
result["withFunctionCallOrMathematicalOperation"].append(
{
"values": {
"alias": select_item["alias"],
"expression": select_item["expression"],
},
"id": select_item.get("id", ""),
}
)
else:
result["withoutFunctionCallOrMathematicalOperation"].append(
{
"values": {
"alias": select_item["alias"],
"expression": select_item["expression"],
},
"id": select_item.get("id", ""),
}
)
if (
select_item["properties"]["includeFunctionCall"] == "true"
or select_item["properties"]["includeMathematicalOperation"] == "true"
):
result["withFunctionCallOrMathematicalOperation"].append(
{
"values": {
"alias": select_item["alias"],
"expression": select_item["expression"],
},
"id": select_item.get("id", ""),
}
)
else:
result["withoutFunctionCallOrMathematicalOperation"].append(
{
"values": {
"alias": select_item["alias"],
"expression": select_item["expression"],
},
"id": select_item.get("id", ""),
}
)

return result

Expand Down Expand Up @@ -210,6 +231,7 @@ class SQLAnalysisPreprocessor:
def run(
self,
cte_names: List[str],
selected_data_sources: List[List[Dict]],
sql_analysis_results: List[Dict],
) -> Dict[str, List[Dict]]:
preprocessed_sql_analysis_results = []
Expand Down Expand Up @@ -245,7 +267,7 @@ def run(
preprocessed_sql_analysis_result[
"selectItems"
] = _compose_sql_expression_of_select_type(
sql_analysis_result["selectItems"]
sql_analysis_result["selectItems"], selected_data_sources
)
else:
preprocessed_sql_analysis_result["selectItems"] = {
Expand Down Expand Up @@ -419,14 +441,15 @@ def run(
@timer
@observe(capture_input=False)
def preprocess(
selected_data_sources: List[List[dict]],
sql_analysis_results: List[dict],
cte_names: List[str],
pre_processor: SQLAnalysisPreprocessor,
) -> dict:
logger.debug(
f"sql_analysis_results: {orjson.dumps(sql_analysis_results, option=orjson.OPT_INDENT_2).decode()}"
)
return pre_processor.run(cte_names, sql_analysis_results)
return pre_processor.run(cte_names, selected_data_sources, sql_analysis_results)


@timer
Expand Down Expand Up @@ -561,6 +584,7 @@ def visualize(
self,
question: str,
cte_names: List[str],
selected_data_sources: List[List[dict]],
step_with_analysis_results: StepWithAnalysisResults,
) -> None:
destination = "outputs/pipelines/sql_explanation"
Expand All @@ -573,6 +597,7 @@ def visualize(
inputs={
"question": question,
"cte_names": cte_names,
"selected_data_sources": selected_data_sources,
"sql": step_with_analysis_results.sql,
"sql_analysis_results": step_with_analysis_results.sql_analysis_results,
"sql_summary": step_with_analysis_results.summary,
Expand All @@ -591,6 +616,7 @@ async def run(
self,
question: str,
cte_names: List[str],
selected_data_sources: List[List[dict]],
step_with_analysis_results: StepWithAnalysisResults,
):
logger.info("SQL Explanation Generation pipeline is running...")
Expand All @@ -600,6 +626,7 @@ async def run(
inputs={
"question": question,
"cte_names": cte_names,
"selected_data_sources": selected_data_sources,
"sql": step_with_analysis_results.sql,
"sql_analysis_results": step_with_analysis_results.sql_analysis_results,
"sql_summary": step_with_analysis_results.summary,
Expand Down
18 changes: 17 additions & 1 deletion wren-ai-service/src/web/v1/services/sql_explanation.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,25 +75,41 @@ async def sql_explanation(self, sql_explanation_request: SQLExplanationRequest):
async def _task(
question: str,
cte_names: List[str],
selected_data_sources: List[List[dict]],
step_with_analysis_results: StepWithAnalysisResults,
i: int,
):
return await self._pipelines["generation"].run(
question=question,
cte_names=cte_names,
selected_data_sources=selected_data_sources[:i],
step_with_analysis_results=step_with_analysis_results,
)

cte_names = [
step_with_analysis_results.cte_name
for step_with_analysis_results in sql_explanation_request.steps_with_analysis_results
]
selected_data_sources = [
[
select_item["exprSources"][0]
for analysis_result in step_with_analysis_results.sql_analysis_results
for select_item in analysis_result.get("selectItems", [])
if select_item.get("exprSources", [])
]
for step_with_analysis_results in sql_explanation_request.steps_with_analysis_results
]
tasks = [
_task(
sql_explanation_request.question,
cte_names,
selected_data_sources,
step_with_analysis_results,
i,
)
for i, step_with_analysis_results in enumerate(
sql_explanation_request.steps_with_analysis_results
)
for step_with_analysis_results in sql_explanation_request.steps_with_analysis_results
]
generation_results = await asyncio.gather(*tasks)

Expand Down

0 comments on commit d6b4881

Please sign in to comment.