From 17494cbd5c67605cd3feff594f5f98cfb38c553f Mon Sep 17 00:00:00 2001 From: Jaume Sanjuan Date: Fri, 4 Oct 2024 20:04:59 +0200 Subject: [PATCH 1/3] allow cast on seeds --- dbt/adapters/glue/credentials.py | 5 ++- dbt/adapters/glue/impl.py | 77 ++++++++++++++++++++++++++++++-- tests/unit/test_adapter.py | 71 ++++++++++++++++++++++++++++- 3 files changed, 147 insertions(+), 6 deletions(-) diff --git a/dbt/adapters/glue/credentials.py b/dbt/adapters/glue/credentials.py index d94929b..a272e7c 100644 --- a/dbt/adapters/glue/credentials.py +++ b/dbt/adapters/glue/credentials.py @@ -35,7 +35,7 @@ class GlueCredentials(Credentials): datalake_formats: Optional[str] = None enable_session_per_model: Optional[bool] = False use_arrow: Optional[bool] = False - + enable_spark_seed_casting: Optional[bool] = False @property def type(self): @@ -93,5 +93,6 @@ def _connection_keys(self): 'glue_session_reuse', 'datalake_formats', 'enable_session_per_model', - 'use_arrow' + 'use_arrow', + 'enable_spark_seed_casting', ] diff --git a/dbt/adapters/glue/impl.py b/dbt/adapters/glue/impl.py index 2ba68cf..d5f382f 100644 --- a/dbt/adapters/glue/impl.py +++ b/dbt/adapters/glue/impl.py @@ -27,11 +27,52 @@ from dbt_common.exceptions import DbtDatabaseError, CompilationError from dbt.adapters.base.impl import catch_as_completed from dbt_common.utils import executor +from dbt_common.clients import agate_helper from dbt.adapters.events.logging import AdapterLogger logger = AdapterLogger("Glue") +class ColumnCsvMappingStrategy: + _schema_mappings = { + agate_helper.ISODateTime: 'string', + agate_helper.Number: 'double', + agate_helper.Integer: 'int', + agate.data_types.Boolean: 'boolean', + agate.data_types.Date: 'string', + agate.data_types.DateTime: 'string', + agate.data_types.Text: 'string', + } + + _cast_mappings = { + agate_helper.ISODateTime: 'timestamp', + agate.data_types.Date: 'date', + agate.data_types.DateTime: 'timestamp', + } + + def __init__(self, column_name, agate_type, specified_type): + self.column_name = column_name + self.agate_type = agate_type + self.specified_type = specified_type + + def as_schema_value(self): + return ColumnCsvMappingStrategy._schema_mappings.get(self.agate_type) + + def as_cast_value(self): + return ( + self.specified_type if self.specified_type else ColumnCsvMappingStrategy._cast_mappings.get(self.agate_type) + ) + + @classmethod + def from_model(cls, model, agate_table): + return [ + ColumnCsvMappingStrategy( + column.name, type(column.data_type), model.get("config", {}).get("column_types", {}).get(column.name) + ) + for column in agate_table.columns + ] + + class GlueAdapter(SQLAdapter): ConnectionManager = GlueConnectionManager Relation = SparkRelation @@ -535,7 +576,7 @@ def create_csv_table(self, model, agate_table): mode = "False" csv_chunks = self._split_csv_records_into_chunks(json.loads(f.getvalue())) - statements = self._map_csv_chunks_to_code(csv_chunks, session, model, mode) + statements = self._map_csv_chunks_to_code(csv_chunks, session, model, mode, ColumnCsvMappingStrategy.from_model(model, agate_table)) try: cursor = session.cursor() for statement in statements: @@ -545,7 +586,14 @@ def create_csv_table(self, model, agate_table): except Exception as e: logger.error(e) - def _map_csv_chunks_to_code(self, csv_chunks: List[List[dict]], session: GlueConnection, model, mode): + def _map_csv_chunks_to_code( + self, + csv_chunks: List[List[dict]], + session: GlueConnection, + model, + mode, + column_mappings: List[ColumnCsvMappingStrategy], + ): statements = [] for i, csv_chunk in enumerate(csv_chunks): is_first = i == 0 @@ -564,8 +612,31 @@ def _map_csv_chunks_to_code(self, csv_chunks: List[List[dict]], session: GlueCon SqlWrapper2.execute("""select 1""") ''' else: - code += f''' + if session.credentials.enable_spark_seed_casting: + csv_schema = ", ".join( + [f"{mapping.column_name}: {mapping.as_schema_value()}" for mapping in column_mappings] + ) + + cast_code = ".".join( + [ + "df", + *[ + f'withColumn("{mapping.column_name}", df.{mapping.column_name}.cast("{cast_value}"))' + for mapping in column_mappings + if (cast_value := mapping.as_cast_value()) + ], + ] + ) + + code += f""" +df = spark.createDataFrame(csv, "{csv_schema}") +df = {cast_code} +""" + else: + code += """ df = spark.createDataFrame(csv) +""" + code += f''' table_name = '{model["schema"]}.{model["name"]}' if (spark.sql("show tables in {model["schema"]}").where("tableName == lower('{model["name"]}')").count() > 0): df.write\ diff --git a/tests/unit/test_adapter.py b/tests/unit/test_adapter.py index cd229b6..174e15a 100644 --- a/tests/unit/test_adapter.py +++ b/tests/unit/test_adapter.py @@ -2,7 +2,9 @@ import unittest from unittest import mock from unittest.mock import Mock +import pytest from multiprocessing import get_context +import agate.data_types from botocore.client import BaseClient from moto import mock_aws @@ -13,6 +15,8 @@ from dbt.adapters.glue import GlueAdapter from dbt.adapters.glue.gluedbapi import GlueConnection from dbt.adapters.glue.relation import SparkRelation +from dbt.adapters.glue.impl import ColumnCsvMappingStrategy +from dbt_common.clients import agate_helper from tests.util import config_from_parts_or_dicts from .util import MockAWSService @@ -99,4 +103,69 @@ def test_create_csv_table_slices_big_datasets(self): adapter.create_csv_table(model, test_table) # test table is between 120000 and 180000 characters so it should be split three times (max chunk is 60000) - self.assertEqual(session_mock.cursor().execute.call_count, 3) \ No newline at end of file + self.assertEqual(session_mock.cursor().execute.call_count, 3) + + def test_create_csv_table_provides_schema_and_casts_when_spark_seed_cast_is_enabled(self): + config = self._get_config() + config.credentials.enable_spark_seed_casting = True + adapter = GlueAdapter(config, get_context("spawn")) + csv_chunks = [{'test_column': '1.2345'}] + model = {"name": "mock_model", "schema": "mock_schema", "config": {"column_types": {"test_column": "double"}}} + column_mappings = [ColumnCsvMappingStrategy('test_column', agate.data_types.Text, 'double')] + code = adapter._map_csv_chunks_to_code(csv_chunks, config, model, 'True', column_mappings) + self.assertIn('spark.createDataFrame(csv, "test_column: string")', code[0]) + self.assertIn('df = df.withColumn("test_column", df.test_column.cast("double"))', code[0]) + + def test_create_csv_table_doesnt_provide_schema_when_spark_seed_cast_is_disabled(self): + config = self._get_config() + config.credentials.enable_spark_seed_casting = False + adapter = GlueAdapter(config, get_context("spawn")) + csv_chunks = [{'test_column': '1.2345'}] + model = {"name": "mock_model", "schema": "mock_schema"} + column_mappings = [ColumnCsvMappingStrategy('test_column', agate.data_types.Text, 'double')] + code = adapter._map_csv_chunks_to_code(csv_chunks, config, model, 'True', column_mappings) + self.assertIn('spark.createDataFrame(csv)', code[0]) + +class TestCsvMappingStrategy: + @pytest.mark.parametrize( + 'agate_type,specified_type,expected_schema_type,expected_cast_type', + [ + (agate_helper.ISODateTime, None, 'string', 'timestamp'), + (agate_helper.Number, None, 'double', None), + (agate_helper.Integer, None, 'int', None), + (agate.data_types.Boolean, None, 'boolean', None), + (agate.data_types.Date, None, 'string', 'date'), + (agate.data_types.DateTime, None, 'string', 'timestamp'), + (agate.data_types.Text, None, 'string', None), + (agate.data_types.Text, 'double', 'string', 'double'), + ], + ids=[ + 'test isodatetime cast', + 'test number cast', + 'test integer cast', + 'test boolean cast', + 'test date cast', + 'test datetime cast', + 'test text cast', + 'test specified cast', + ] + ) + def test_mapping_strategy_provides_proper_mappings(self, agate_type, specified_type, expected_schema_type, expected_cast_type): + column_mapping = ColumnCsvMappingStrategy('test_column', agate_type, specified_type) + assert column_mapping.as_schema_value() == expected_schema_type + assert column_mapping.as_cast_value() == expected_cast_type + + def test_from_model_builds_column_mappings(self): + expected_column_names = ['col_int', 'col_str', 'col_date', 'col_specific'] + expected_agate_types = [agate_helper.Integer,agate.data_types.Text, agate.data_types.Date, agate.data_types.Text] + expected_specified_types = [None, None, None, 'double'] + agate_table = agate.Table( + [(111,'str_val','2024-01-01', '1.234')], + column_names=expected_column_names, + column_types=[data_type() for data_type in expected_agate_types] + ) + model = {"name": "mock_model", "config": {"column_types": {"col_specific": "double"}}} + mappings = ColumnCsvMappingStrategy.from_model(model, agate_table) + assert expected_column_names == [mapping.column_name for mapping in mappings] + assert expected_agate_types == [mapping.agate_type for mapping in mappings] + assert expected_specified_types == [mapping.specified_type for mapping in mappings] \ No newline at end of file From 3fc70eed43f7d077e561ee235e792ff4e1372bfe Mon Sep 17 00:00:00 2001 From: Jaume Sanjuan Date: Mon, 7 Oct 2024 10:44:44 +0200 Subject: [PATCH 2/3] Update changelog and readme --- CHANGELOG.md | 1 + README.md | 1 + 2 files changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index f797ee2..6f56ffb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,6 @@ ## New version - Allow to load big seed files +- Added configuration property to allow spark casting of seed column types ## v1.8.6 - Fix session provisioning timeout and delay handling diff --git a/README.md b/README.md index 9ed7798..b16d6c3 100644 --- a/README.md +++ b/README.md @@ -244,6 +244,7 @@ The table below describes all the options. | glue_session_reuse | re-use the glue-session to run multiple dbt run commands: If set to true, the glue session will not be closed for re-use. If set to false, the session will be closed. The glue session will close after idle_timeout time is expired after idle_timeout time | no | | datalake_formats | The ACID datalake format that you want to use if you are doing merge, can be `hudi`, `ìceberg` or `delta` |no| | use_arrow | (experimental) use an arrow file instead of stdout to have better scalability. |no| +| enable_spark_seed_casting | Allows spark to cast the columns depending on the specified model column types. Default `False`. |no| ## Configs From 1171ef656bfa6c644539fff0a78ec2b66c56b597 Mon Sep 17 00:00:00 2001 From: Jaume Sanjuan Date: Thu, 31 Oct 2024 15:46:24 +0100 Subject: [PATCH 3/3] extend seed casting unit test --- tests/unit/test_adapter.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/tests/unit/test_adapter.py b/tests/unit/test_adapter.py index ba4b661..0ad070f 100644 --- a/tests/unit/test_adapter.py +++ b/tests/unit/test_adapter.py @@ -133,12 +133,23 @@ def test_create_csv_table_provides_schema_and_casts_when_spark_seed_cast_is_enab config = self._get_config() config.credentials.enable_spark_seed_casting = True adapter = GlueAdapter(config, get_context("spawn")) - csv_chunks = [{"test_column": "1.2345"}] - model = {"name": "mock_model", "schema": "mock_schema", "config": {"column_types": {"test_column": "double"}}} - column_mappings = [ColumnCsvMappingStrategy("test_column", agate.data_types.Text, "double")] + csv_chunks = [{"test_column_double": "1.2345", "test_column_str": "test"}] + model = { + "name": "mock_model", + "schema": "mock_schema", + "config": {"column_types": {"test_column_double": "double", "test_column_str": "string"}}, + } + column_mappings = [ + ColumnCsvMappingStrategy("test_column_double", agate.data_types.Text, "double"), + ColumnCsvMappingStrategy("test_column_str", agate.data_types.Text, "string"), + ] code = adapter._map_csv_chunks_to_code(csv_chunks, config, model, "True", column_mappings) - self.assertIn('spark.createDataFrame(csv, "test_column: string")', code[0]) - self.assertIn('df = df.withColumn("test_column", df.test_column.cast("double"))', code[0]) + self.assertIn('spark.createDataFrame(csv, "test_column_double: string, test_column_str: string")', code[0]) + self.assertIn( + 'df = df.withColumn("test_column_double", df.test_column_double.cast("double"))' + + '.withColumn("test_column_str", df.test_column_str.cast("string"))', + code[0], + ) def test_create_csv_table_doesnt_provide_schema_when_spark_seed_cast_is_disabled(self): config = self._get_config()