diff --git a/datasette_extract/__init__.py b/datasette_extract/__init__.py index 476d7aa..559013c 100644 --- a/datasette_extract/__init__.py +++ b/datasette_extract/__init__.py @@ -1,16 +1,160 @@ -from datasette import hookimpl, Response import click +from datasette import hookimpl, Response, NotFound +import json +import openai +from sqlite_utils import Database +import sys -async def extract_web(datasette, request): +async def extract_create_table(datasette, request): + database = request.url_vars["database"] + try: + db = datasette.get_database(database) + except KeyError: + raise NotFound("Database '{}' does not exist".format(database)) + + if request.method == "POST": + post_vars = await request.post_vars() + content = (post_vars.get("content") or "").strip() + if not content: + return Response.text("No content provided", status=400) + table = post_vars.get("table") + if not table: + return Response.text("No table provided", status=400) + + properties = {} + # Build the properties out of name_0 upwards, only if populated + for key, value in post_vars.items(): + if key.startswith("name_") and value.strip(): + index = int(key.split("_")[1]) + type_ = post_vars.get("type_{}".format(index)) + hint = post_vars.get("hint_{}".format(index)) + properties[value] = { + "type": type_, + } + if hint: + properties[value]["description"] = hint + + return await extract_to_table_post( + datasette, request, content, database, table, properties + ) + return Response.html( - await datasette.render_template("extract.html", request=request) + await datasette.render_template( + "extract_create_table.html", + { + "database": database, + "fields": range(10), + }, + request=request, + ) ) +async def extract_to_table(datasette, request): + database = request.url_vars["database"] + table = request.url_vars["table"] + # Do they exist? + try: + db = datasette.get_database(database) + except KeyError: + raise NotFound("Database '{}' does not exist".format(database)) + tables = await db.table_names() + if table not in tables: + raise NotFound("Table '{}' does not exist".format(table)) + + schema = await db.execute_fn(lambda conn: Database(conn)[table].columns_dict) + + if request.method == "POST": + # Turn schema into a properties dict + properties = { + name: { + "type": get_type(type_), + # "description": "..." + } + for name, type_ in schema.items() + } + post_vars = await request.post_vars() + content = (post_vars.get("content") or "").strip() + return await extract_to_table_post( + datasette, request, content, database, table, properties + ) + + return Response.html( + await datasette.render_template( + "extract_to_table.html", + { + "database": database, + "table": table, + "schema": schema, + }, + request=request, + ) + ) + + +async def extract_to_table_post( + datasette, request, content, database, table, properties +): + # Here we go! + if not content: + return Response.text("No content provided") + + required_fields = list(properties.keys()) + try: + contents = [] + async for chunk in await openai.ChatCompletion.acreate( + stream=True, + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": content}], + functions=[ + { + "name": "extract_data", + "description": "Extract data matching this schema", + "parameters": { + "type": "object", + "properties": { + "items": { + "type": "array", + "items": { + "type": "object", + "properties": properties, + "required": required_fields, + }, + } + }, + "required": ["items"], + }, + }, + ], + function_call={"name": "extract_data"}, + ): + content = ( + chunk["choices"][0] + .get("delta", {}) + .get("function_call", {}) + .get("arguments") + ) + print(content, end="") + sys.stdout.flush() + if content is not None: + contents.append(content) + + except openai.OpenAIError as ex: + return Response.text(str(ex), status=400) + output = "".join(contents) + return Response.json(json.loads(output)) + + @click.command() -def extract(): - click.echo("Hello from extract") +@click.argument( + "database", + type=click.Path(file_okay=True, dir_okay=False, allow_dash=False), + required=True, +) +@click.argument("table", required=True) +def extract(database, table): + click.echo("Will extract to {} in {}".format(table, database)) @hookimpl @@ -27,5 +171,15 @@ def startup(): @hookimpl def register_routes(): return [ - (r"^/-/extract$", extract_web), + (r"^/-/extract/(?P[^/]+)$", extract_create_table), + (r"^/-/extract/(?P[^/]+)/(?P[^/]+)$", extract_to_table), ] + + +def get_type(type_): + if type_ is int: + return "integer" + elif type_ is float: + return "number" + else: + return "string" diff --git a/datasette_extract/templates/extract_create_table.html b/datasette_extract/templates/extract_create_table.html new file mode 100644 index 0000000..fa27760 --- /dev/null +++ b/datasette_extract/templates/extract_create_table.html @@ -0,0 +1,40 @@ +{% extends "base.html" %} + +{% block title %}Extract data and create a new table{% endblock %} + +{% block extra_head %} + +{% endblock %} + +{% block content %} +

Extract data and create a new table in {{ database }}

+ + +

+ + + +

+ {% for field in fields %} +

+ + + +

+ {% endfor %} +

+ +

+

+ +

+ + +{% endblock %} diff --git a/datasette_extract/templates/extract_to_table.html b/datasette_extract/templates/extract_to_table.html new file mode 100644 index 0000000..7b87d25 --- /dev/null +++ b/datasette_extract/templates/extract_to_table.html @@ -0,0 +1,26 @@ +{% extends "base.html" %} + +{% block title %}Extract{% endblock %} + +{% block extra_head %} + +{% endblock %} + +{% block content %} +

Extract data into {{ database }} / {{ table }}

+ +
+{{ schema}}
+
+ + +

+ + +

+

+ +

+ + +{% endblock %} diff --git a/tests/test_cli.py b/tests/test_cli.py index db664bd..311aec5 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -4,6 +4,6 @@ def test_extract_command(): runner = CliRunner() - result = runner.invoke(cli, ["extract"]) + result = runner.invoke(cli, ["extract", "database", "table"]) assert result.exit_code == 0 - assert result.output == "Hello from extract\n" + assert result.output == "Will extract to table in database\n" diff --git a/tests/test_web.py b/tests/test_web.py index 18ed863..da3d0a3 100644 --- a/tests/test_web.py +++ b/tests/test_web.py @@ -5,6 +5,7 @@ @pytest.mark.asyncio async def test_extract_web(): ds = Datasette(memory=True) - response = await ds.client.get("/-/extract") + ds.add_memory_database("data") + response = await ds.client.get("/-/extract/data") assert response.status_code == 200 - assert "

Extract

" in response.text + assert "

Extract data and create a new table in data

" in response.text