diff --git a/pyproject.toml b/pyproject.toml index 3b7eeea5..e1f8a2a9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,6 @@ dependencies = [ # in the future we need not require pyarrow as pandas implements the # PyCapsule interface. See pandas PR #56587 "pyarrow>=14.0.0", - "tableauhyperapi>=0.0.19691", ] [project.urls] @@ -68,7 +67,13 @@ build = "cp39-*64 cp310-*64 cp311-*64 cp312-*64" skip = "*musllinux*" test-command = "python -m pytest {project}/tests" -test-requires = ["pytest", "pandas>=2.0.0", "polars~=1.2.0", "narwhals"] +test-requires = [ + "pytest", + "pandas>=2.0.0", + "polars~=1.2.0", + "narwhals", + "tableauhyperapi", +] [tool.ruff] line-length = 88 diff --git a/src/pantab/_reader.py b/src/pantab/_reader.py index 59597538..56b0aa65 100644 --- a/src/pantab/_reader.py +++ b/src/pantab/_reader.py @@ -1,12 +1,9 @@ import pathlib -import shutil -import tempfile from typing import Literal, Optional, Union import pyarrow as pa -import tableauhyperapi as tab_api -import pantab._types as pantab_types +import pantab._types as pt_types import pantab.libpantab as libpantab @@ -43,15 +40,21 @@ def frame_from_hyper_query( def frame_from_hyper( source: Union[str, pathlib.Path], *, - table: pantab_types.TableNameType, + table: pt_types.TableNameType, return_type: Literal["pandas", "polars", "pyarrow"] = "pandas", process_params: Optional[dict[str, str]] = None, ): """See api.rst for documentation""" - if isinstance(table, (str, tab_api.Name)) or not table.schema_name: - table = tab_api.TableName("public", table) - - query = f"SELECT * FROM {table}" + if isinstance(table, (pt_types.TableauName, pt_types.TableauTableName)): + tbl = str(table) + elif isinstance(table, tuple): + tbl = ".".join( + libpantab.escape_sql_identifier(x) for x in table + ) # check for injection + else: + tbl = libpantab.escape_sql_identifier(table) + + query = f"SELECT * FROM {tbl}" return frame_from_hyper_query( source, query, return_type=return_type, process_params=process_params ) @@ -65,17 +68,7 @@ def frames_from_hyper( """See api.rst for documentation.""" result = {} - table_names = [] - with tempfile.TemporaryDirectory() as tmp_dir, tab_api.HyperProcess( - tab_api.Telemetry.DO_NOT_SEND_USAGE_DATA_TO_TABLEAU, - parameters={"log_config": ""}, - ) as hpe: - tmp_db = shutil.copy(source, tmp_dir) - with tab_api.Connection(hpe.endpoint, tmp_db) as connection: - for schema in connection.catalog.get_schema_names(): - for table in connection.catalog.get_table_names(schema=schema): - table_names.append(table) - + table_names = libpantab.get_table_names(str(source)) for table in table_names: result[table] = frame_from_hyper( source=source, diff --git a/src/pantab/_types.py b/src/pantab/_types.py index 102fae12..fb10e42b 100644 --- a/src/pantab/_types.py +++ b/src/pantab/_types.py @@ -1,5 +1,22 @@ -from typing import Union +from typing import Optional, Protocol, Union, runtime_checkable -import tableauhyperapi as tab_api -TableNameType = Union[str, tab_api.Name, tab_api.TableName] +@runtime_checkable +class TableauName(Protocol): + @property + def unescaped(self) -> str: + ... + + +@runtime_checkable +class TableauTableName(Protocol): + @property + def name(self) -> TableauName: + ... + + @property + def schema_name(self) -> Optional[TableauName]: + ... + + +TableNameType = Union[str, TableauName, TableauTableName] diff --git a/src/pantab/_writer.py b/src/pantab/_writer.py index a44bb703..de92841d 100644 --- a/src/pantab/_writer.py +++ b/src/pantab/_writer.py @@ -4,9 +4,7 @@ import uuid from typing import Any, Literal, Optional, Union -import tableauhyperapi as tab_api - -import pantab._types as pantab_types +import pantab._types as pt_types import pantab.libpantab as libpantab @@ -52,7 +50,7 @@ def frame_to_hyper( df, database: Union[str, pathlib.Path], *, - table: pantab_types.TableNameType, + table: pt_types.TableNameType, table_mode: Literal["a", "w"] = "w", not_null_columns: Optional[set[str]] = None, json_columns: Optional[set[str]] = None, @@ -72,7 +70,7 @@ def frame_to_hyper( def frames_to_hyper( - dict_of_frames: dict[pantab_types.TableNameType, Any], + dict_of_frames: dict[pt_types.TableNameType, Any], database: Union[str, pathlib.Path], *, table_mode: Literal["a", "w"] = "w", @@ -98,12 +96,16 @@ def frames_to_hyper( if table_mode == "a" and pathlib.Path(database).exists(): shutil.copy(database, tmp_db) - def convert_to_table_name(table: pantab_types.TableNameType): - # nanobind expects a tuple of (schema, table) strings - if isinstance(table, (str, tab_api.Name)) or not table.schema_name: - table = tab_api.TableName("public", table) + def convert_to_table_name(table: pt_types.TableNameType): + if isinstance(table, pt_types.TableauTableName): + if table.schema_name: + return (table.schema_name.name.unescaped, table.name.unescaped) + else: + return table.name.unescaped + elif isinstance(table, pt_types.TableauName): + return table.unescaped - return (table.schema_name.name.unescaped, table.name.unescaped) + return table data = { convert_to_table_name(key): _get_capsule_from_obj(val) diff --git a/src/pantab/libpantab.cpp b/src/pantab/libpantab.cpp index 5f31b8cd..8e5304f9 100644 --- a/src/pantab/libpantab.cpp +++ b/src/pantab/libpantab.cpp @@ -1,17 +1,59 @@ +#include + +#include #include +#include -#include "datetime.h" #include "reader.hpp" #include "writer.hpp" namespace nb = nanobind; NB_MODULE(libpantab, m) { // NOLINT - m.def("write_to_hyper", &write_to_hyper, nb::arg("dict_of_capsules"), - nb::arg("path"), nb::arg("table_mode"), nb::arg("not_null_columns"), - nb::arg("json_columns"), nb::arg("geo_columns"), - nb::arg("process_params")) + m.def("escape_sql_identifier", + [](const nb::str &str) { + const auto required_size = + hyper_quote_sql_identifier(nullptr, 0, str.c_str(), nb::len(str)); + std::string result(required_size, 'x'); + hyper_quote_sql_identifier(result.data(), required_size, str.c_str(), + nb::len(str)); + return result; + }) + .def( + "get_table_names", + [](const std::string &path) { + std::unordered_map params{ + {"log_config", ""}}; + const hyperapi::HyperProcess hyper{ + hyperapi::Telemetry::DoNotSendUsageDataToTableau, "", + std::move(params)}; + hyperapi::Connection connection(hyper.getEndpoint(), path); + + nb::list result; + for (const auto &schema_name : + connection.getCatalog().getSchemaNames()) { + for (const auto &table_name : + connection.getCatalog().getTableNames(schema_name)) { + const auto schema_prefix = table_name.getSchemaName(); + if (schema_prefix) { + const auto tup = + nb::make_tuple(schema_prefix->getName().getUnescaped(), + table_name.getName().getUnescaped()); + result.append(tup); + } else { + result.append( + nb::str(table_name.getName().getUnescaped().c_str())); + } + } + } + + return result; + }, + nb::arg("path")) + .def("write_to_hyper", &write_to_hyper, nb::arg("dict_of_capsules"), + nb::arg("path"), nb::arg("table_mode"), nb::arg("not_null_columns"), + nb::arg("json_columns"), nb::arg("geo_columns"), + nb::arg("process_params")) .def("read_from_hyper_query", &read_from_hyper_query, nb::arg("path"), nb::arg("query"), nb::arg("process_params")); - PyDateTime_IMPORT; } diff --git a/src/pantab/writer.cpp b/src/pantab/writer.cpp index 9f24a308..3c1004b3 100644 --- a/src/pantab/writer.cpp +++ b/src/pantab/writer.cpp @@ -575,13 +575,10 @@ static void AssertColumnsEqual( } }; -using SchemaAndTableName = std::tuple; - void write_to_hyper( - const std::map &dict_of_capsules, - const std::string &path, const std::string &table_mode, - const nb::iterable not_null_columns, const nb::iterable json_columns, - const nb::iterable geo_columns, + const nb::object &dict_of_capsules, const std::string &path, + const std::string &table_mode, const nb::iterable not_null_columns, + const nb::iterable json_columns, const nb::iterable geo_columns, std::unordered_map &&process_params) { std::set not_null_set; @@ -620,10 +617,8 @@ void write_to_hyper( hyperapi::Connection connection{hyper.getEndpoint(), path, createMode}; const hyperapi::Catalog &catalog = connection.getCatalog(); - for (auto const &[schema_and_table, capsule] : dict_of_capsules) { - const auto hyper_schema = std::get<0>(schema_and_table); - const auto hyper_table = std::get<1>(schema_and_table); - + for (auto const &[name, capsule] : + nb::cast(dict_of_capsules, false)) { const auto c_stream = static_cast( PyCapsule_GetPointer(capsule.ptr(), "arrow_array_stream")); if (c_stream == nullptr) { @@ -723,9 +718,22 @@ void write_to_hyper( } } - const hyperapi::TableName table_name{hyper_schema, hyper_table}; + std::tuple schema_and_table; + std::string t_name; + const auto is_tup = nb::try_cast(name, schema_and_table, false); + const auto is_str = nb::try_cast(name, t_name, false); + if (!(is_tup || is_str)) { + throw nb::type_error("Expected string or tuple key"); + } + const auto table_name = + is_tup ? hyperapi::TableName(std::get<0>(schema_and_table), + std::get<1>(schema_and_table)) + : hyperapi::TableName(t_name); const hyperapi::TableDefinition table_def{table_name, hyper_columns}; - catalog.createSchemaIfNotExists(*table_name.getSchemaName()); + + const auto schema_name = + table_name.getSchemaName() ? *table_name.getSchemaName() : "public"; + catalog.createSchemaIfNotExists(schema_name); if ((table_mode == "a") && (catalog.hasTable(table_name))) { const auto existing_def = catalog.getTableDefinition(table_name); diff --git a/src/pantab/writer.hpp b/src/pantab/writer.hpp index 66426727..e302077e 100644 --- a/src/pantab/writer.hpp +++ b/src/pantab/writer.hpp @@ -8,11 +8,8 @@ namespace nb = nanobind; -using SchemaAndTableName = std::tuple; - void write_to_hyper( - const std::map &dict_of_capsules, - const std::string &path, const std::string &table_mode, - const nb::iterable not_null_columns, const nb::iterable json_columns, - const nb::iterable geo_columns, + const nb::object &dict_of_capsules, const std::string &path, + const std::string &table_mode, const nb::iterable not_null_columns, + const nb::iterable json_columns, const nb::iterable geo_columns, std::unordered_map &&process_params); diff --git a/tests/test_reader.py b/tests/test_reader.py index 561c053e..bf2c96b3 100644 --- a/tests/test_reader.py +++ b/tests/test_reader.py @@ -156,7 +156,7 @@ def test_reader_accepts_process_params(tmp_hyper): pt.frames_from_hyper(tmp_hyper, process_params=params) -def test_reader_invalid_process_params_raises(frame, tmp_hyper): +def test_reader_invalid_process_params_raises(tmp_hyper): frame = pd.DataFrame(list(range(10)), columns=["nums"]).astype("int8") pt.frame_to_hyper(frame, tmp_hyper, table="test") @@ -164,3 +164,23 @@ def test_reader_invalid_process_params_raises(frame, tmp_hyper): msg = r"No internal setting named 'not_a_real_parameter'" with pytest.raises(RuntimeError, match=msg): pt.frames_from_hyper(tmp_hyper, process_params=params) + + +@pytest.mark.parametrize( + "table_name", + [ + "a';DROP TABLE users;DELETE FROM foo WHERE 't' = 't", + tab_api.Name("a';DROP TABLE users;DELETE FROM foo WHERE 't' = 't"), + tab_api.TableName( + "public", "a';DROP TABLE users;DELETE FROM foo WHERE 't' = 't" + ), + tab_api.TableName( + "a';DROP TABLE users;DELETE FROM foo WHERE 't' = 't", + "a';DROP TABLE users;DELETE FROM foo WHERE 't' = 't", + ), + ], +) +def test_reader_prevents_sql_injection(tmp_hyper, table_name): + frame = pd.DataFrame(list(range(10)), columns=["nums"]).astype("int8") + pt.frame_to_hyper(frame, tmp_hyper, table=table_name) + pt.frame_from_hyper(tmp_hyper, table=table_name) diff --git a/tests/test_roundtrip.py b/tests/test_roundtrip.py index 4c61ad7d..e797c83c 100644 --- a/tests/test_roundtrip.py +++ b/tests/test_roundtrip.py @@ -1,5 +1,9 @@ +import sys + +import pandas as pd import pyarrow as pa -from tableauhyperapi import TableName +import pytest +import tableauhyperapi as tab_api import pantab as pt @@ -70,10 +74,15 @@ def test_multiple_tables( expected = compat.concat_frames(expected, expected) # some test trickery here - if not isinstance(table_name, TableName) or table_name.schema_name is None: - table_name = TableName("public", table_name) - - assert set(result.keys()) == set((table_name, TableName("public", "table2"))) + if not isinstance(table_name, tab_api.TableName) or table_name.schema_name is None: + table_name = tab_api.TableName("public", table_name) + + assert set(result.keys()) == set( + ( + tuple(table_name._unescaped_components), + tuple(tab_api.TableName("public", "table2")._unescaped_components), + ) + ) for val in result.values(): compat.assert_frame_equal(val, expected) @@ -113,3 +122,35 @@ def test_empty_roundtrip( expected = compat.drop_columns(expected, ["object"]) expected = compat.empty_like(expected) compat.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "table_name", + [ + "a';DROP TABLE users;DELETE FROM foo WHERE 't' = 't", + tab_api.Name("a';DROP TABLE users;DELETE FROM foo WHERE 't' = 't"), + tab_api.TableName( + "public", "a';DROP TABLE users;DELETE FROM foo WHERE 't' = 't" + ), + tab_api.TableName( + "a';DROP TABLE users;DELETE FROM foo WHERE 't' = 't", + "a';DROP TABLE users;DELETE FROM foo WHERE 't' = 't", + ), + ], +) +def test_write_prevents_injection(tmp_hyper, table_name): + frame = pd.DataFrame(list(range(10)), columns=["nums"]).astype("int8") + frames = {table_name: frame} + pt.frames_to_hyper(frames, tmp_hyper) + pt.frames_from_hyper(tmp_hyper) + + +def test_roundtrip_works_without_tableauhyperapi(frame, tmp_hyper, monkeypatch): + libname = "tableauhyperapi" + mods = set(sys.modules.keys()) + for mod in mods: + if mod.startswith(libname): + monkeypatch.delitem(sys.modules, mod) + + pt.frame_to_hyper(frame, tmp_hyper, table="foo") + pt.frames_from_hyper(tmp_hyper)