diff --git a/dbt/adapters/databricks/connections.py b/dbt/adapters/databricks/connections.py index 509686d7..5380ac21 100644 --- a/dbt/adapters/databricks/connections.py +++ b/dbt/adapters/databricks/connections.py @@ -22,6 +22,7 @@ from databricks.sql.client import Connection as DatabricksSQLConnection from databricks.sql.client import Cursor as DatabricksSQLCursor from databricks.sql.exc import Error +from databricks.sql.types import Row from dbt.adapters.base.query_headers import MacroQueryStringSetter from dbt.adapters.contracts.connection import ( DEFAULT_QUERY_COMMENT, @@ -179,12 +180,21 @@ def dbr_version(self) -> tuple[int, int]: return self._dbr_version +@dataclass +class DatabricksQueryImpact: + num_affected_rows: Optional[int] = None + num_updated_rows: Optional[int] = None + num_deleted_rows: Optional[int] = None + num_inserted_rows: Optional[int] = None + + class DatabricksSQLCursorWrapper: """Wrap a Databricks SQL cursor in a way that no-ops transactions""" _cursor: DatabricksSQLCursor _user_agent: str _creds: DatabricksCredentials + _cache_fetchone: Optional[Row] = None def __init__(self, cursor: DatabricksSQLCursor, creds: DatabricksCredentials, user_agent: str): self._cursor = cursor @@ -207,17 +217,52 @@ def close(self) -> None: except Error as exc: logger.warning(CursorCloseError(self._cursor, exc)) - def fetchall(self) -> Sequence[tuple]: + def fetchall(self) -> Sequence[Row]: return self._cursor.fetchall() - def fetchone(self) -> Optional[tuple]: + def query_impact(self) -> DatabricksQueryImpact: + """Get the number of rows affected by the last query. + + Delta returns for merge, update and insert commands a single row containing: + - num_affected_rows: the number of rows affected by the query + - num_updated_rows: the number of rows updated by the query + - num_deleted_rows: the number of rows deleted by the query + - num_inserted_rows: the number of rows inserted by the query + + This method attempts to retrieve it from the last query, while caching the result to make + sure it does not interfere with the fetchone method. + """ + if not self._cache_fetchone: + try: + # Cache the result to be able to return it if fetchone is called later + self._cache_fetchone = self._cursor.fetchone() + except Error: + return DatabricksQueryImpact() + + if not self._cache_fetchone: + return DatabricksQueryImpact() + + # Cast the result to check that is indeed query metadata + try: + return DatabricksQueryImpact(**self._cache_fetchone.asDict()) + except TypeError: + return DatabricksQueryImpact() + + def fetchone(self) -> Optional[Row]: + if self._cache_fetchone: + # If `fetchone` result was cached by `query_metadata`, return it and invalidate it + row = self._cache_fetchone + self._cache_fetchone = None + return row return self._cursor.fetchone() - def fetchmany(self, size: int) -> Sequence[tuple]: + def fetchmany(self, size: int) -> Sequence[Row]: return self._cursor.fetchmany(size) def execute(self, sql: str, bindings: Optional[Sequence[Any]] = None) -> None: # print(f"execute: {sql}") + # Invalidate fetchone cache + self._cache_fetchone = None if sql.strip().endswith(";"): sql = sql.strip()[:-1] if bindings is not None: @@ -300,6 +345,9 @@ def _get_comment_macro(self) -> Optional[str]: @dataclass class DatabricksAdapterResponse(AdapterResponse): query_id: str = "" + rows_updated: Optional[int] = None + rows_deleted: Optional[int] = None + rows_inserted: Optional[int] = None @dataclass(init=False) @@ -531,7 +579,7 @@ def execute( sql = self._add_query_comment(sql) _, cursor = self.add_query(sql, auto_begin) try: - response = self.get_response(cursor) + response = self.get_response(cursor, include_impact=(not fetch)) if fetch: table = self.get_result_from_cursor(cursor, limit) else: @@ -693,7 +741,9 @@ def exponential_backoff(attempt: int) -> int: ) @classmethod - def get_response(cls, cursor: DatabricksSQLCursorWrapper) -> DatabricksAdapterResponse: + def get_response( + cls, cursor: DatabricksSQLCursorWrapper, include_impact: bool = False + ) -> DatabricksAdapterResponse: _query_id = getattr(cursor, "hex_query_id", None) if cursor is None: logger.debug("No cursor was provided. Query ID not available.") @@ -701,7 +751,22 @@ def get_response(cls, cursor: DatabricksSQLCursorWrapper) -> DatabricksAdapterRe else: query_id = _query_id message = "OK" - return DatabricksAdapterResponse(_message=message, query_id=query_id) # type: ignore + + response = DatabricksAdapterResponse( + _message=message, + query_id=query_id, # type: ignore + ) + + # If some query metadata are available, add them to the adapter response + if include_impact: + query_impact = cursor.query_impact() + logger.debug(query_impact) + response.rows_affected = query_impact.num_affected_rows + response.rows_inserted = query_impact.num_inserted_rows + response.rows_updated = query_impact.num_updated_rows + response.rows_deleted = query_impact.num_deleted_rows + + return response class ExtendedSessionConnectionManager(DatabricksConnectionManager):