From ccdc630a285500302dfff48f8ec8f1fdd44c1b20 Mon Sep 17 00:00:00 2001 From: Jessie Chen Date: Tue, 25 Jul 2023 13:29:18 -0700 Subject: [PATCH 1/4] override get_columns_in_relation with driver api call --- dbt/adapters/redshift/impl.py | 29 +++++++++++++++++++++++++++++ tests/unit/test_redshift_adapter.py | 26 ++++++++++++++++++++++++++ 2 files changed, 55 insertions(+) diff --git a/dbt/adapters/redshift/impl.py b/dbt/adapters/redshift/impl.py index ae9f18392..427eb48cd 100644 --- a/dbt/adapters/redshift/impl.py +++ b/dbt/adapters/redshift/impl.py @@ -4,6 +4,7 @@ from dbt.adapters.base import PythonJobHelper from dbt.adapters.base.impl import AdapterConfig, ConstraintSupport from dbt.adapters.base.meta import available +from dbt.adapters.base.column import Column from dbt.adapters.sql import SQLAdapter from dbt.contracts.connection import AdapterResponse from dbt.contracts.graph.nodes import ConstraintType @@ -115,6 +116,34 @@ def valid_incremental_strategies(self): def timestamp_add_sql(self, add_to: str, number: int = 1, interval: str = "hour") -> str: return f"{add_to} + interval '{number} {interval}'" + def _get_cursor(self): + return self.connections.get_thread_connection().handle.cursor() + + def get_columns_in_relation(self, relation): + cursor = self._get_cursor() + results = [] + if relation.identifier: + columns = cursor.get_columns( + catalog=relation.database, + schema_pattern=relation.schema, + tablename_pattern=relation.identifier, + ) + else: + columns = cursor.get_columns(catalog=relation.database, schema_pattern=relation.schema) + if columns is not None and len(columns) > 0: + for column in columns: + if column[4] == 1 or column[4] == 12: # if column type is character + results.append(Column(column[3], column[5], column[6], None, None)) + # elif column[4] == 5 or column[4] == 4 or column[4] == -5 or column[4] == 3 or column[4] == 7\ + # or column[4] == 8 or column[4] == 6 or column[4] == 2 or column[4] == 2003:#if column type is numeric + elif any( + column[4] == type_int for type_int in [5, 4, -5, 3, 7, 8, 6, 2, 2003] + ): # if column type is numeric + results.append(Column(column[3], column[5], None, column[6], column[8])) + else: + results.append(Column(column[3], column[5], column[6], None, None)) + return results + def _link_cached_database_relations(self, schemas: Set[str]): """ :param schemas: The set of schemas that should have links added. diff --git a/tests/unit/test_redshift_adapter.py b/tests/unit/test_redshift_adapter.py index c31366a1e..354356252 100644 --- a/tests/unit/test_redshift_adapter.py +++ b/tests/unit/test_redshift_adapter.py @@ -587,6 +587,32 @@ def mock_cursor(self, mock_get_thread_conn): mock_handle.return_value = mock_cursor return mock_cursor + @mock.patch("dbt.adapters.redshift.impl.RedshiftAdapter._get_cursor") + def test_get_columns_in_relation_char(self, mock_cursor): + mock_relation = mock.MagicMock(database="somedb", schema="someschema", identifier="iden") + mock_cursor.return_value.get_columns.return_value = [ + ("TABLE_CAT", "TABLE_SCHEM", "TABLE_NAME", "char_column", 12, "char", 10, 10, None) + ] + result = self.adapter.get_columns_in_relation(mock_relation) + self.assertTrue(result[0].column == "char_column") + self.assertTrue(result[0].dtype == "char") + self.assertTrue(result[0].char_size == 10) + self.assertTrue(result[0].numeric_scale is None) + self.assertTrue(result[0].numeric_precision is None) + + @mock.patch("dbt.adapters.redshift.impl.RedshiftAdapter._get_cursor") + def test_get_columns_in_relation_int(self, mock_cursor): + mock_relation = mock.MagicMock(database="somedb", schema="someschema", identifier="iden") + mock_cursor.return_value.get_columns.return_value = [ + ("TABLE_CAT", "TABLE_SCHEM", "TABLE_NAME", "int_column", 4, "integer", 10, 10, 20) + ] + result = self.adapter.get_columns_in_relation(mock_relation) + self.assertTrue(result[0].column == "int_column") + self.assertTrue(result[0].dtype == "integer") + self.assertTrue(result[0].char_size is None) + self.assertTrue(result[0].numeric_scale == 20) + self.assertTrue(result[0].numeric_precision == 10) + class TestRedshiftAdapterConversions(TestAdapterConversions): def test_convert_text_type(self): From 122ce603189d3ee37eb0b9dd59d1d5a5d3cfc182 Mon Sep 17 00:00:00 2001 From: Jessie Chen Date: Tue, 25 Jul 2023 14:16:46 -0700 Subject: [PATCH 2/4] add changelog --- .changes/unreleased/Under the Hood-20230725-141628.yaml | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 .changes/unreleased/Under the Hood-20230725-141628.yaml diff --git a/.changes/unreleased/Under the Hood-20230725-141628.yaml b/.changes/unreleased/Under the Hood-20230725-141628.yaml new file mode 100644 index 000000000..55922946c --- /dev/null +++ b/.changes/unreleased/Under the Hood-20230725-141628.yaml @@ -0,0 +1,6 @@ +kind: Under the Hood +body: use redshift driver api calls instead of hard coding get_columns +time: 2023-07-25T14:16:28.492828-07:00 +custom: + Author: jiezhen-chen + Issue: "555" From 4f49325e2136a54490c197eec429c8f57cdac823 Mon Sep 17 00:00:00 2001 From: Jessie Chen Date: Tue, 25 Jul 2023 20:54:46 -0700 Subject: [PATCH 3/4] add functional tests --- .../adapter/test_adapter_methods.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/tests/functional/adapter/test_adapter_methods.py b/tests/functional/adapter/test_adapter_methods.py index 9907ccb72..591e14f14 100644 --- a/tests/functional/adapter/test_adapter_methods.py +++ b/tests/functional/adapter/test_adapter_methods.py @@ -16,6 +16,20 @@ """ +tests__get_columns_invalid = """ +{% set model = ref('model_regression') %} +{% set relation = api.Relation.create(database=model.database, schema=model.schema, identifier=model.identifier) %} +{% set cols = adapter.get_columns_in_relation(relation) %} +{%- if ( (cols[0].dtype == 'int4') and (cols[0].column == 'id') ) -%} + {% set limit_query = 0 %} +{% else %} + {% set limit_query = 1 %} +{% endif %} + +select 1 as id limit {{ limit_query }} + +""" + models__upstream_sql = """ select 1 as id @@ -72,7 +86,10 @@ class RedshiftAdapterMethod: @pytest.fixture(scope="class") def tests(self): - return {"get_relation_invalid.sql": tests__get_relation_invalid} + return { + "get_relation_invalid.sql": tests__get_relation_invalid, + "get_columns_invalid.sql": tests__get_columns_invalid, + } @pytest.fixture(scope="class") def models(self): From c0979217e063225650714bfbc9911b31b6693874 Mon Sep 17 00:00:00 2001 From: Jessie Chen Date: Thu, 27 Jul 2023 22:29:01 -0700 Subject: [PATCH 4/4] use constants to unpack tuple --- dbt/adapters/redshift/impl.py | 35 ++++++++++++++++------------------- 1 file changed, 16 insertions(+), 19 deletions(-) diff --git a/dbt/adapters/redshift/impl.py b/dbt/adapters/redshift/impl.py index 427eb48cd..c5f6e66f1 100644 --- a/dbt/adapters/redshift/impl.py +++ b/dbt/adapters/redshift/impl.py @@ -121,27 +121,24 @@ def _get_cursor(self): def get_columns_in_relation(self, relation): cursor = self._get_cursor() - results = [] - if relation.identifier: - columns = cursor.get_columns( - catalog=relation.database, - schema_pattern=relation.schema, - tablename_pattern=relation.identifier, - ) - else: - columns = cursor.get_columns(catalog=relation.database, schema_pattern=relation.schema) - if columns is not None and len(columns) > 0: + if columns := cursor.get_columns( + catalog=relation.database, + schema_pattern=relation.schema, + tablename_pattern=relation.identifier, + ): + results = [] + CHAR_TYPES = {1, 12} + NUMERIC_TYPES = {5, 4, -5, 3, 7, 8, 6, 2, 2003} for column in columns: - if column[4] == 1 or column[4] == 12: # if column type is character - results.append(Column(column[3], column[5], column[6], None, None)) - # elif column[4] == 5 or column[4] == 4 or column[4] == -5 or column[4] == 3 or column[4] == 7\ - # or column[4] == 8 or column[4] == 6 or column[4] == 2 or column[4] == 2003:#if column type is numeric - elif any( - column[4] == type_int for type_int in [5, 4, -5, 3, 7, 8, 6, 2, 2003] - ): # if column type is numeric - results.append(Column(column[3], column[5], None, column[6], column[8])) + _, _, _, name, dtype_number, dtype_name, size, _, numeric_precision, *_ = column + if dtype_number in CHAR_TYPES: + results.append(Column(name, dtype_name, size, None, None)) # if + elif dtype_number in NUMERIC_TYPES: + results.append(Column(name, dtype_name, None, size, numeric_precision)) else: - results.append(Column(column[3], column[5], column[6], None, None)) + results.append(Column(name, dtype_name, size, None, None)) + else: + results = [] return results def _link_cached_database_relations(self, schemas: Set[str]):