Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PECO-1436] [sqlalchemy] Include sqlalchemy __version__ in user-agent #332

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
7 changes: 4 additions & 3 deletions src/databricks/sqlalchemy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,14 +411,15 @@ def receive_do_connect(dialect, conn_rec, cargs, cparams):
ua = cparams.get("_user_agent_entry", "")

def add_sqla_tag_if_not_present(val: str):
tag = f"sqlalchemy=={sqlalchemy.__version__}"
if not val:
output = "sqlalchemy"
output = tag

if val and "sqlalchemy" in val:
if val and tag in val:
output = val

else:
output = f"sqlalchemy + {val}"
output = f"{tag} + {val}"

return output

Expand Down
64 changes: 37 additions & 27 deletions src/databricks/sqlalchemy/test_local/e2e/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,20 +120,6 @@ def test_can_connect(db_engine):
assert len(result) == 1


def test_connect_args(db_engine):
"""Verify that extra connect args passed to sqlalchemy.create_engine are passed to DBAPI

This will most commonly happen when partners supply a user agent entry
"""

conn = db_engine.connect()
connection_headers = conn.connection.thrift_backend._transport._headers
user_agent = connection_headers["User-Agent"]

expected = f"(sqlalchemy + {USER_AGENT_TOKEN})"
assert expected in user_agent


@pytest.mark.skipif(sqlalchemy_1_3(), reason="Pandas requires SQLAlchemy >= 1.4")
@pytest.mark.skip(
reason="DBR is currently limited to 256 parameters per call to .execute(). Test cannot pass."
Expand Down Expand Up @@ -449,25 +435,49 @@ def test_has_table_across_schemas(db_engine: Engine, samples_engine: Engine):
conn.execute(text("DROP TABLE test_has_table;"))


def test_user_agent_adjustment(db_engine):
# If .connect() is called multiple times on an engine, don't keep pre-pending the user agent
# https://github.com/databricks/databricks-sql-python/issues/192
c1 = db_engine.connect()
c2 = db_engine.connect()

def get_conn_user_agent(conn):
class TestUserAgent:
def get_conn_user_agent(self, conn):
return conn.connection.dbapi_connection.thrift_backend._transport._headers.get(
"User-Agent"
)

ua1 = get_conn_user_agent(c1)
ua2 = get_conn_user_agent(c2)
same_ua = ua1 == ua2
def test_user_agent_adjustment(self, db_engine):
# If .connect() is called multiple times on an engine, don't keep pre-pending the user agent
# https://github.com/databricks/databricks-sql-python/issues/192
c1 = db_engine.connect()
c2 = db_engine.connect()

ua1 = self.get_conn_user_agent(c1)
ua2 = self.get_conn_user_agent(c2)
same_ua = ua1 == ua2

c1.close()
c2.close()

assert same_ua, f"User agents didn't match \n {ua1} \n {ua2}"

def test_sqlalchemy_user_agent_includes_version(self, db_engine):
"""So that we know when we can safely deprecate support for sqlalchemy 1.x"""

import sqlalchemy

version_str = sqlalchemy.__version__
c = db_engine.connect()
ua = self.get_conn_user_agent(c)

assert version_str in ua

def test_user_supplied_string(self, db_engine):
"""Verify that extra connect args passed to sqlalchemy.create_engine are passed to DBAPI

This will most commonly happen when partners supply a user agent entry
"""

c1.close()
c2.close()
conn = db_engine.connect()
connection_headers = conn.connection.thrift_backend._transport._headers
user_agent = connection_headers["User-Agent"]

assert same_ua, f"User agents didn't match \n {ua1} \n {ua2}"
assert USER_AGENT_TOKEN in user_agent


@pytest.fixture
Expand Down
Loading