-
Notifications
You must be signed in to change notification settings - Fork 69
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow seed type casting #459
base: main
Are you sure you want to change the base?
Changes from all commits
17494cb
3fc70ee
7c98bbc
e79f7a8
1171ef6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,9 @@ | |
import unittest | ||
from unittest import mock | ||
from unittest.mock import Mock | ||
import pytest | ||
from multiprocessing import get_context | ||
import agate.data_types | ||
from botocore.client import BaseClient | ||
from moto import mock_aws | ||
|
||
|
@@ -13,6 +15,8 @@ | |
from dbt.adapters.glue import GlueAdapter | ||
from dbt.adapters.glue.gluedbapi import GlueConnection | ||
from dbt.adapters.glue.relation import SparkRelation | ||
from dbt.adapters.glue.impl import ColumnCsvMappingStrategy | ||
from dbt_common.clients import agate_helper | ||
from tests.util import config_from_parts_or_dicts | ||
from .util import MockAWSService | ||
|
||
|
@@ -41,7 +45,7 @@ def setUp(self): | |
"region": "us-east-1", | ||
"workers": 2, | ||
"worker_type": "G.1X", | ||
"location" : "path_to_location/", | ||
"location": "path_to_location/", | ||
"schema": "dbt_unit_test_01", | ||
"database": "dbt_unit_test_01", | ||
"use_interactive_session_role_for_api_calls": False, | ||
|
@@ -71,7 +75,6 @@ def test_glue_connection(self): | |
self.assertIsNotNone(connection.handle) | ||
self.assertIsInstance(glueSession.client, BaseClient) | ||
|
||
|
||
@mock_aws | ||
def test_get_table_type(self): | ||
config = self._get_config() | ||
|
@@ -96,8 +99,10 @@ def test_create_csv_table_slices_big_datasets(self): | |
adapter = GlueAdapter(config, get_context("spawn")) | ||
model = {"name": "mock_model", "schema": "mock_schema"} | ||
session_mock = Mock() | ||
adapter.get_connection = lambda: (session_mock, 'mock_client') | ||
test_table = agate.Table([(f'mock_value_{i}',f'other_mock_value_{i}') for i in range(2000)], column_names=['value', 'other_value']) | ||
adapter.get_connection = lambda: (session_mock, "mock_client") | ||
test_table = agate.Table( | ||
[(f"mock_value_{i}", f"other_mock_value_{i}") for i in range(2000)], column_names=["value", "other_value"] | ||
) | ||
adapter.create_csv_table(model, test_table) | ||
|
||
# test table is between 120000 and 180000 characters so it should be split three times (max chunk is 60000) | ||
|
@@ -115,11 +120,95 @@ def test_get_location(self): | |
connection.handle # trigger lazy-load | ||
print(adapter.get_location(relation)) | ||
self.assertEqual(adapter.get_location(relation), "LOCATION 'path_to_location/some_database/some_table'") | ||
|
||
def test_get_custom_iceberg_catalog_namespace(self): | ||
config = self._get_config() | ||
adapter = GlueAdapter(config, get_context("spawn")) | ||
with mock.patch("dbt.adapters.glue.connections.open"): | ||
connection = adapter.acquire_connection("dummy") | ||
connection.handle # trigger lazy-load | ||
self.assertEqual(adapter.get_custom_iceberg_catalog_namespace(), "custom_iceberg_catalog") | ||
self.assertEqual(adapter.get_custom_iceberg_catalog_namespace(), "custom_iceberg_catalog") | ||
|
||
def test_create_csv_table_provides_schema_and_casts_when_spark_seed_cast_is_enabled(self): | ||
config = self._get_config() | ||
config.credentials.enable_spark_seed_casting = True | ||
adapter = GlueAdapter(config, get_context("spawn")) | ||
csv_chunks = [{"test_column_double": "1.2345", "test_column_str": "test"}] | ||
model = { | ||
"name": "mock_model", | ||
"schema": "mock_schema", | ||
"config": {"column_types": {"test_column_double": "double", "test_column_str": "string"}}, | ||
} | ||
column_mappings = [ | ||
ColumnCsvMappingStrategy("test_column_double", agate.data_types.Text, "double"), | ||
ColumnCsvMappingStrategy("test_column_str", agate.data_types.Text, "string"), | ||
] | ||
code = adapter._map_csv_chunks_to_code(csv_chunks, config, model, "True", column_mappings) | ||
self.assertIn('spark.createDataFrame(csv, "test_column_double: string, test_column_str: string")', code[0]) | ||
self.assertIn( | ||
'df = df.withColumn("test_column_double", df.test_column_double.cast("double"))' | ||
+ '.withColumn("test_column_str", df.test_column_str.cast("string"))', | ||
code[0], | ||
) | ||
|
||
def test_create_csv_table_doesnt_provide_schema_when_spark_seed_cast_is_disabled(self): | ||
config = self._get_config() | ||
config.credentials.enable_spark_seed_casting = False | ||
adapter = GlueAdapter(config, get_context("spawn")) | ||
csv_chunks = [{"test_column": "1.2345"}] | ||
model = {"name": "mock_model", "schema": "mock_schema"} | ||
column_mappings = [ColumnCsvMappingStrategy("test_column", agate.data_types.Text, "double")] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shall we add another case to test multiple mapping strategies? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I extended the test case to include different strategies There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you. |
||
code = adapter._map_csv_chunks_to_code(csv_chunks, config, model, "True", column_mappings) | ||
self.assertIn("spark.createDataFrame(csv)", code[0]) | ||
|
||
|
||
class TestCsvMappingStrategy: | ||
@pytest.mark.parametrize( | ||
"agate_type,specified_type,expected_schema_type,expected_cast_type", | ||
[ | ||
(agate_helper.ISODateTime, None, "string", "timestamp"), | ||
(agate_helper.Number, None, "double", None), | ||
(agate_helper.Integer, None, "int", None), | ||
(agate.data_types.Boolean, None, "boolean", None), | ||
(agate.data_types.Date, None, "string", "date"), | ||
(agate.data_types.DateTime, None, "string", "timestamp"), | ||
(agate.data_types.Text, None, "string", None), | ||
(agate.data_types.Text, "double", "string", "double"), | ||
], | ||
ids=[ | ||
"test isodatetime cast", | ||
"test number cast", | ||
"test integer cast", | ||
"test boolean cast", | ||
"test date cast", | ||
"test datetime cast", | ||
"test text cast", | ||
"test specified cast", | ||
], | ||
) | ||
def test_mapping_strategy_provides_proper_mappings( | ||
self, agate_type, specified_type, expected_schema_type, expected_cast_type | ||
): | ||
column_mapping = ColumnCsvMappingStrategy("test_column", agate_type, specified_type) | ||
assert column_mapping.as_schema_value() == expected_schema_type | ||
assert column_mapping.as_cast_value() == expected_cast_type | ||
|
||
def test_from_model_builds_column_mappings(self): | ||
expected_column_names = ["col_int", "col_str", "col_date", "col_specific"] | ||
expected_agate_types = [ | ||
agate_helper.Integer, | ||
agate.data_types.Text, | ||
agate.data_types.Date, | ||
agate.data_types.Text, | ||
] | ||
expected_specified_types = [None, None, None, "double"] | ||
agate_table = agate.Table( | ||
[(111, "str_val", "2024-01-01", "1.234")], | ||
column_names=expected_column_names, | ||
column_types=[data_type() for data_type in expected_agate_types], | ||
) | ||
model = {"name": "mock_model", "config": {"column_types": {"col_specific": "double"}}} | ||
mappings = ColumnCsvMappingStrategy.from_model(model, agate_table) | ||
assert expected_column_names == [mapping.column_name for mapping in mappings] | ||
assert expected_agate_types == [mapping.agate_type for mapping in mappings] | ||
assert expected_specified_types == [mapping.specified_type for mapping in mappings] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
withColumn works well for simple pattern, but if we have many withColumn, it easily causes StackOverflowException.
Do you think if this happens in real world use case? If yes, it will be safer to replace withColumn with select..as.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd say it depends, how many withColumns would cause a SO? If it's over 100 then I'd say it's not realistic for any of the use cases we're handling. However, I can look into implementing it with "select..as", but I'm unfamiliar on that syntax.
Do you mean something like this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right.
df.select()
ordf.selectExpr()
work here.The condition depends on multiple factors, but sometimes it can occur with less than 100.