Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow seed type casting #459

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
- Adds a helper function to retrieve the Iceberg catalog namespace from the profile.yaml file.
- Adds merge_exclude_columns and incremental_predicates features.
- Drop Python 3.8 support
- Added configuration property to allow spark casting of seed column types

## v1.8.6
- Fix session provisioning timeout and delay handling
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ The table below describes all the options.
| glue_session_reuse | re-use the glue-session to run multiple dbt run commands: If set to true, the glue session will not be closed for re-use. If set to false, the session will be closed. The glue session will close after idle_timeout time is expired after idle_timeout time | no |
| datalake_formats | The ACID datalake format that you want to use if you are doing merge, can be `hudi`, `ìceberg` or `delta` |no|
| use_arrow | (experimental) use an arrow file instead of stdout to have better scalability. |no|
| enable_spark_seed_casting | Allows spark to cast the columns depending on the specified model column types. Default `False`. |no|

## Configs

Expand Down
64 changes: 34 additions & 30 deletions dbt/adapters/glue/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
from dbt.adapters.contracts.connection import Credentials
from dbt_common.exceptions import DbtRuntimeError


@dataclass
class GlueCredentials(Credentials):
""" Required connections for a Glue connection"""
"""Required connections for a Glue connection"""

role_arn: Optional[str] = None # type: ignore
region: Optional[str] = None # type: ignore
workers: Optional[int] = None # type: ignore
Expand Down Expand Up @@ -36,6 +38,7 @@ class GlueCredentials(Credentials):
enable_session_per_model: Optional[bool] = False
use_arrow: Optional[bool] = False
custom_iceberg_catalog_namespace: Optional[str] = "glue_catalog"
enable_spark_seed_casting: Optional[bool] = False

@property
def type(self):
Expand Down Expand Up @@ -64,34 +67,35 @@ def __post_init__(self):
self.database = None

def _connection_keys(self):
""" Keys to show when debugging """
"""Keys to show when debugging"""
return [
'role_arn',
'region',
'workers',
'worker_type',
'session_provisioning_timeout_in_seconds',
'schema',
'location',
'extra_jars',
'idle_timeout',
'query_timeout_in_minutes',
'glue_version',
'security_configuration',
'connections',
'conf',
'extra_py_files',
'delta_athena_prefix',
'tags',
'seed_format',
'seed_mode',
'default_arguments',
'iceberg_glue_commit_lock_table',
'use_interactive_session_role_for_api_calls',
'lf_tags',
'glue_session_id',
'glue_session_reuse',
'datalake_formats',
'enable_session_per_model',
'use_arrow'
"role_arn",
"region",
"workers",
"worker_type",
"session_provisioning_timeout_in_seconds",
"schema",
"location",
"extra_jars",
"idle_timeout",
"query_timeout_in_minutes",
"glue_version",
"security_configuration",
"connections",
"conf",
"extra_py_files",
"delta_athena_prefix",
"tags",
"seed_format",
"seed_mode",
"default_arguments",
"iceberg_glue_commit_lock_table",
"use_interactive_session_role_for_api_calls",
"lf_tags",
"glue_session_id",
"glue_session_reuse",
"datalake_formats",
"enable_session_per_model",
"use_arrow",
"enable_spark_seed_casting",
]
77 changes: 74 additions & 3 deletions dbt/adapters/glue/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,52 @@
from dbt_common.exceptions import DbtDatabaseError, CompilationError
from dbt.adapters.base.impl import catch_as_completed
from dbt_common.utils import executor
from dbt_common.clients import agate_helper
from dbt.adapters.events.logging import AdapterLogger

logger = AdapterLogger("Glue")


class ColumnCsvMappingStrategy:
_schema_mappings = {
agate_helper.ISODateTime: 'string',
agate_helper.Number: 'double',
agate_helper.Integer: 'int',
agate.data_types.Boolean: 'boolean',
agate.data_types.Date: 'string',
agate.data_types.DateTime: 'string',
agate.data_types.Text: 'string',
}

_cast_mappings = {
agate_helper.ISODateTime: 'timestamp',
agate.data_types.Date: 'date',
agate.data_types.DateTime: 'timestamp',
}

def __init__(self, column_name, agate_type, specified_type):
self.column_name = column_name
self.agate_type = agate_type
self.specified_type = specified_type

def as_schema_value(self):
return ColumnCsvMappingStrategy._schema_mappings.get(self.agate_type)

def as_cast_value(self):
return (
self.specified_type if self.specified_type else ColumnCsvMappingStrategy._cast_mappings.get(self.agate_type)
)

@classmethod
def from_model(cls, model, agate_table):
return [
ColumnCsvMappingStrategy(
column.name, type(column.data_type), model.get("config", {}).get("column_types", {}).get(column.name)
)
for column in agate_table.columns
]


class GlueAdapter(SQLAdapter):
ConnectionManager = GlueConnectionManager
Relation = SparkRelation
Expand Down Expand Up @@ -519,7 +560,7 @@ def create_csv_table(self, model, agate_table):
mode = "False"

csv_chunks = self._split_csv_records_into_chunks(json.loads(f.getvalue()))
statements = self._map_csv_chunks_to_code(csv_chunks, session, model, mode)
statements = self._map_csv_chunks_to_code(csv_chunks, session, model, mode, ColumnCsvMappingStrategy.from_model(model, agate_table))
try:
cursor = session.cursor()
for statement in statements:
Expand All @@ -529,7 +570,14 @@ def create_csv_table(self, model, agate_table):
except Exception as e:
logger.error(e)

def _map_csv_chunks_to_code(self, csv_chunks: List[List[dict]], session: GlueConnection, model, mode):
def _map_csv_chunks_to_code(
self,
csv_chunks: List[List[dict]],
session: GlueConnection,
model,
mode,
column_mappings: List[ColumnCsvMappingStrategy],
):
statements = []
for i, csv_chunk in enumerate(csv_chunks):
is_first = i == 0
Expand All @@ -548,8 +596,31 @@ def _map_csv_chunks_to_code(self, csv_chunks: List[List[dict]], session: GlueCon
SqlWrapper2.execute("""select 1""")
'''
else:
code += f'''
if session.credentials.enable_spark_seed_casting:
csv_schema = ", ".join(
[f"{mapping.column_name}: {mapping.as_schema_value()}" for mapping in column_mappings]
)

cast_code = ".".join(
[
"df",
*[
f'withColumn("{mapping.column_name}", df.{mapping.column_name}.cast("{cast_value}"))'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

withColumn works well for simple pattern, but if we have many withColumn, it easily causes StackOverflowException.
Do you think if this happens in real world use case? If yes, it will be safer to replace withColumn with select..as.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd say it depends, how many withColumns would cause a SO? If it's over 100 then I'd say it's not realistic for any of the use cases we're handling. However, I can look into implementing it with "select..as", but I'm unfamiliar on that syntax.
Do you mean something like this?

df = df.select(col("foo").cast("double"), col("bar").cast("string"))

Copy link
Collaborator

@moomindani moomindani Oct 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right. df.select() or df.selectExpr() work here.
The condition depends on multiple factors, but sometimes it can occur with less than 100.

for mapping in column_mappings
if (cast_value := mapping.as_cast_value())
],
]
)

code += f"""
df = spark.createDataFrame(csv, "{csv_schema}")
df = {cast_code}
"""
else:
code += """
df = spark.createDataFrame(csv)
"""
code += f'''
table_name = '{model["schema"]}.{model["name"]}'
if (spark.sql("show tables in {model["schema"]}").where("tableName == lower('{model["name"]}')").count() > 0):
df.write\
Expand Down
101 changes: 95 additions & 6 deletions tests/unit/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
import unittest
from unittest import mock
from unittest.mock import Mock
import pytest
from multiprocessing import get_context
import agate.data_types
from botocore.client import BaseClient
from moto import mock_aws

Expand All @@ -13,6 +15,8 @@
from dbt.adapters.glue import GlueAdapter
from dbt.adapters.glue.gluedbapi import GlueConnection
from dbt.adapters.glue.relation import SparkRelation
from dbt.adapters.glue.impl import ColumnCsvMappingStrategy
from dbt_common.clients import agate_helper
from tests.util import config_from_parts_or_dicts
from .util import MockAWSService

Expand Down Expand Up @@ -41,7 +45,7 @@ def setUp(self):
"region": "us-east-1",
"workers": 2,
"worker_type": "G.1X",
"location" : "path_to_location/",
"location": "path_to_location/",
"schema": "dbt_unit_test_01",
"database": "dbt_unit_test_01",
"use_interactive_session_role_for_api_calls": False,
Expand Down Expand Up @@ -71,7 +75,6 @@ def test_glue_connection(self):
self.assertIsNotNone(connection.handle)
self.assertIsInstance(glueSession.client, BaseClient)


@mock_aws
def test_get_table_type(self):
config = self._get_config()
Expand All @@ -96,8 +99,10 @@ def test_create_csv_table_slices_big_datasets(self):
adapter = GlueAdapter(config, get_context("spawn"))
model = {"name": "mock_model", "schema": "mock_schema"}
session_mock = Mock()
adapter.get_connection = lambda: (session_mock, 'mock_client')
test_table = agate.Table([(f'mock_value_{i}',f'other_mock_value_{i}') for i in range(2000)], column_names=['value', 'other_value'])
adapter.get_connection = lambda: (session_mock, "mock_client")
test_table = agate.Table(
[(f"mock_value_{i}", f"other_mock_value_{i}") for i in range(2000)], column_names=["value", "other_value"]
)
adapter.create_csv_table(model, test_table)

# test table is between 120000 and 180000 characters so it should be split three times (max chunk is 60000)
Expand All @@ -115,11 +120,95 @@ def test_get_location(self):
connection.handle # trigger lazy-load
print(adapter.get_location(relation))
self.assertEqual(adapter.get_location(relation), "LOCATION 'path_to_location/some_database/some_table'")

def test_get_custom_iceberg_catalog_namespace(self):
config = self._get_config()
adapter = GlueAdapter(config, get_context("spawn"))
with mock.patch("dbt.adapters.glue.connections.open"):
connection = adapter.acquire_connection("dummy")
connection.handle # trigger lazy-load
self.assertEqual(adapter.get_custom_iceberg_catalog_namespace(), "custom_iceberg_catalog")
self.assertEqual(adapter.get_custom_iceberg_catalog_namespace(), "custom_iceberg_catalog")

def test_create_csv_table_provides_schema_and_casts_when_spark_seed_cast_is_enabled(self):
config = self._get_config()
config.credentials.enable_spark_seed_casting = True
adapter = GlueAdapter(config, get_context("spawn"))
csv_chunks = [{"test_column_double": "1.2345", "test_column_str": "test"}]
model = {
"name": "mock_model",
"schema": "mock_schema",
"config": {"column_types": {"test_column_double": "double", "test_column_str": "string"}},
}
column_mappings = [
ColumnCsvMappingStrategy("test_column_double", agate.data_types.Text, "double"),
ColumnCsvMappingStrategy("test_column_str", agate.data_types.Text, "string"),
]
code = adapter._map_csv_chunks_to_code(csv_chunks, config, model, "True", column_mappings)
self.assertIn('spark.createDataFrame(csv, "test_column_double: string, test_column_str: string")', code[0])
self.assertIn(
'df = df.withColumn("test_column_double", df.test_column_double.cast("double"))'
+ '.withColumn("test_column_str", df.test_column_str.cast("string"))',
code[0],
)

def test_create_csv_table_doesnt_provide_schema_when_spark_seed_cast_is_disabled(self):
config = self._get_config()
config.credentials.enable_spark_seed_casting = False
adapter = GlueAdapter(config, get_context("spawn"))
csv_chunks = [{"test_column": "1.2345"}]
model = {"name": "mock_model", "schema": "mock_schema"}
column_mappings = [ColumnCsvMappingStrategy("test_column", agate.data_types.Text, "double")]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we add another case to test multiple mapping strategies?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I extended the test case to include different strategies

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you.

code = adapter._map_csv_chunks_to_code(csv_chunks, config, model, "True", column_mappings)
self.assertIn("spark.createDataFrame(csv)", code[0])


class TestCsvMappingStrategy:
@pytest.mark.parametrize(
"agate_type,specified_type,expected_schema_type,expected_cast_type",
[
(agate_helper.ISODateTime, None, "string", "timestamp"),
(agate_helper.Number, None, "double", None),
(agate_helper.Integer, None, "int", None),
(agate.data_types.Boolean, None, "boolean", None),
(agate.data_types.Date, None, "string", "date"),
(agate.data_types.DateTime, None, "string", "timestamp"),
(agate.data_types.Text, None, "string", None),
(agate.data_types.Text, "double", "string", "double"),
],
ids=[
"test isodatetime cast",
"test number cast",
"test integer cast",
"test boolean cast",
"test date cast",
"test datetime cast",
"test text cast",
"test specified cast",
],
)
def test_mapping_strategy_provides_proper_mappings(
self, agate_type, specified_type, expected_schema_type, expected_cast_type
):
column_mapping = ColumnCsvMappingStrategy("test_column", agate_type, specified_type)
assert column_mapping.as_schema_value() == expected_schema_type
assert column_mapping.as_cast_value() == expected_cast_type

def test_from_model_builds_column_mappings(self):
expected_column_names = ["col_int", "col_str", "col_date", "col_specific"]
expected_agate_types = [
agate_helper.Integer,
agate.data_types.Text,
agate.data_types.Date,
agate.data_types.Text,
]
expected_specified_types = [None, None, None, "double"]
agate_table = agate.Table(
[(111, "str_val", "2024-01-01", "1.234")],
column_names=expected_column_names,
column_types=[data_type() for data_type in expected_agate_types],
)
model = {"name": "mock_model", "config": {"column_types": {"col_specific": "double"}}}
mappings = ColumnCsvMappingStrategy.from_model(model, agate_table)
assert expected_column_names == [mapping.column_name for mapping in mappings]
assert expected_agate_types == [mapping.agate_type for mapping in mappings]
assert expected_specified_types == [mapping.specified_type for mapping in mappings]