Skip to content

Commit

Permalink
Support the -x flag in env.py
Browse files Browse the repository at this point in the history
  • Loading branch information
luolingchun committed Oct 17, 2023
1 parent d70e25c commit 986bb25
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 4 deletions.
23 changes: 21 additions & 2 deletions src/flask_migrate/templates/aioflask-multidb/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
from logging.config import fileConfig

from sqlalchemy import MetaData
from sqlalchemy import MetaData, text
from flask import current_app

from alembic import context
Expand Down Expand Up @@ -135,8 +135,27 @@ def process_revision_directives(context, revision, directives):
if conf_args.get("process_revision_directives") is None:
conf_args["process_revision_directives"] = process_revision_directives

current_schema = context.get_x_argument(as_dictionary=True).get("schema")

for name, rec in engines.items():
rec['sync_connection'] = conn = rec['connection']._sync_connection()
connection = rec['connection']
if current_schema and connection.dialect.name == "postgresql":
# set search path on the connection, which ensures that
# PostgreSQL will emit all CREATE / ALTER / DROP statements
# in terms of this schema by default
connection.execute(
text('set search_path to "%s"' % current_schema)
)
# in SQLAlchemy v2+ the search path change
# needs to be committed
connection.commit()

# make use of non-supported SQLAlchemy attribute to ensure
# the dialect reflects tables in terms of
# the current schema name
connection.dialect.default_schema_name = current_schema

rec['sync_connection'] = conn = connection._sync_connection()
if USE_TWOPHASE:
rec['transaction'] = conn.begin_twophase()
else:
Expand Down
17 changes: 17 additions & 0 deletions src/flask_migrate/templates/aioflask/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from flask import current_app

from alembic import context
from sqlalchemy import text

# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
Expand Down Expand Up @@ -88,6 +89,22 @@ def process_revision_directives(context, revision, directives):
if conf_args.get("process_revision_directives") is None:
conf_args["process_revision_directives"] = process_revision_directives

current_schema = context.get_x_argument(as_dictionary=True).get("schema")

if current_schema and connection.dialect.name == "postgresql":
# set search path on the connection, which ensures that
# PostgreSQL will emit all CREATE / ALTER / DROP statements
# in terms of this schema by default
connection.execute(
text('set search_path to "%s"' % current_schema)
)
# in SQLAlchemy v2+ the search path change needs to be committed
connection.commit()

# make use of non-supported SQLAlchemy attribute to ensure
# the dialect reflects tables in terms of the current schema name
connection.dialect.default_schema_name = current_schema

context.configure(
connection=connection,
target_metadata=get_metadata(),
Expand Down
23 changes: 21 additions & 2 deletions src/flask_migrate/templates/flask-multidb/env.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
from logging.config import fileConfig

from sqlalchemy import MetaData
from sqlalchemy import MetaData, text
from flask import current_app

from alembic import context
Expand Down Expand Up @@ -158,11 +158,30 @@ def process_revision_directives(context, revision, directives):
else:
rec['transaction'] = conn.begin()

current_schema = context.get_x_argument(as_dictionary=True).get("schema")

try:
for name, rec in engines.items():
connection = rec['connection']
if current_schema and connection.dialect.name == "postgresql":
# set search path on the connection, which ensures that
# PostgreSQL will emit all CREATE / ALTER / DROP statements
# in terms of this schema by default
connection.execute(
text('set search_path to "%s"' % current_schema)
)
# in SQLAlchemy v2+ the search path change
# needs to be committed
connection.commit()

# make use of non-supported SQLAlchemy attribute to ensure
# the dialect reflects tables in terms of
# the current schema name
connection.dialect.default_schema_name = current_schema

logger.info("Migrating database %s" % (name or '<default>'))
context.configure(
connection=rec['connection'],
connection=connection,
upgrade_token="%s_upgrades" % name,
downgrade_token="%s_downgrades" % name,
target_metadata=get_metadata(name),
Expand Down
17 changes: 17 additions & 0 deletions src/flask_migrate/templates/flask/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from flask import current_app

from alembic import context
from sqlalchemy import text

# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
Expand Down Expand Up @@ -96,7 +97,23 @@ def process_revision_directives(context, revision, directives):

connectable = get_engine()

current_schema = context.get_x_argument(as_dictionary=True).get("schema")

with connectable.connect() as connection:
if current_schema and connection.dialect.name == "postgresql":
# set search path on the connection, which ensures that
# PostgreSQL will emit all CREATE / ALTER / DROP statements
# in terms of this schema by default
connection.execute(
text('set search_path to "%s"' % current_schema)
)
# in SQLAlchemy v2+ the search path change needs to be committed
connection.commit()

# make use of non-supported SQLAlchemy attribute to ensure
# the dialect reflects tables in terms of the current schema name
connection.dialect.default_schema_name = current_schema

context.configure(
connection=connection,
target_metadata=get_metadata(),
Expand Down

0 comments on commit 986bb25

Please sign in to comment.