Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(wren-ai-service): fix generation eval #656

Merged
merged 2 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 173 additions & 3 deletions wren-ai-service/eval/data_curation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from copy import deepcopy
from datetime import datetime
from pathlib import Path
from typing import List, Tuple
from typing import Any, Dict, List, Tuple

import aiohttp
import orjson
Expand Down Expand Up @@ -100,6 +100,176 @@ async def get_validated_question_sql_pairs(
]


def get_ddl_commands(mdl: Dict[str, Any]) -> List[str]:
def _convert_models_and_relationships(
models: List[Dict[str, Any]], relationships: List[Dict[str, Any]]
) -> List[str]:
ddl_commands = []

# A map to store model primary keys for foreign key relationships
primary_keys_map = {model["name"]: model["primaryKey"] for model in models}

for model in models:
table_name = model["name"]
columns_ddl = []
for column in model["columns"]:
if "relationship" not in column:
if "properties" in column:
column["properties"]["alias"] = column["properties"].pop(
"displayName", ""
)
comment = f"-- {orjson.dumps(column['properties']).decode("utf-8")}\n "
else:
comment = ""
if "isCalculated" in column and column["isCalculated"]:
comment = (
comment
+ f"-- This column is a Calculated Field\n -- column expression: {column["expression"]}\n "
)
column_name = column["name"]
column_type = column["type"]
column_ddl = f"{comment}{column_name} {column_type}"

# If column is a primary key
if column_name == model.get("primaryKey", ""):
column_ddl += " PRIMARY KEY"

columns_ddl.append(column_ddl)

# Add foreign key constraints based on relationships
for relationship in relationships:
comment = f'-- {{"condition": {relationship["condition"]}, "joinType": {relationship["joinType"]}}}\n '
if (
table_name == relationship["models"][0]
and relationship["joinType"].upper() == "MANY_TO_ONE"
):
related_table = relationship["models"][1]
fk_column = relationship["condition"].split(" = ")[0].split(".")[1]
fk_constraint = f"FOREIGN KEY ({fk_column}) REFERENCES {related_table}({primary_keys_map[related_table]})"
columns_ddl.append(f"{comment}{fk_constraint}")
elif (
table_name == relationship["models"][1]
and relationship["joinType"].upper() == "ONE_TO_MANY"
):
related_table = relationship["models"][0]
fk_column = relationship["condition"].split(" = ")[1].split(".")[1]
fk_constraint = f"FOREIGN KEY ({fk_column}) REFERENCES {related_table}({primary_keys_map[related_table]})"
columns_ddl.append(f"{comment}{fk_constraint}")
elif (
table_name in relationship["models"]
and relationship["joinType"].upper() == "ONE_TO_ONE"
):
index = relationship["models"].index(table_name)
related_table = [
m for m in relationship["models"] if m != table_name
][0]
fk_column = (
relationship["condition"].split(" = ")[index].split(".")[1]
)
fk_constraint = f"FOREIGN KEY ({fk_column}) REFERENCES {related_table}({primary_keys_map[related_table]})"
columns_ddl.append(f"{comment}{fk_constraint}")

if "properties" in model:
model["properties"]["alias"] = model["properties"].pop(
"displayName", ""
)
comment = (
f"\n/* {orjson.dumps(model['properties']).decode("utf-8")} */\n"
)
else:
comment = ""

create_table_ddl = (
f"{comment}CREATE TABLE {table_name} (\n "
+ ",\n ".join(columns_ddl)
+ "\n);"
)
ddl_commands.append(create_table_ddl)

return ddl_commands

def _convert_views(views: List[Dict[str, Any]]) -> List[str]:
def _format(view: Dict[str, Any]) -> str:
properties = view["properties"] if "properties" in view else ""
return f"/* {properties} */\nCREATE VIEW {view['name']}\nAS ({view['statement']})"

return [_format(view) for view in views]

def _convert_metrics(metrics: List[Dict[str, Any]]) -> List[str]:
ddl_commands = []

for metric in metrics:
table_name = metric["name"]
columns_ddl = []
for dimension in metric["dimension"]:
column_name = dimension["name"]
column_type = dimension["type"]
comment = "-- This column is a dimension\n "
column_ddl = f"{comment}{column_name} {column_type}"
columns_ddl.append(column_ddl)

for measure in metric["measure"]:
column_name = measure["name"]
column_type = measure["type"]
comment = f"-- This column is a measure\n -- expression: {measure["expression"]}\n "
column_ddl = f"{comment}{column_name} {column_type}"
columns_ddl.append(column_ddl)

comment = f"\n/* This table is a metric */\n/* Metric Base Object: {metric["baseObject"]} */\n"
create_table_ddl = (
f"{comment}CREATE TABLE {table_name} (\n "
+ ",\n ".join(columns_ddl)
+ "\n);"
)

ddl_commands.append(create_table_ddl)

return ddl_commands

semantics = {
"models": [],
"relationships": mdl["relationships"],
"views": mdl["views"],
"metrics": mdl["metrics"],
}

for model in mdl["models"]:
columns = []
for column in model["columns"]:
ddl_column = {
"name": column["name"],
"type": column["type"],
}
if "properties" in column:
ddl_column["properties"] = column["properties"]
if "relationship" in column:
ddl_column["relationship"] = column["relationship"]
if "expression" in column:
ddl_column["expression"] = column["expression"]
if "isCalculated" in column:
ddl_column["isCalculated"] = column["isCalculated"]

columns.append(ddl_column)

semantics["models"].append(
{
"type": "model",
"name": model["name"],
"properties": model["properties"] if "properties" in model else {},
"columns": columns,
"primaryKey": model["primaryKey"],
}
)

return (
_convert_models_and_relationships(
semantics["models"], semantics["relationships"]
)
+ _convert_metrics(semantics["metrics"])
+ _convert_views(semantics["views"])
)


async def get_contexts_from_sqls(
sqls: list[str],
mdl_json: dict,
Expand Down Expand Up @@ -230,7 +400,7 @@ def _build_partial_mdl_json(
"content": ddl_command,
}
for new_mdl_json in new_mdl_jsons
for i, ddl_command in enumerate(ddl_converter._get_ddl_commands(new_mdl_json))
for i, ddl_command in enumerate(get_ddl_commands(new_mdl_json))
]


Expand Down Expand Up @@ -274,7 +444,7 @@ async def get_question_sql_pairs(
{custom_instructions}
### Input ###
Data Model: {"\n\n".join(ddl_converter._get_ddl_commands(mdl_json))}
Data Model: {"\n\n".join(get_ddl_commands(mdl_json))}
Generate {num_pairs} of the questions and corresponding SQL queries according to the Output Format in JSON
Think step by step
Expand Down
1 change: 1 addition & 0 deletions wren-ai-service/src/pipelines/ask/components/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
### ALERT ###
- ONLY USE SELECT statements, NO DELETE, UPDATE OR INSERT etc. statements that might change the data in the database.
- ONLY USE the tables and columns mentioned in the database schema.
- ONLY USE "*" if the user query asks for all the columns of a table.
- ONLY CHOOSE columns belong to the tables mentioned in the database schema.
- YOU MUST USE "JOIN" if you choose columns from multiple tables!
- YOU MUST USE "lower(<column_name>) = lower(<value>)" function for case-insensitive comparison!
Expand Down
Loading