Skip to content

Commit

Permalink
Remove runtime dependency on tableauhyperapi (#336)
Browse files Browse the repository at this point in the history
  • Loading branch information
WillAyd authored Sep 20, 2024
1 parent 19f2237 commit 5bb265c
Show file tree
Hide file tree
Showing 9 changed files with 190 additions and 65 deletions.
9 changes: 7 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ dependencies = [
# in the future we need not require pyarrow as pandas implements the
# PyCapsule interface. See pandas PR #56587
"pyarrow>=14.0.0",
"tableauhyperapi>=0.0.19691",
]

[project.urls]
Expand Down Expand Up @@ -68,7 +67,13 @@ build = "cp39-*64 cp310-*64 cp311-*64 cp312-*64"
skip = "*musllinux*"

test-command = "python -m pytest {project}/tests"
test-requires = ["pytest", "pandas>=2.0.0", "polars~=1.2.0", "narwhals"]
test-requires = [
"pytest",
"pandas>=2.0.0",
"polars~=1.2.0",
"narwhals",
"tableauhyperapi",
]

[tool.ruff]
line-length = 88
Expand Down
33 changes: 13 additions & 20 deletions src/pantab/_reader.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
import pathlib
import shutil
import tempfile
from typing import Literal, Optional, Union

import pyarrow as pa
import tableauhyperapi as tab_api

import pantab._types as pantab_types
import pantab._types as pt_types
import pantab.libpantab as libpantab


Expand Down Expand Up @@ -43,15 +40,21 @@ def frame_from_hyper_query(
def frame_from_hyper(
source: Union[str, pathlib.Path],
*,
table: pantab_types.TableNameType,
table: pt_types.TableNameType,
return_type: Literal["pandas", "polars", "pyarrow"] = "pandas",
process_params: Optional[dict[str, str]] = None,
):
"""See api.rst for documentation"""
if isinstance(table, (str, tab_api.Name)) or not table.schema_name:
table = tab_api.TableName("public", table)

query = f"SELECT * FROM {table}"
if isinstance(table, (pt_types.TableauName, pt_types.TableauTableName)):
tbl = str(table)
elif isinstance(table, tuple):
tbl = ".".join(
libpantab.escape_sql_identifier(x) for x in table
) # check for injection
else:
tbl = libpantab.escape_sql_identifier(table)

query = f"SELECT * FROM {tbl}"
return frame_from_hyper_query(
source, query, return_type=return_type, process_params=process_params
)
Expand All @@ -65,17 +68,7 @@ def frames_from_hyper(
"""See api.rst for documentation."""
result = {}

table_names = []
with tempfile.TemporaryDirectory() as tmp_dir, tab_api.HyperProcess(
tab_api.Telemetry.DO_NOT_SEND_USAGE_DATA_TO_TABLEAU,
parameters={"log_config": ""},
) as hpe:
tmp_db = shutil.copy(source, tmp_dir)
with tab_api.Connection(hpe.endpoint, tmp_db) as connection:
for schema in connection.catalog.get_schema_names():
for table in connection.catalog.get_table_names(schema=schema):
table_names.append(table)

table_names = libpantab.get_table_names(str(source))
for table in table_names:
result[table] = frame_from_hyper(
source=source,
Expand Down
23 changes: 20 additions & 3 deletions src/pantab/_types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,22 @@
from typing import Union
from typing import Optional, Protocol, Union, runtime_checkable

import tableauhyperapi as tab_api

TableNameType = Union[str, tab_api.Name, tab_api.TableName]
@runtime_checkable
class TableauName(Protocol):
@property
def unescaped(self) -> str:
...


@runtime_checkable
class TableauTableName(Protocol):
@property
def name(self) -> TableauName:
...

@property
def schema_name(self) -> Optional[TableauName]:
...


TableNameType = Union[str, TableauName, TableauTableName]
22 changes: 12 additions & 10 deletions src/pantab/_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
import uuid
from typing import Any, Literal, Optional, Union

import tableauhyperapi as tab_api

import pantab._types as pantab_types
import pantab._types as pt_types
import pantab.libpantab as libpantab


Expand Down Expand Up @@ -52,7 +50,7 @@ def frame_to_hyper(
df,
database: Union[str, pathlib.Path],
*,
table: pantab_types.TableNameType,
table: pt_types.TableNameType,
table_mode: Literal["a", "w"] = "w",
not_null_columns: Optional[set[str]] = None,
json_columns: Optional[set[str]] = None,
Expand All @@ -72,7 +70,7 @@ def frame_to_hyper(


def frames_to_hyper(
dict_of_frames: dict[pantab_types.TableNameType, Any],
dict_of_frames: dict[pt_types.TableNameType, Any],
database: Union[str, pathlib.Path],
*,
table_mode: Literal["a", "w"] = "w",
Expand All @@ -98,12 +96,16 @@ def frames_to_hyper(
if table_mode == "a" and pathlib.Path(database).exists():
shutil.copy(database, tmp_db)

def convert_to_table_name(table: pantab_types.TableNameType):
# nanobind expects a tuple of (schema, table) strings
if isinstance(table, (str, tab_api.Name)) or not table.schema_name:
table = tab_api.TableName("public", table)
def convert_to_table_name(table: pt_types.TableNameType):
if isinstance(table, pt_types.TableauTableName):
if table.schema_name:
return (table.schema_name.name.unescaped, table.name.unescaped)
else:
return table.name.unescaped
elif isinstance(table, pt_types.TableauName):
return table.unescaped

return (table.schema_name.name.unescaped, table.name.unescaped)
return table

data = {
convert_to_table_name(key): _get_capsule_from_obj(val)
Expand Down
54 changes: 48 additions & 6 deletions src/pantab/libpantab.cpp
Original file line number Diff line number Diff line change
@@ -1,17 +1,59 @@
#include <unordered_map>

#include <hyperapi/hyperapi.hpp>
#include <nanobind/nanobind.h>
#include <nanobind/stl/vector.h>

#include "datetime.h"
#include "reader.hpp"
#include "writer.hpp"

namespace nb = nanobind;

NB_MODULE(libpantab, m) { // NOLINT
m.def("write_to_hyper", &write_to_hyper, nb::arg("dict_of_capsules"),
nb::arg("path"), nb::arg("table_mode"), nb::arg("not_null_columns"),
nb::arg("json_columns"), nb::arg("geo_columns"),
nb::arg("process_params"))
m.def("escape_sql_identifier",
[](const nb::str &str) {
const auto required_size =
hyper_quote_sql_identifier(nullptr, 0, str.c_str(), nb::len(str));
std::string result(required_size, 'x');
hyper_quote_sql_identifier(result.data(), required_size, str.c_str(),
nb::len(str));
return result;
})
.def(
"get_table_names",
[](const std::string &path) {
std::unordered_map<std::string, std::string> params{
{"log_config", ""}};
const hyperapi::HyperProcess hyper{
hyperapi::Telemetry::DoNotSendUsageDataToTableau, "",
std::move(params)};
hyperapi::Connection connection(hyper.getEndpoint(), path);

nb::list result;
for (const auto &schema_name :
connection.getCatalog().getSchemaNames()) {
for (const auto &table_name :
connection.getCatalog().getTableNames(schema_name)) {
const auto schema_prefix = table_name.getSchemaName();
if (schema_prefix) {
const auto tup =
nb::make_tuple(schema_prefix->getName().getUnescaped(),
table_name.getName().getUnescaped());
result.append(tup);
} else {
result.append(
nb::str(table_name.getName().getUnescaped().c_str()));
}
}
}

return result;
},
nb::arg("path"))
.def("write_to_hyper", &write_to_hyper, nb::arg("dict_of_capsules"),
nb::arg("path"), nb::arg("table_mode"), nb::arg("not_null_columns"),
nb::arg("json_columns"), nb::arg("geo_columns"),
nb::arg("process_params"))
.def("read_from_hyper_query", &read_from_hyper_query, nb::arg("path"),
nb::arg("query"), nb::arg("process_params"));
PyDateTime_IMPORT;
}
32 changes: 20 additions & 12 deletions src/pantab/writer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -575,13 +575,10 @@ static void AssertColumnsEqual(
}
};

using SchemaAndTableName = std::tuple<std::string, std::string>;

void write_to_hyper(
const std::map<SchemaAndTableName, nb::capsule> &dict_of_capsules,
const std::string &path, const std::string &table_mode,
const nb::iterable not_null_columns, const nb::iterable json_columns,
const nb::iterable geo_columns,
const nb::object &dict_of_capsules, const std::string &path,
const std::string &table_mode, const nb::iterable not_null_columns,
const nb::iterable json_columns, const nb::iterable geo_columns,
std::unordered_map<std::string, std::string> &&process_params) {

std::set<std::string> not_null_set;
Expand Down Expand Up @@ -620,10 +617,8 @@ void write_to_hyper(
hyperapi::Connection connection{hyper.getEndpoint(), path, createMode};
const hyperapi::Catalog &catalog = connection.getCatalog();

for (auto const &[schema_and_table, capsule] : dict_of_capsules) {
const auto hyper_schema = std::get<0>(schema_and_table);
const auto hyper_table = std::get<1>(schema_and_table);

for (auto const &[name, capsule] :
nb::cast<nb::dict>(dict_of_capsules, false)) {
const auto c_stream = static_cast<struct ArrowArrayStream *>(
PyCapsule_GetPointer(capsule.ptr(), "arrow_array_stream"));
if (c_stream == nullptr) {
Expand Down Expand Up @@ -723,9 +718,22 @@ void write_to_hyper(
}
}

const hyperapi::TableName table_name{hyper_schema, hyper_table};
std::tuple<std::string, std::string> schema_and_table;
std::string t_name;
const auto is_tup = nb::try_cast(name, schema_and_table, false);
const auto is_str = nb::try_cast(name, t_name, false);
if (!(is_tup || is_str)) {
throw nb::type_error("Expected string or tuple key");
}
const auto table_name =
is_tup ? hyperapi::TableName(std::get<0>(schema_and_table),
std::get<1>(schema_and_table))
: hyperapi::TableName(t_name);
const hyperapi::TableDefinition table_def{table_name, hyper_columns};
catalog.createSchemaIfNotExists(*table_name.getSchemaName());

const auto schema_name =
table_name.getSchemaName() ? *table_name.getSchemaName() : "public";
catalog.createSchemaIfNotExists(schema_name);

if ((table_mode == "a") && (catalog.hasTable(table_name))) {
const auto existing_def = catalog.getTableDefinition(table_name);
Expand Down
9 changes: 3 additions & 6 deletions src/pantab/writer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,8 @@

namespace nb = nanobind;

using SchemaAndTableName = std::tuple<std::string, std::string>;

void write_to_hyper(
const std::map<SchemaAndTableName, nanobind::capsule> &dict_of_capsules,
const std::string &path, const std::string &table_mode,
const nb::iterable not_null_columns, const nb::iterable json_columns,
const nb::iterable geo_columns,
const nb::object &dict_of_capsules, const std::string &path,
const std::string &table_mode, const nb::iterable not_null_columns,
const nb::iterable json_columns, const nb::iterable geo_columns,
std::unordered_map<std::string, std::string> &&process_params);
22 changes: 21 additions & 1 deletion tests/test_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,11 +156,31 @@ def test_reader_accepts_process_params(tmp_hyper):
pt.frames_from_hyper(tmp_hyper, process_params=params)


def test_reader_invalid_process_params_raises(frame, tmp_hyper):
def test_reader_invalid_process_params_raises(tmp_hyper):
frame = pd.DataFrame(list(range(10)), columns=["nums"]).astype("int8")
pt.frame_to_hyper(frame, tmp_hyper, table="test")

params = {"not_a_real_parameter": "0"}
msg = r"No internal setting named 'not_a_real_parameter'"
with pytest.raises(RuntimeError, match=msg):
pt.frames_from_hyper(tmp_hyper, process_params=params)


@pytest.mark.parametrize(
"table_name",
[
"a';DROP TABLE users;DELETE FROM foo WHERE 't' = 't",
tab_api.Name("a';DROP TABLE users;DELETE FROM foo WHERE 't' = 't"),
tab_api.TableName(
"public", "a';DROP TABLE users;DELETE FROM foo WHERE 't' = 't"
),
tab_api.TableName(
"a';DROP TABLE users;DELETE FROM foo WHERE 't' = 't",
"a';DROP TABLE users;DELETE FROM foo WHERE 't' = 't",
),
],
)
def test_reader_prevents_sql_injection(tmp_hyper, table_name):
frame = pd.DataFrame(list(range(10)), columns=["nums"]).astype("int8")
pt.frame_to_hyper(frame, tmp_hyper, table=table_name)
pt.frame_from_hyper(tmp_hyper, table=table_name)
51 changes: 46 additions & 5 deletions tests/test_roundtrip.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import sys

import pandas as pd
import pyarrow as pa
from tableauhyperapi import TableName
import pytest
import tableauhyperapi as tab_api

import pantab as pt

Expand Down Expand Up @@ -70,10 +74,15 @@ def test_multiple_tables(
expected = compat.concat_frames(expected, expected)

# some test trickery here
if not isinstance(table_name, TableName) or table_name.schema_name is None:
table_name = TableName("public", table_name)

assert set(result.keys()) == set((table_name, TableName("public", "table2")))
if not isinstance(table_name, tab_api.TableName) or table_name.schema_name is None:
table_name = tab_api.TableName("public", table_name)

assert set(result.keys()) == set(
(
tuple(table_name._unescaped_components),
tuple(tab_api.TableName("public", "table2")._unescaped_components),
)
)
for val in result.values():
compat.assert_frame_equal(val, expected)

Expand Down Expand Up @@ -113,3 +122,35 @@ def test_empty_roundtrip(
expected = compat.drop_columns(expected, ["object"])
expected = compat.empty_like(expected)
compat.assert_frame_equal(result, expected)


@pytest.mark.parametrize(
"table_name",
[
"a';DROP TABLE users;DELETE FROM foo WHERE 't' = 't",
tab_api.Name("a';DROP TABLE users;DELETE FROM foo WHERE 't' = 't"),
tab_api.TableName(
"public", "a';DROP TABLE users;DELETE FROM foo WHERE 't' = 't"
),
tab_api.TableName(
"a';DROP TABLE users;DELETE FROM foo WHERE 't' = 't",
"a';DROP TABLE users;DELETE FROM foo WHERE 't' = 't",
),
],
)
def test_write_prevents_injection(tmp_hyper, table_name):
frame = pd.DataFrame(list(range(10)), columns=["nums"]).astype("int8")
frames = {table_name: frame}
pt.frames_to_hyper(frames, tmp_hyper)
pt.frames_from_hyper(tmp_hyper)


def test_roundtrip_works_without_tableauhyperapi(frame, tmp_hyper, monkeypatch):
libname = "tableauhyperapi"
mods = set(sys.modules.keys())
for mod in mods:
if mod.startswith(libname):
monkeypatch.delitem(sys.modules, mod)

pt.frame_to_hyper(frame, tmp_hyper, table="foo")
pt.frames_from_hyper(tmp_hyper)

0 comments on commit 5bb265c

Please sign in to comment.