From 509d5595c43f47886401fce97cf27b917d6ff7a4 Mon Sep 17 00:00:00 2001 From: betaboon Date: Fri, 23 Dec 2022 15:25:57 +0100 Subject: [PATCH 1/2] Support for schema in CompositeType --- sqlalchemy_utils/types/pg_composite.py | 16 ++++-- tests/types/test_composite.py | 75 ++++++++++++++++++++++++++ 2 files changed, 87 insertions(+), 4 deletions(-) diff --git a/sqlalchemy_utils/types/pg_composite.py b/sqlalchemy_utils/types/pg_composite.py index 72e4b5de..3f6c53a1 100644 --- a/sqlalchemy_utils/types/pg_composite.py +++ b/sqlalchemy_utils/types/pg_composite.py @@ -183,7 +183,7 @@ def __getattr__(self, key): return CompositeElement(self.expr, key, type_) - def __init__(self, name, columns, quote=None): + def __init__(self, name, columns, schema=None, quote=None): if psycopg2 is None: raise ImproperlyConfigured( "'psycopg2' package is required in order to use CompositeType." @@ -191,6 +191,7 @@ def __init__(self, name, columns, quote=None): SchemaType.__init__( self, name=name, + schema=schema, quote=quote ) self.columns = columns @@ -210,7 +211,7 @@ def make(obj, values): attach_composite_listeners() def get_col_spec(self): - return self.name + return f"{self.schema}.{self.name}" if self.schema else self.name def bind_processor(self, dialect): def process(value): @@ -269,7 +270,7 @@ def drop(self, bind=None, checkfirst=True): def register_psycopg2_composite(dbapi_connection, composite): psycopg2.extras.register_composite( - composite.name, + composite.get_col_spec(), dbapi_connection, globally=True, factory=composite.caster @@ -296,10 +297,17 @@ def adapt_composite(value): value.getquoted().decode(dbapi_connection.encoding) for value in adapted ] + quoted_name = dialect.identifier_preparer.quote(composite.name) + typename = ( + f"{composite.schema}.{quoted_name}" + if composite.schema + else quoted_name + ) + return AsIs( '({})::{}'.format( ', '.join(values), - dialect.identifier_preparer.quote(composite.name) + typename ) ) diff --git a/tests/types/test_composite.py b/tests/types/test_composite.py index 3ea9af42..5a9f1fae 100644 --- a/tests/types/test_composite.py +++ b/tests/types/test_composite.py @@ -476,3 +476,78 @@ def test_parameter_processing(self, session, Account): account = session.query(Account).first() assert account.balance.currency == 'USD' assert account.balance.amount == 15 + + +@pytest.mark.usefixtures('postgresql_dsn') +class TestCompositeTypeWithSchema: + + @pytest.fixture + def Account(self, Base): + pg_composite.registered_composites = {} + + type_ = CompositeType( + name='MoneyType', + schema='my_schema', + columns=[ + sa.Column('currency', sa.String), + sa.Column('amount', sa.Integer) + ] + ) + + class Account(Base): + __tablename__ = 'account' + __table_args__ = {'schema': 'my_schema'} + id = sa.Column(sa.Integer, primary_key=True) + balance = sa.Column(type_) + + return Account + + @pytest.fixture + def session(self, request, engine, connection, Base, Account): + sa.orm.configure_mappers() + + Session = sessionmaker(bind=connection) + try: + # Enable sqlalchemy 2.0 behavior + session = Session(future=True) + except TypeError: + # sqlalchemy 1.3 + session = Session() + session.execute(sa.text('CREATE SCHEMA my_schema')) + session.execute(sa.text( + """CREATE TYPE my_schema."MoneyType" AS ( + currency VARCHAR, amount INTEGER + )""" + )) + session.execute(sa.text( + """CREATE TABLE my_schema.account ( + id SERIAL, balance my_schema."MoneyType", PRIMARY KEY(id) + )""" + )) + + def teardown(): + session.execute(sa.text('DROP TABLE my_schema.account')) + session.execute(sa.text('DROP TYPE my_schema."MoneyType"')) + session.execute(sa.text('DROP SCHEMA my_schema')) + session.commit() + close_all_sessions() + connection.close() + remove_composite_listeners() + engine.dispose() + + register_composites(connection) + request.addfinalizer(teardown) + + return session + + def test_parameter_processing(self, session, Account): + account = Account( + balance=('USD', 15), + ) + + session.add(account) + session.commit() + + account = session.query(Account).first() + assert account.balance.currency == 'USD' + assert account.balance.amount == 15 From ae1c3470c247f08788eef3eeddd73ec39f0d5c32 Mon Sep 17 00:00:00 2001 From: betaboon Date: Fri, 23 Dec 2022 15:33:50 +0100 Subject: [PATCH 2/2] Improve adaptation and result-processing code in CompositeType --- sqlalchemy_utils/types/pg_composite.py | 41 ++++++------ tests/types/test_composite.py | 89 ++++++++++++++++++++++++++ 2 files changed, 112 insertions(+), 18 deletions(-) diff --git a/sqlalchemy_utils/types/pg_composite.py b/sqlalchemy_utils/types/pg_composite.py index 3f6c53a1..2e422f33 100644 --- a/sqlalchemy_utils/types/pg_composite.py +++ b/sqlalchemy_utils/types/pg_composite.py @@ -244,12 +244,17 @@ def process(value): cls = value.__class__ kwargs = {} for column in self.columns: + column_value = getattr(value, column.name) + result_processor = column.type.result_processor(dialect, column.type) if isinstance(column.type, TypeDecorator): - kwargs[column.name] = column.type.process_result_value( - getattr(value, column.name), dialect + processed_value = column.type.process_result_value( + column_value, dialect ) + elif result_processor: + processed_value = result_processor(column_value) else: - kwargs[column.name] = getattr(value, column.name) + processed_value = column_value + kwargs[column.name] = processed_value return cls(**kwargs) return process @@ -278,21 +283,21 @@ def register_psycopg2_composite(dbapi_connection, composite): def adapt_composite(value): dialect = PGDialect_psycopg2() - adapted = [ - adapt( - getattr(value, column.name) - if not isinstance(column.type, TypeDecorator) - else column.type.process_bind_param( - getattr(value, column.name), - dialect - ) - ) - for column in - composite.columns - ] - for value in adapted: - if hasattr(value, 'prepare'): - value.prepare(dbapi_connection) + adapted = [] + for column in composite.columns: + column_value = getattr(value, column.name) + bind_processor = column.type.bind_processor(dialect) + if isinstance(column.type, TypeDecorator): + processed_value = column.type.process_bind_param(column_value, dialect) + elif bind_processor: + processed_value = bind_processor(column_value) + else: + processed_value = column_value + adapted_value = adapt(processed_value) + if hasattr(adapted_value, 'prepare'): + adapted_value.prepare(dbapi_connection) + adapted.append(adapted_value) + values = [ value.getquoted().decode(dbapi_connection.encoding) for value in adapted diff --git a/tests/types/test_composite.py b/tests/types/test_composite.py index 5a9f1fae..6e655e56 100644 --- a/tests/types/test_composite.py +++ b/tests/types/test_composite.py @@ -1,3 +1,5 @@ +import enum + import pytest import sqlalchemy as sa from sqlalchemy.dialects.postgresql import ARRAY @@ -551,3 +553,90 @@ def test_parameter_processing(self, session, Account): account = session.query(Account).first() assert account.balance.currency == 'USD' assert account.balance.amount == 15 + + +@pytest.mark.usefixtures('postgresql_dsn') +class TestCompositeTypeWithEnumColumn: + @pytest.fixture + def PaintColor(self): + class PaintColor(enum.Enum): + RED = "red" + GREEN = "green" + BLUE = "blue" + return PaintColor + + @pytest.fixture + def PaintJob(self, Base, PaintColor): + pg_composite.registered_composites = {} + + type_ = CompositeType( + name='PaintType', + columns=[ + sa.Column('name', sa.Text), + sa.Column('color', sa.Enum( + PaintColor, + values_callable=lambda o: [e.value for e in o], + )), + ] + ) + + class PaintJob(Base): + __tablename__ = 'paint_job' + id = sa.Column(sa.Integer, primary_key=True) + paint = sa.Column(type_) + + return PaintJob + + @pytest.fixture + def session(self, request, engine, connection, Base, PaintJob): + sa.orm.configure_mappers() + + Session = sessionmaker(bind=connection) + try: + # Enable sqlalchemy 2.0 behavior + session = Session(future=True) + except TypeError: + # sqlalchemy 1.3 + session = Session() + session.execute(sa.text( + """CREATE TYPE "PaintColor" AS ENUM( + 'red', + 'green', + 'blue' + )""" + )) + session.execute( + sa.text('CREATE TYPE "PaintType" AS (name TEXT, color "PaintColor")') + ) + session.execute(sa.text( + """CREATE TABLE paint_job ( + id SERIAL, paint "PaintType", PRIMARY KEY(id) + )""" + )) + + def teardown(): + session.execute(sa.text('DROP TABLE paint_job')) + session.execute(sa.text('DROP TYPE "PaintType"')) + session.execute(sa.text('DROP TYPE "PaintColor"')) + session.commit() + close_all_sessions() + connection.close() + remove_composite_listeners() + engine.dispose() + + register_composites(connection) + request.addfinalizer(teardown) + + return session + + def test_parameter_processing(self, session, PaintJob, PaintColor): + paint_job = PaintJob( + paint=('awesome red', PaintColor.RED), + ) + + session.add(paint_job) + session.commit() + + paint_job = session.query(PaintJob).first() + assert paint_job.paint.name == 'awesome red' + assert paint_job.paint.color == PaintColor.RED