Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

override get_columns_in_relation with driver api call #554

6 changes: 6 additions & 0 deletions .changes/unreleased/Under the Hood-20230725-141628.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Under the Hood
body: use redshift driver api calls instead of hard coding get_columns
time: 2023-07-25T14:16:28.492828-07:00
custom:
Author: jiezhen-chen
Issue: "555"
26 changes: 26 additions & 0 deletions dbt/adapters/redshift/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from dbt.adapters.base import PythonJobHelper
from dbt.adapters.base.impl import AdapterConfig, ConstraintSupport
from dbt.adapters.base.meta import available
from dbt.adapters.base.column import Column
from dbt.adapters.sql import SQLAdapter
from dbt.contracts.connection import AdapterResponse
from dbt.contracts.graph.nodes import ConstraintType
Expand Down Expand Up @@ -122,6 +123,31 @@ def valid_incremental_strategies(self):
def timestamp_add_sql(self, add_to: str, number: int = 1, interval: str = "hour") -> str:
return f"{add_to} + interval '{number} {interval}'"

def _get_cursor(self):
return self.connections.get_thread_connection().handle.cursor()

def get_columns_in_relation(self, relation):
mikealfare marked this conversation as resolved.
Show resolved Hide resolved
cursor = self._get_cursor()
if columns := cursor.get_columns(
catalog=relation.database,
schema_pattern=relation.schema,
tablename_pattern=relation.identifier,
):
results = []
CHAR_TYPES = {1, 12}
NUMERIC_TYPES = {5, 4, -5, 3, 7, 8, 6, 2, 2003}
for column in columns:
mikealfare marked this conversation as resolved.
Show resolved Hide resolved
_, _, _, name, dtype_number, dtype_name, size, _, numeric_precision, *_ = column
if dtype_number in CHAR_TYPES:
results.append(Column(name, dtype_name, size, None, None)) # if
elif dtype_number in NUMERIC_TYPES:
results.append(Column(name, dtype_name, None, size, numeric_precision))
else:
results.append(Column(name, dtype_name, size, None, None))
else:
results = []
return results

def _link_cached_database_relations(self, schemas: Set[str]):
"""
:param schemas: The set of schemas that should have links added.
Expand Down
19 changes: 18 additions & 1 deletion tests/functional/adapter/test_adapter_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,20 @@

"""

tests__get_columns_invalid = """
{% set model = ref('model_regression') %}
{% set relation = api.Relation.create(database=model.database, schema=model.schema, identifier=model.identifier) %}
{% set cols = adapter.get_columns_in_relation(relation) %}
{%- if ( (cols[0].dtype == 'int4') and (cols[0].column == 'id') ) -%}
{% set limit_query = 0 %}
{% else %}
{% set limit_query = 1 %}
{% endif %}

select 1 as id limit {{ limit_query }}

"""

models__upstream_sql = """
select 1 as id

Expand Down Expand Up @@ -72,7 +86,10 @@
class RedshiftAdapterMethod:
@pytest.fixture(scope="class")
def tests(self):
return {"get_relation_invalid.sql": tests__get_relation_invalid}
return {
"get_relation_invalid.sql": tests__get_relation_invalid,
"get_columns_invalid.sql": tests__get_columns_invalid,
}

@pytest.fixture(scope="class")
def models(self):
Expand Down
26 changes: 26 additions & 0 deletions tests/unit/test_redshift_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,32 @@ def mock_cursor(self, mock_get_thread_conn):
mock_handle.return_value = mock_cursor
return mock_cursor

@mock.patch("dbt.adapters.redshift.impl.RedshiftAdapter._get_cursor")
def test_get_columns_in_relation_char(self, mock_cursor):
mock_relation = mock.MagicMock(database="somedb", schema="someschema", identifier="iden")
mock_cursor.return_value.get_columns.return_value = [
("TABLE_CAT", "TABLE_SCHEM", "TABLE_NAME", "char_column", 12, "char", 10, 10, None)
]
result = self.adapter.get_columns_in_relation(mock_relation)
self.assertTrue(result[0].column == "char_column")
self.assertTrue(result[0].dtype == "char")
self.assertTrue(result[0].char_size == 10)
self.assertTrue(result[0].numeric_scale is None)
self.assertTrue(result[0].numeric_precision is None)

@mock.patch("dbt.adapters.redshift.impl.RedshiftAdapter._get_cursor")
def test_get_columns_in_relation_int(self, mock_cursor):
mock_relation = mock.MagicMock(database="somedb", schema="someschema", identifier="iden")
mock_cursor.return_value.get_columns.return_value = [
("TABLE_CAT", "TABLE_SCHEM", "TABLE_NAME", "int_column", 4, "integer", 10, 10, 20)
]
result = self.adapter.get_columns_in_relation(mock_relation)
self.assertTrue(result[0].column == "int_column")
self.assertTrue(result[0].dtype == "integer")
self.assertTrue(result[0].char_size is None)
self.assertTrue(result[0].numeric_scale == 20)
self.assertTrue(result[0].numeric_precision == 10)


class TestRedshiftAdapterConversions(TestAdapterConversions):
def test_convert_text_type(self):
Expand Down
Loading