Skip to content

Commit

Permalink
chore(wren-ai-service): fix generation eval (#656)
Browse files Browse the repository at this point in the history
* fix eval data generation

* update prompt
  • Loading branch information
cyyeh authored Sep 9, 2024
1 parent 9d4e819 commit 6775e02
Show file tree
Hide file tree
Showing 2 changed files with 174 additions and 3 deletions.
176 changes: 173 additions & 3 deletions wren-ai-service/eval/data_curation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from copy import deepcopy
from datetime import datetime
from pathlib import Path
from typing import List, Tuple
from typing import Any, Dict, List, Tuple

import aiohttp
import orjson
Expand Down Expand Up @@ -100,6 +100,176 @@ async def get_validated_question_sql_pairs(
]


def get_ddl_commands(mdl: Dict[str, Any]) -> List[str]:
def _convert_models_and_relationships(
models: List[Dict[str, Any]], relationships: List[Dict[str, Any]]
) -> List[str]:
ddl_commands = []

# A map to store model primary keys for foreign key relationships
primary_keys_map = {model["name"]: model["primaryKey"] for model in models}

for model in models:
table_name = model["name"]
columns_ddl = []
for column in model["columns"]:
if "relationship" not in column:
if "properties" in column:
column["properties"]["alias"] = column["properties"].pop(
"displayName", ""
)
comment = f"-- {orjson.dumps(column['properties']).decode("utf-8")}\n "
else:
comment = ""
if "isCalculated" in column and column["isCalculated"]:
comment = (
comment
+ f"-- This column is a Calculated Field\n -- column expression: {column["expression"]}\n "
)
column_name = column["name"]
column_type = column["type"]
column_ddl = f"{comment}{column_name} {column_type}"

# If column is a primary key
if column_name == model.get("primaryKey", ""):
column_ddl += " PRIMARY KEY"

columns_ddl.append(column_ddl)

# Add foreign key constraints based on relationships
for relationship in relationships:
comment = f'-- {{"condition": {relationship["condition"]}, "joinType": {relationship["joinType"]}}}\n '
if (
table_name == relationship["models"][0]
and relationship["joinType"].upper() == "MANY_TO_ONE"
):
related_table = relationship["models"][1]
fk_column = relationship["condition"].split(" = ")[0].split(".")[1]
fk_constraint = f"FOREIGN KEY ({fk_column}) REFERENCES {related_table}({primary_keys_map[related_table]})"
columns_ddl.append(f"{comment}{fk_constraint}")
elif (
table_name == relationship["models"][1]
and relationship["joinType"].upper() == "ONE_TO_MANY"
):
related_table = relationship["models"][0]
fk_column = relationship["condition"].split(" = ")[1].split(".")[1]
fk_constraint = f"FOREIGN KEY ({fk_column}) REFERENCES {related_table}({primary_keys_map[related_table]})"
columns_ddl.append(f"{comment}{fk_constraint}")
elif (
table_name in relationship["models"]
and relationship["joinType"].upper() == "ONE_TO_ONE"
):
index = relationship["models"].index(table_name)
related_table = [
m for m in relationship["models"] if m != table_name
][0]
fk_column = (
relationship["condition"].split(" = ")[index].split(".")[1]
)
fk_constraint = f"FOREIGN KEY ({fk_column}) REFERENCES {related_table}({primary_keys_map[related_table]})"
columns_ddl.append(f"{comment}{fk_constraint}")

if "properties" in model:
model["properties"]["alias"] = model["properties"].pop(
"displayName", ""
)
comment = (
f"\n/* {orjson.dumps(model['properties']).decode("utf-8")} */\n"
)
else:
comment = ""

create_table_ddl = (
f"{comment}CREATE TABLE {table_name} (\n "
+ ",\n ".join(columns_ddl)
+ "\n);"
)
ddl_commands.append(create_table_ddl)

return ddl_commands

def _convert_views(views: List[Dict[str, Any]]) -> List[str]:
def _format(view: Dict[str, Any]) -> str:
properties = view["properties"] if "properties" in view else ""
return f"/* {properties} */\nCREATE VIEW {view['name']}\nAS ({view['statement']})"

return [_format(view) for view in views]

def _convert_metrics(metrics: List[Dict[str, Any]]) -> List[str]:
ddl_commands = []

for metric in metrics:
table_name = metric["name"]
columns_ddl = []
for dimension in metric["dimension"]:
column_name = dimension["name"]
column_type = dimension["type"]
comment = "-- This column is a dimension\n "
column_ddl = f"{comment}{column_name} {column_type}"
columns_ddl.append(column_ddl)

for measure in metric["measure"]:
column_name = measure["name"]
column_type = measure["type"]
comment = f"-- This column is a measure\n -- expression: {measure["expression"]}\n "
column_ddl = f"{comment}{column_name} {column_type}"
columns_ddl.append(column_ddl)

comment = f"\n/* This table is a metric */\n/* Metric Base Object: {metric["baseObject"]} */\n"
create_table_ddl = (
f"{comment}CREATE TABLE {table_name} (\n "
+ ",\n ".join(columns_ddl)
+ "\n);"
)

ddl_commands.append(create_table_ddl)

return ddl_commands

semantics = {
"models": [],
"relationships": mdl["relationships"],
"views": mdl["views"],
"metrics": mdl["metrics"],
}

for model in mdl["models"]:
columns = []
for column in model["columns"]:
ddl_column = {
"name": column["name"],
"type": column["type"],
}
if "properties" in column:
ddl_column["properties"] = column["properties"]
if "relationship" in column:
ddl_column["relationship"] = column["relationship"]
if "expression" in column:
ddl_column["expression"] = column["expression"]
if "isCalculated" in column:
ddl_column["isCalculated"] = column["isCalculated"]

columns.append(ddl_column)

semantics["models"].append(
{
"type": "model",
"name": model["name"],
"properties": model["properties"] if "properties" in model else {},
"columns": columns,
"primaryKey": model["primaryKey"],
}
)

return (
_convert_models_and_relationships(
semantics["models"], semantics["relationships"]
)
+ _convert_metrics(semantics["metrics"])
+ _convert_views(semantics["views"])
)


async def get_contexts_from_sqls(
sqls: list[str],
mdl_json: dict,
Expand Down Expand Up @@ -230,7 +400,7 @@ def _build_partial_mdl_json(
"content": ddl_command,
}
for new_mdl_json in new_mdl_jsons
for i, ddl_command in enumerate(ddl_converter._get_ddl_commands(new_mdl_json))
for i, ddl_command in enumerate(get_ddl_commands(new_mdl_json))
]


Expand Down Expand Up @@ -274,7 +444,7 @@ async def get_question_sql_pairs(
{custom_instructions}
### Input ###
Data Model: {"\n\n".join(ddl_converter._get_ddl_commands(mdl_json))}
Data Model: {"\n\n".join(get_ddl_commands(mdl_json))}
Generate {num_pairs} of the questions and corresponding SQL queries according to the Output Format in JSON
Think step by step
Expand Down
1 change: 1 addition & 0 deletions wren-ai-service/src/pipelines/ask/components/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
### ALERT ###
- ONLY USE SELECT statements, NO DELETE, UPDATE OR INSERT etc. statements that might change the data in the database.
- ONLY USE the tables and columns mentioned in the database schema.
- ONLY USE "*" if the user query asks for all the columns of a table.
- ONLY CHOOSE columns belong to the tables mentioned in the database schema.
- YOU MUST USE "JOIN" if you choose columns from multiple tables!
- YOU MUST USE "lower(<column_name>) = lower(<value>)" function for case-insensitive comparison!
Expand Down

0 comments on commit 6775e02

Please sign in to comment.