diff --git a/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/datahub_listener.py b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/datahub_listener.py index dfb92573dbf46..bd30e676eceb8 100644 --- a/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/datahub_listener.py +++ b/metadata-ingestion-modules/airflow-plugin/src/datahub_airflow_plugin/datahub_listener.py @@ -133,6 +133,37 @@ def run_in_thread(f: _F) -> _F: @functools.wraps(f) def wrapper(*args, **kwargs): + def safe_execution(*args, **kwargs): + # Filter out database-related parameters to prevent transaction conflicts. + # These objects contain active database session references that can interfere + # with Airflow's strict transaction boundaries and cause HA lock violations. + + # Filter SQLAlchemy session objects by type, regardless of parameter name or position + # See: https://github.com/apache/airflow/blob/main/airflow-core/src/airflow/listeners/spec/taskinstance.py + # Session parameters provide direct SQLAlchemy database access which can interfere + # with Airflow's transaction boundaries and cause "UNEXPECTED COMMIT" errors + + # Import here to avoid import issues if SQLAlchemy isn't available + try: + from sqlalchemy.orm import Session + + # Filter session objects from positional arguments + filtered_args = tuple( + arg for arg in args if not isinstance(arg, Session) + ) + + # Filter session objects from keyword arguments + filtered_kwargs = { + k: v for k, v in kwargs.items() if not isinstance(v, Session) + } + + except ImportError: + # Fallback to name-based filtering if SQLAlchemy import fails + filtered_args = args + filtered_kwargs = {k: v for k, v in kwargs.items() if k != "session"} + + f(*filtered_args, **filtered_kwargs) + try: if _RUN_IN_THREAD: # A poor-man's timeout mechanism. @@ -140,7 +171,7 @@ def wrapper(*args, **kwargs): # are slow or the DataHub API is slow to respond. thread = threading.Thread( - target=f, args=args, kwargs=kwargs, daemon=True + target=safe_execution, args=args, kwargs=kwargs, daemon=True ) thread.start() @@ -161,7 +192,9 @@ def wrapper(*args, **kwargs): f"Thread for {f.__name__} finished after {time.time() - start_time} seconds" ) else: - f(*args, **kwargs) + # Run synchronously but with database isolation to prevent transaction conflicts + # This prevents "UNEXPECTED COMMIT - THIS WILL BREAK HA LOCKS!" errors + safe_execution(*args, **kwargs) except Exception as e: logger.warning(e, exc_info=True)