From 96afd7ce68ddd9aba1c7de1ccb592daf2e65c931 Mon Sep 17 00:00:00 2001 From: ChihYu Yeh Date: Wed, 31 Jul 2024 10:12:47 +0800 Subject: [PATCH] update feedback loop --- .../pipelines/generation/sql_explanation.py | 77 +++++++++++++------ .../src/web/v1/services/sql_explanation.py | 18 ++++- 2 files changed, 69 insertions(+), 26 deletions(-) diff --git a/wren-ai-service/src/pipelines/generation/sql_explanation.py b/wren-ai-service/src/pipelines/generation/sql_explanation.py index 9e4e6162f..bcf7220cb 100644 --- a/wren-ai-service/src/pipelines/generation/sql_explanation.py +++ b/wren-ai-service/src/pipelines/generation/sql_explanation.py @@ -218,36 +218,57 @@ def _collect_relations(relation, result, cte_names, top_level: bool = True): return results -def _compose_sql_expression_of_select_type(select_items: List[Dict]) -> Dict: +def _compose_sql_expression_of_select_type( + select_items: List[Dict], selected_data_sources: List[List[Dict]] +) -> Dict: + def _is_select_item_existed_in_selected_data_sources( + select_item, selected_data_sources + ): + for selected_data_source in selected_data_sources: + for data_source in selected_data_source: + if ( + "exprSources" in select_item + and select_item["exprSources"] + and select_item["exprSources"][0]["sourceDataset"] + == data_source["sourceDataset"] + and select_item["exprSources"][0]["sourceColumn"] + == data_source["sourceColumn"] + ): + return True + return False + result = { "withFunctionCallOrMathematicalOperation": [], "withoutFunctionCallOrMathematicalOperation": [], } for select_item in select_items: - if ( - select_item["properties"]["includeFunctionCall"] == "true" - or select_item["properties"]["includeMathematicalOperation"] == "true" + if not _is_select_item_existed_in_selected_data_sources( + select_item, selected_data_sources ): - result["withFunctionCallOrMathematicalOperation"].append( - { - "values": { - "alias": select_item["alias"], - "expression": select_item["expression"], - }, - "id": select_item.get("id", ""), - } - ) - else: - result["withoutFunctionCallOrMathematicalOperation"].append( - { - "values": { - "alias": select_item["alias"], - "expression": select_item["expression"], - }, - "id": select_item.get("id", ""), - } - ) + if ( + select_item["properties"]["includeFunctionCall"] == "true" + or select_item["properties"]["includeMathematicalOperation"] == "true" + ): + result["withFunctionCallOrMathematicalOperation"].append( + { + "values": { + "alias": select_item["alias"], + "expression": select_item["expression"], + }, + "id": select_item.get("id", ""), + } + ) + else: + result["withoutFunctionCallOrMathematicalOperation"].append( + { + "values": { + "alias": select_item["alias"], + "expression": select_item["expression"], + }, + "id": select_item.get("id", ""), + } + ) return result @@ -293,6 +314,7 @@ class SQLAnalysisPreprocessor: def run( self, cte_names: List[str], + selected_data_sources: List[List[Dict]], sql_analysis_results: List[Dict], ) -> Dict[str, List[Dict]]: preprocessed_sql_analysis_results = [] @@ -328,7 +350,7 @@ def run( preprocessed_sql_analysis_result[ "selectItems" ] = _compose_sql_expression_of_select_type( - sql_analysis_result["selectItems"] + sql_analysis_result["selectItems"], selected_data_sources ) else: preprocessed_sql_analysis_result["selectItems"] = { @@ -502,6 +524,7 @@ def run( @timer @observe(capture_input=False) def preprocess( + selected_data_sources: List[List[dict]], sql_analysis_results: List[dict], cte_names: List[str], pre_processor: SQLAnalysisPreprocessor, @@ -509,7 +532,7 @@ def preprocess( logger.debug( f"sql_analysis_results: {orjson.dumps(sql_analysis_results, option=orjson.OPT_INDENT_2).decode()}" ) - return pre_processor.run(cte_names, sql_analysis_results) + return pre_processor.run(cte_names, selected_data_sources, sql_analysis_results) @timer @@ -647,6 +670,7 @@ def visualize( self, question: str, cte_names: List[str], + selected_data_sources: List[List[dict]], step_with_analysis_results: StepWithAnalysisResults, ) -> None: destination = "outputs/pipelines/generation" @@ -659,6 +683,7 @@ def visualize( inputs={ "question": question, "cte_names": cte_names, + "selected_data_sources": selected_data_sources, "sql": step_with_analysis_results.sql, "sql_analysis_results": step_with_analysis_results.sql_analysis_results, "sql_summary": step_with_analysis_results.summary, @@ -674,6 +699,7 @@ async def run( self, question: str, cte_names: List[str], + selected_data_sources: List[List[dict]], step_with_analysis_results: StepWithAnalysisResults, ): logger.info("SQL Explanation Generation pipeline is running...") @@ -683,6 +709,7 @@ async def run( inputs={ "question": question, "cte_names": cte_names, + "selected_data_sources": selected_data_sources, "sql": step_with_analysis_results.sql, "sql_analysis_results": step_with_analysis_results.sql_analysis_results, "sql_summary": step_with_analysis_results.summary, diff --git a/wren-ai-service/src/web/v1/services/sql_explanation.py b/wren-ai-service/src/web/v1/services/sql_explanation.py index 358a5ef82..66900d76a 100644 --- a/wren-ai-service/src/web/v1/services/sql_explanation.py +++ b/wren-ai-service/src/web/v1/services/sql_explanation.py @@ -88,11 +88,14 @@ async def sql_explanation( async def _task( question: str, cte_names: List[str], + selected_data_sources: List[List[dict]], step_with_analysis_results: StepWithAnalysisResults, + i: int, ): return await self._pipelines["sql_explanation"].run( question=question, cte_names=cte_names, + selected_data_sources=selected_data_sources[:i], step_with_analysis_results=step_with_analysis_results, ) @@ -100,13 +103,26 @@ async def _task( step_with_analysis_results.cte_name for step_with_analysis_results in sql_explanation_request.steps_with_analysis_results ] + selected_data_sources = [ + [ + select_item["exprSources"][0] + for analysis_result in step_with_analysis_results.sql_analysis_results + for select_item in analysis_result.get("selectItems", []) + if select_item.get("exprSources", []) + ] + for step_with_analysis_results in sql_explanation_request.steps_with_analysis_results + ] tasks = [ _task( sql_explanation_request.question, cte_names, + selected_data_sources, step_with_analysis_results, + i, + ) + for i, step_with_analysis_results in enumerate( + sql_explanation_request.steps_with_analysis_results ) - for step_with_analysis_results in sql_explanation_request.steps_with_analysis_results ] generation_results = await asyncio.gather(*tasks)