diff --git a/CHANGELOG.md b/CHANGELOG.md index 6679df9..5c22e4a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ - Adds a helper function to retrieve the Iceberg catalog namespace from the profile.yaml file. - Adds merge_exclude_columns and incremental_predicates features. - Drop Python 3.8 support +- Added configuration property to allow spark casting of seed column types ## v1.8.6 - Fix session provisioning timeout and delay handling diff --git a/README.md b/README.md index 7ef9df8..31ed7b2 100644 --- a/README.md +++ b/README.md @@ -244,6 +244,7 @@ The table below describes all the options. | glue_session_reuse | re-use the glue-session to run multiple dbt run commands: If set to true, the glue session will not be closed for re-use. If set to false, the session will be closed. The glue session will close after idle_timeout time is expired after idle_timeout time | no | | datalake_formats | The ACID datalake format that you want to use if you are doing merge, can be `hudi`, `ìceberg` or `delta` |no| | use_arrow | (experimental) use an arrow file instead of stdout to have better scalability. |no| +| enable_spark_seed_casting | Allows spark to cast the columns depending on the specified model column types. Default `False`. |no| ## Configs diff --git a/dbt/adapters/glue/credentials.py b/dbt/adapters/glue/credentials.py index 0861832..c9a7705 100644 --- a/dbt/adapters/glue/credentials.py +++ b/dbt/adapters/glue/credentials.py @@ -3,9 +3,11 @@ from dbt.adapters.contracts.connection import Credentials from dbt_common.exceptions import DbtRuntimeError + @dataclass class GlueCredentials(Credentials): - """ Required connections for a Glue connection""" + """Required connections for a Glue connection""" + role_arn: Optional[str] = None # type: ignore region: Optional[str] = None # type: ignore workers: Optional[int] = None # type: ignore @@ -36,6 +38,7 @@ class GlueCredentials(Credentials): enable_session_per_model: Optional[bool] = False use_arrow: Optional[bool] = False custom_iceberg_catalog_namespace: Optional[str] = "glue_catalog" + enable_spark_seed_casting: Optional[bool] = False @property def type(self): @@ -64,34 +67,35 @@ def __post_init__(self): self.database = None def _connection_keys(self): - """ Keys to show when debugging """ + """Keys to show when debugging""" return [ - 'role_arn', - 'region', - 'workers', - 'worker_type', - 'session_provisioning_timeout_in_seconds', - 'schema', - 'location', - 'extra_jars', - 'idle_timeout', - 'query_timeout_in_minutes', - 'glue_version', - 'security_configuration', - 'connections', - 'conf', - 'extra_py_files', - 'delta_athena_prefix', - 'tags', - 'seed_format', - 'seed_mode', - 'default_arguments', - 'iceberg_glue_commit_lock_table', - 'use_interactive_session_role_for_api_calls', - 'lf_tags', - 'glue_session_id', - 'glue_session_reuse', - 'datalake_formats', - 'enable_session_per_model', - 'use_arrow' + "role_arn", + "region", + "workers", + "worker_type", + "session_provisioning_timeout_in_seconds", + "schema", + "location", + "extra_jars", + "idle_timeout", + "query_timeout_in_minutes", + "glue_version", + "security_configuration", + "connections", + "conf", + "extra_py_files", + "delta_athena_prefix", + "tags", + "seed_format", + "seed_mode", + "default_arguments", + "iceberg_glue_commit_lock_table", + "use_interactive_session_role_for_api_calls", + "lf_tags", + "glue_session_id", + "glue_session_reuse", + "datalake_formats", + "enable_session_per_model", + "use_arrow", + "enable_spark_seed_casting", ] diff --git a/dbt/adapters/glue/impl.py b/dbt/adapters/glue/impl.py index b36fdd0..29fdb68 100644 --- a/dbt/adapters/glue/impl.py +++ b/dbt/adapters/glue/impl.py @@ -27,11 +27,52 @@ from dbt_common.exceptions import DbtDatabaseError, CompilationError from dbt.adapters.base.impl import catch_as_completed from dbt_common.utils import executor +from dbt_common.clients import agate_helper from dbt.adapters.events.logging import AdapterLogger logger = AdapterLogger("Glue") +class ColumnCsvMappingStrategy: + _schema_mappings = { + agate_helper.ISODateTime: 'string', + agate_helper.Number: 'double', + agate_helper.Integer: 'int', + agate.data_types.Boolean: 'boolean', + agate.data_types.Date: 'string', + agate.data_types.DateTime: 'string', + agate.data_types.Text: 'string', + } + + _cast_mappings = { + agate_helper.ISODateTime: 'timestamp', + agate.data_types.Date: 'date', + agate.data_types.DateTime: 'timestamp', + } + + def __init__(self, column_name, agate_type, specified_type): + self.column_name = column_name + self.agate_type = agate_type + self.specified_type = specified_type + + def as_schema_value(self): + return ColumnCsvMappingStrategy._schema_mappings.get(self.agate_type) + + def as_cast_value(self): + return ( + self.specified_type if self.specified_type else ColumnCsvMappingStrategy._cast_mappings.get(self.agate_type) + ) + + @classmethod + def from_model(cls, model, agate_table): + return [ + ColumnCsvMappingStrategy( + column.name, type(column.data_type), model.get("config", {}).get("column_types", {}).get(column.name) + ) + for column in agate_table.columns + ] + + class GlueAdapter(SQLAdapter): ConnectionManager = GlueConnectionManager Relation = SparkRelation @@ -519,7 +560,7 @@ def create_csv_table(self, model, agate_table): mode = "False" csv_chunks = self._split_csv_records_into_chunks(json.loads(f.getvalue())) - statements = self._map_csv_chunks_to_code(csv_chunks, session, model, mode) + statements = self._map_csv_chunks_to_code(csv_chunks, session, model, mode, ColumnCsvMappingStrategy.from_model(model, agate_table)) try: cursor = session.cursor() for statement in statements: @@ -529,7 +570,14 @@ def create_csv_table(self, model, agate_table): except Exception as e: logger.error(e) - def _map_csv_chunks_to_code(self, csv_chunks: List[List[dict]], session: GlueConnection, model, mode): + def _map_csv_chunks_to_code( + self, + csv_chunks: List[List[dict]], + session: GlueConnection, + model, + mode, + column_mappings: List[ColumnCsvMappingStrategy], + ): statements = [] for i, csv_chunk in enumerate(csv_chunks): is_first = i == 0 @@ -548,8 +596,31 @@ def _map_csv_chunks_to_code(self, csv_chunks: List[List[dict]], session: GlueCon SqlWrapper2.execute("""select 1""") ''' else: - code += f''' + if session.credentials.enable_spark_seed_casting: + csv_schema = ", ".join( + [f"{mapping.column_name}: {mapping.as_schema_value()}" for mapping in column_mappings] + ) + + cast_code = ".".join( + [ + "df", + *[ + f'withColumn("{mapping.column_name}", df.{mapping.column_name}.cast("{cast_value}"))' + for mapping in column_mappings + if (cast_value := mapping.as_cast_value()) + ], + ] + ) + + code += f""" +df = spark.createDataFrame(csv, "{csv_schema}") +df = {cast_code} +""" + else: + code += """ df = spark.createDataFrame(csv) +""" + code += f''' table_name = '{model["schema"]}.{model["name"]}' if (spark.sql("show tables in {model["schema"]}").where("tableName == lower('{model["name"]}')").count() > 0): df.write\ diff --git a/tests/unit/test_adapter.py b/tests/unit/test_adapter.py index d6a7c89..0ad070f 100644 --- a/tests/unit/test_adapter.py +++ b/tests/unit/test_adapter.py @@ -2,7 +2,9 @@ import unittest from unittest import mock from unittest.mock import Mock +import pytest from multiprocessing import get_context +import agate.data_types from botocore.client import BaseClient from moto import mock_aws @@ -13,6 +15,8 @@ from dbt.adapters.glue import GlueAdapter from dbt.adapters.glue.gluedbapi import GlueConnection from dbt.adapters.glue.relation import SparkRelation +from dbt.adapters.glue.impl import ColumnCsvMappingStrategy +from dbt_common.clients import agate_helper from tests.util import config_from_parts_or_dicts from .util import MockAWSService @@ -41,7 +45,7 @@ def setUp(self): "region": "us-east-1", "workers": 2, "worker_type": "G.1X", - "location" : "path_to_location/", + "location": "path_to_location/", "schema": "dbt_unit_test_01", "database": "dbt_unit_test_01", "use_interactive_session_role_for_api_calls": False, @@ -71,7 +75,6 @@ def test_glue_connection(self): self.assertIsNotNone(connection.handle) self.assertIsInstance(glueSession.client, BaseClient) - @mock_aws def test_get_table_type(self): config = self._get_config() @@ -96,8 +99,10 @@ def test_create_csv_table_slices_big_datasets(self): adapter = GlueAdapter(config, get_context("spawn")) model = {"name": "mock_model", "schema": "mock_schema"} session_mock = Mock() - adapter.get_connection = lambda: (session_mock, 'mock_client') - test_table = agate.Table([(f'mock_value_{i}',f'other_mock_value_{i}') for i in range(2000)], column_names=['value', 'other_value']) + adapter.get_connection = lambda: (session_mock, "mock_client") + test_table = agate.Table( + [(f"mock_value_{i}", f"other_mock_value_{i}") for i in range(2000)], column_names=["value", "other_value"] + ) adapter.create_csv_table(model, test_table) # test table is between 120000 and 180000 characters so it should be split three times (max chunk is 60000) @@ -115,11 +120,95 @@ def test_get_location(self): connection.handle # trigger lazy-load print(adapter.get_location(relation)) self.assertEqual(adapter.get_location(relation), "LOCATION 'path_to_location/some_database/some_table'") - + def test_get_custom_iceberg_catalog_namespace(self): config = self._get_config() adapter = GlueAdapter(config, get_context("spawn")) with mock.patch("dbt.adapters.glue.connections.open"): connection = adapter.acquire_connection("dummy") connection.handle # trigger lazy-load - self.assertEqual(adapter.get_custom_iceberg_catalog_namespace(), "custom_iceberg_catalog") \ No newline at end of file + self.assertEqual(adapter.get_custom_iceberg_catalog_namespace(), "custom_iceberg_catalog") + + def test_create_csv_table_provides_schema_and_casts_when_spark_seed_cast_is_enabled(self): + config = self._get_config() + config.credentials.enable_spark_seed_casting = True + adapter = GlueAdapter(config, get_context("spawn")) + csv_chunks = [{"test_column_double": "1.2345", "test_column_str": "test"}] + model = { + "name": "mock_model", + "schema": "mock_schema", + "config": {"column_types": {"test_column_double": "double", "test_column_str": "string"}}, + } + column_mappings = [ + ColumnCsvMappingStrategy("test_column_double", agate.data_types.Text, "double"), + ColumnCsvMappingStrategy("test_column_str", agate.data_types.Text, "string"), + ] + code = adapter._map_csv_chunks_to_code(csv_chunks, config, model, "True", column_mappings) + self.assertIn('spark.createDataFrame(csv, "test_column_double: string, test_column_str: string")', code[0]) + self.assertIn( + 'df = df.withColumn("test_column_double", df.test_column_double.cast("double"))' + + '.withColumn("test_column_str", df.test_column_str.cast("string"))', + code[0], + ) + + def test_create_csv_table_doesnt_provide_schema_when_spark_seed_cast_is_disabled(self): + config = self._get_config() + config.credentials.enable_spark_seed_casting = False + adapter = GlueAdapter(config, get_context("spawn")) + csv_chunks = [{"test_column": "1.2345"}] + model = {"name": "mock_model", "schema": "mock_schema"} + column_mappings = [ColumnCsvMappingStrategy("test_column", agate.data_types.Text, "double")] + code = adapter._map_csv_chunks_to_code(csv_chunks, config, model, "True", column_mappings) + self.assertIn("spark.createDataFrame(csv)", code[0]) + + +class TestCsvMappingStrategy: + @pytest.mark.parametrize( + "agate_type,specified_type,expected_schema_type,expected_cast_type", + [ + (agate_helper.ISODateTime, None, "string", "timestamp"), + (agate_helper.Number, None, "double", None), + (agate_helper.Integer, None, "int", None), + (agate.data_types.Boolean, None, "boolean", None), + (agate.data_types.Date, None, "string", "date"), + (agate.data_types.DateTime, None, "string", "timestamp"), + (agate.data_types.Text, None, "string", None), + (agate.data_types.Text, "double", "string", "double"), + ], + ids=[ + "test isodatetime cast", + "test number cast", + "test integer cast", + "test boolean cast", + "test date cast", + "test datetime cast", + "test text cast", + "test specified cast", + ], + ) + def test_mapping_strategy_provides_proper_mappings( + self, agate_type, specified_type, expected_schema_type, expected_cast_type + ): + column_mapping = ColumnCsvMappingStrategy("test_column", agate_type, specified_type) + assert column_mapping.as_schema_value() == expected_schema_type + assert column_mapping.as_cast_value() == expected_cast_type + + def test_from_model_builds_column_mappings(self): + expected_column_names = ["col_int", "col_str", "col_date", "col_specific"] + expected_agate_types = [ + agate_helper.Integer, + agate.data_types.Text, + agate.data_types.Date, + agate.data_types.Text, + ] + expected_specified_types = [None, None, None, "double"] + agate_table = agate.Table( + [(111, "str_val", "2024-01-01", "1.234")], + column_names=expected_column_names, + column_types=[data_type() for data_type in expected_agate_types], + ) + model = {"name": "mock_model", "config": {"column_types": {"col_specific": "double"}}} + mappings = ColumnCsvMappingStrategy.from_model(model, agate_table) + assert expected_column_names == [mapping.column_name for mapping in mappings] + assert expected_agate_types == [mapping.agate_type for mapping in mappings] + assert expected_specified_types == [mapping.specified_type for mapping in mappings]