diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..9cf1b75 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,98 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Project Overview + +Dagger is a configuration-driven framework that transforms YAML definitions into Apache Airflow DAGs. It uses dataset lineage (matching inputs/outputs) to automatically build dependency graphs across workflows. + +## Common Commands + +### Development Setup +```bash +make install-dev # Create venv, install package in editable mode with dev/test deps +source venv/bin/activate +``` + +### Testing +```bash +make test # Run all tests with coverage (sets AIRFLOW_HOME automatically) + +# Run a single test file +AIRFLOW_HOME=$(pwd)/tests/fixtures/config_finder/root/ ENV=local pytest -s tests/path/to/test_file.py + +# Run a specific test +AIRFLOW_HOME=$(pwd)/tests/fixtures/config_finder/root/ ENV=local pytest -s tests/path/to/test_file.py::test_function_name +``` + +### Linting +```bash +make lint # Run flake8 on dagger and tests directories +black dagger tests # Format code +``` + +### Local Airflow Testing +```bash +make test-airflow # Build and start Airflow in Docker (localhost:8080, user: dev_user, pass: dev_user) +make stop-airflow # Stop Airflow containers +``` + +### CLI +```bash +dagger --help +dagger list-tasks # Show available task types +dagger list-ios # Show available IO types +dagger init-pipeline # Create a new pipeline.yaml +dagger init-task --type= # Add a task configuration +dagger init-io --type= # Add an IO definition +dagger print-graph # Visualize dependency graph +``` + +## Architecture + +### Core Flow +1. **ConfigFinder** discovers pipeline directories (each with `pipeline.yaml` + task YAML files) +2. **ConfigProcessor** loads YAML configs with environment variable support +3. **TaskFactory/IOFactory** use reflection to instantiate task/IO objects from YAML +4. **TaskGraph** builds a 3-layer graph: Pipeline → Task → Dataset nodes +5. **DagCreator** traverses the graph and generates Airflow DAGs using **OperatorFactory** + +### Key Directories +- `dagger/pipeline/tasks/` - Task type definitions (DbtTask, SparkTask, AthenaTransformTask, etc.) +- `dagger/pipeline/ios/` - IO type definitions (S3, Redshift, Athena, Databricks, etc.) +- `dagger/dag_creator/airflow/operator_creators/` - One creator per task type, translates tasks to Airflow operators +- `dagger/graph/` - Graph construction from task inputs/outputs +- `dagger/config_finder/` - YAML discovery and loading +- `tests/fixtures/config_finder/root/dags/` - Example DAG configurations for testing + +### Adding a New Task Type +1. Create task definition in `dagger/pipeline/tasks/` (subclass of Task) +2. Create any needed IOs in `dagger/pipeline/ios/` (if new data sources) +3. Create operator creator in `dagger/dag_creator/airflow/operator_creators/` +4. Register in `dagger/dag_creator/airflow/operator_factory.py` + +### Configuration Files +- `pipeline.yaml` - Pipeline metadata (owner, schedule, alerts, airflow_parameters) +- `[taskname].yaml` - Task configs (type, inputs, outputs, task-specific params) +- `dagger_config.yaml` - System config (Neo4j, Elasticsearch, Spark settings) + +### Key Patterns +- **Factory Pattern**: TaskFactory/IOFactory auto-discover types via reflection +- **Strategy Pattern**: OperatorCreator subclasses handle task-specific operator creation +- **Dataset Aliasing**: IO `alias()` method enables automatic dependency detection across pipelines + +## Coding Standards + +### Avoid getattr +Do not use `getattr` for accessing task or IO properties. Instead, define explicit properties on the class. This ensures: +- Type safety and IDE autocompletion +- Clear interface contracts +- Easier debugging and testing + +```python +# Bad - avoid this pattern +value = getattr(self._task, 'some_property', default) + +# Good - use explicit properties +value = self._task.some_property # Property defined on task class +``` diff --git a/dagger/cli/module.py b/dagger/cli/module.py index 931e809..d77897d 100644 --- a/dagger/cli/module.py +++ b/dagger/cli/module.py @@ -1,19 +1,34 @@ +import json + import click +import yaml + from dagger.utilities.module import Module from dagger.utils import Printer -import json def parse_key_value(ctx, param, value): - #print('YYY', value) + """Parse key=value pairs where value is a path to JSON or YAML file. + + Args: + ctx: Click context. + param: Click parameter. + value: List of key=value pairs. + + Returns: + Dictionary mapping variable names to parsed file contents. + """ if not value: return {} key_value_dict = {} for pair in value: try: key, val_file_path = pair.split('=', 1) - #print('YYY', key, val_file_path, pair) - val = json.load(open(val_file_path)) + with open(val_file_path, 'r') as f: + if val_file_path.endswith(('.yaml', '.yml')): + val = yaml.safe_load(f) + else: + val = json.load(f) key_value_dict[key] = val except ValueError: raise click.BadParameter(f"Key-value pair '{pair}' is not in the format key=value") @@ -22,7 +37,7 @@ def parse_key_value(ctx, param, value): @click.command() @click.option("--config_file", "-c", help="Path to module config file") @click.option("--target_dir", "-t", help="Path to directory to generate the task configs to") -@click.option("--jinja_parameters", "-j", callback=parse_key_value, multiple=True, default=None, help="Path to jinja parameters json file in the format: =") +@click.option("--jinja_parameters", "-j", callback=parse_key_value, multiple=True, default=None, help="Jinja parameters file in the format: =") def generate_tasks(config_file: str, target_dir: str, jinja_parameters: dict) -> None: """ Generating tasks for a module based on config diff --git a/dagger/dag_creator/airflow/operator_creators/databricks_dlt_creator.py b/dagger/dag_creator/airflow/operator_creators/databricks_dlt_creator.py new file mode 100644 index 0000000..87a11ac --- /dev/null +++ b/dagger/dag_creator/airflow/operator_creators/databricks_dlt_creator.py @@ -0,0 +1,129 @@ +"""Operator creator for Databricks DLT (Delta Live Tables) pipelines.""" + +import logging +from typing import Any + +from airflow.models import BaseOperator, DAG + +from dagger.dag_creator.airflow.operator_creator import OperatorCreator +from dagger.pipeline.tasks.databricks_dlt_task import DatabricksDLTTask + +_logger = logging.getLogger(__name__) + + +def _cancel_databricks_run(context: dict[str, Any]) -> None: + """Cancel a Databricks job run when task fails or is cleared. + + This callback retrieves the run_id from XCom and cancels the corresponding + Databricks job run. Used as on_failure_callback to ensure jobs are cancelled + when tasks are marked as failed. + + Args: + context: Airflow context dictionary containing task instance and other metadata. + """ + ti = context.get("task_instance") + if not ti: + _logger.warning("No task instance in context, cannot cancel Databricks run") + return + + # Get run_id from XCom (pushed by DatabricksRunNowOperator) + run_id = ti.xcom_pull(task_ids=ti.task_id, key="run_id") + if not run_id: + _logger.warning(f"No run_id found in XCom for task {ti.task_id}") + return + + # Get the databricks_conn_id from the operator (set during operator creation) + databricks_conn_id = ti.task.databricks_conn_id + + # Import here to avoid import errors if databricks provider not installed + # and to only import when actually needed (after early returns) + try: + from airflow.providers.databricks.hooks.databricks import DatabricksHook + + hook = DatabricksHook(databricks_conn_id=databricks_conn_id) + hook.cancel_run(run_id) + _logger.info(f"Cancelled Databricks run {run_id} for task {ti.task_id}") + except ImportError: + _logger.error( + "airflow-providers-databricks is not installed, cannot cancel run" + ) + except Exception as e: + _logger.error(f"Failed to cancel Databricks run {run_id}: {e}") + + +class DatabricksDLTCreator(OperatorCreator): + """Creates operators for triggering Databricks DLT pipelines via Jobs. + + This creator uses DatabricksRunNowOperator to trigger a Databricks Job + that wraps the DLT pipeline. The job is identified by name and must be + defined in the Databricks Asset Bundle. + + Attributes: + ref_name: Reference name used by OperatorFactory to match this creator + with DatabricksDLTTask instances. + """ + + ref_name: str = "databricks_dlt" + + def __init__(self, task: DatabricksDLTTask, dag: DAG) -> None: + """Initialize the DatabricksDLTCreator. + + Args: + task: The DatabricksDLTTask containing pipeline configuration. + dag: The Airflow DAG this operator will belong to. + """ + super().__init__(task, dag) + + def _create_operator(self, **kwargs: Any) -> BaseOperator: + """Create a DatabricksRunNowOperator for the DLT pipeline. + + Creates an Airflow operator that triggers an existing Databricks Job + by name. The job must have a pipeline_task that references the DLT + pipeline. + + Args: + **kwargs: Additional keyword arguments passed to the operator. + + Returns: + A configured DatabricksRunNowOperator instance. + + Raises: + ValueError: If job_name is empty or not provided. + """ + # Import here to avoid import errors if databricks provider not installed + from datetime import timedelta + + from airflow.providers.databricks.operators.databricks import ( + DatabricksRunNowOperator, + ) + + # Get task parameters - defaults are handled in DatabricksDLTTask + job_name: str = self._task.job_name + if not job_name: + raise ValueError( + f"job_name is required for DatabricksDLTTask '{self._task.name}'" + ) + databricks_conn_id: str = self._task.databricks_conn_id + wait_for_completion: bool = self._task.wait_for_completion + poll_interval_seconds: int = self._task.poll_interval_seconds + timeout_seconds: int = self._task.timeout_seconds + + # DatabricksRunNowOperator triggers an existing Databricks Job by name + # The job must have a pipeline_task that references the DLT pipeline + # Note: timeout is handled via Airflow's execution_timeout, not a direct parameter + # Note: on_kill() is already implemented in DatabricksRunNowOperator to cancel runs + # We add on_failure_callback to also cancel when task is marked as failed + operator: BaseOperator = DatabricksRunNowOperator( + dag=self._dag, + task_id=self._task.name, + databricks_conn_id=databricks_conn_id, + job_name=job_name, + wait_for_termination=wait_for_completion, + polling_period_seconds=poll_interval_seconds, + execution_timeout=timedelta(seconds=timeout_seconds), + do_xcom_push=True, # Required to store run_id for cancellation callback + on_failure_callback=_cancel_databricks_run, + **kwargs, + ) + + return operator diff --git a/dagger/dag_creator/airflow/operator_factory.py b/dagger/dag_creator/airflow/operator_factory.py index 2a1654a..dd7344e 100644 --- a/dagger/dag_creator/airflow/operator_factory.py +++ b/dagger/dag_creator/airflow/operator_factory.py @@ -4,6 +4,7 @@ airflow_op_creator, athena_transform_creator, batch_creator, + databricks_dlt_creator, dbt_creator, dummy_creator, python_creator, diff --git a/dagger/pipeline/ios/databricks_io.py b/dagger/pipeline/ios/databricks_io.py index 15be2c1..7c7b4d2 100644 --- a/dagger/pipeline/ios/databricks_io.py +++ b/dagger/pipeline/ios/databricks_io.py @@ -1,12 +1,45 @@ +"""IO representation for Databricks Unity Catalog tables.""" + +from typing import Any + from dagger.pipeline.io import IO from dagger.utilities.config_validator import Attribute class DatabricksIO(IO): - ref_name = "databricks" + """IO representation for Databricks Unity Catalog tables. + + Represents a table in Databricks Unity Catalog with catalog.schema.table naming. + Used to define inputs and outputs for tasks that read from or write to + Databricks tables. + + Attributes: + ref_name: Reference name used by IOFactory to instantiate this IO type. + catalog: Databricks Unity Catalog name. + schema: Schema/database name within the catalog. + table: Table name. + + Example YAML configuration: + type: databricks + name: my_output_table + catalog: prod_catalog + schema: analytics + table: user_metrics + """ + + ref_name: str = "databricks" @classmethod - def init_attributes(cls, orig_cls): + def init_attributes(cls, orig_cls: type) -> None: + """Initialize configuration attributes for YAML parsing. + + Registers all attributes that can be specified in the YAML configuration. + Called by the IO metaclass during class creation. + + Args: + orig_cls: The original class being initialized (used for attribute + registration). + """ cls.add_config_attributes( [ Attribute(attribute_name="catalog"), @@ -15,32 +48,81 @@ def init_attributes(cls, orig_cls): ] ) - def __init__(self, io_config, config_location): + def __init__(self, io_config: dict[str, Any], config_location: str) -> None: + """Initialize a DatabricksIO instance. + + Args: + io_config: Dictionary containing the IO configuration from YAML. + config_location: Path to the configuration file for error reporting. + + Raises: + DaggerMissingFieldException: If required fields (catalog, schema, table) + are missing from the configuration. + """ super().__init__(io_config, config_location) - self._catalog = self.parse_attribute("catalog") - self._schema = self.parse_attribute("schema") - self._table = self.parse_attribute("table") + self._catalog: str = self.parse_attribute("catalog") + self._schema: str = self.parse_attribute("schema") + self._table: str = self.parse_attribute("table") - def alias(self): + def alias(self) -> str: + """Return the unique alias for this IO in databricks:// URI format. + + The alias is used for dataset lineage tracking and dependency resolution + across pipelines. + + Returns: + A unique identifier string in the format + 'databricks://{catalog}/{schema}/{table}'. + """ return f"databricks://{self._catalog}/{self._schema}/{self._table}" @property - def rendered_name(self): + def rendered_name(self) -> str: + """Return the fully qualified table name in dot notation. + + This format is used in SQL queries and Databricks API calls. + + Returns: + The table name in '{catalog}.{schema}.{table}' format. + """ return f"{self._catalog}.{self._schema}.{self._table}" @property - def airflow_name(self): + def airflow_name(self) -> str: + """Return an Airflow-safe identifier for this table. + + Airflow task/dataset IDs cannot contain dots, so this returns a + hyphen-separated format suitable for use in Airflow contexts. + + Returns: + The table name in 'databricks-{catalog}-{schema}-{table}' format. + """ return f"databricks-{self._catalog}-{self._schema}-{self._table}" @property - def catalog(self): + def catalog(self) -> str: + """Return the Databricks Unity Catalog name. + + Returns: + The catalog name. + """ return self._catalog @property - def schema(self): + def schema(self) -> str: + """Return the schema/database name within the catalog. + + Returns: + The schema name. + """ return self._schema @property - def table(self): + def table(self) -> str: + """Return the table name. + + Returns: + The table name. + """ return self._table diff --git a/dagger/pipeline/task_factory.py b/dagger/pipeline/task_factory.py index 9ed79e7..f5f80bb 100644 --- a/dagger/pipeline/task_factory.py +++ b/dagger/pipeline/task_factory.py @@ -3,6 +3,7 @@ airflow_op_task, athena_transform_task, batch_task, + databricks_dlt_task, dbt_task, dummy_task, python_task, @@ -12,7 +13,7 @@ reverse_etl_task, spark_task, sqoop_task, - soda_task + soda_task, ) from dagger.utilities.classes import get_deep_obj_subclasses diff --git a/dagger/pipeline/tasks/databricks_dlt_task.py b/dagger/pipeline/tasks/databricks_dlt_task.py new file mode 100644 index 0000000..4f0b113 --- /dev/null +++ b/dagger/pipeline/tasks/databricks_dlt_task.py @@ -0,0 +1,164 @@ +"""Task configuration for Databricks DLT (Delta Live Tables) pipelines.""" + +from typing import Any, Optional + +from dagger.pipeline.task import Task +from dagger.utilities.config_validator import Attribute + + +class DatabricksDLTTask(Task): + """Task configuration for triggering Databricks DLT pipelines via Jobs. + + This task type uses DatabricksRunNowOperator to trigger a Databricks Job + that wraps the DLT pipeline. The job is identified by name and must be + defined in the Databricks Asset Bundle. + + Attributes: + ref_name: Reference name used by TaskFactory to instantiate this task type. + job_name: Databricks Job name that triggers the DLT pipeline. + databricks_conn_id: Airflow connection ID for Databricks. + wait_for_completion: Whether to wait for job completion. + poll_interval_seconds: Polling interval in seconds. + timeout_seconds: Timeout in seconds. + cancel_on_kill: Whether to cancel Databricks job if Airflow task is killed. + + Example YAML configuration: + type: databricks_dlt + description: Run DLT pipeline users + inputs: + - type: athena + schema: ddb_changelogs + table: order_preference + follow_external_dependency: true + outputs: + - type: databricks + catalog: ${ENV_MARTS} + schema: dlt_users + table: silver_order_preference + task_parameters: + job_name: dlt-users + databricks_conn_id: databricks_default + wait_for_completion: true + poll_interval_seconds: 30 + timeout_seconds: 3600 + """ + + ref_name: str = "databricks_dlt" + + @classmethod + def init_attributes(cls, orig_cls: type) -> None: + """Initialize configuration attributes for YAML parsing. + + Registers all task_parameters attributes that can be specified in the + YAML configuration file. Called by the Task metaclass during class creation. + + Args: + orig_cls: The original class being initialized (used for attribute registration). + """ + cls.add_config_attributes( + [ + Attribute( + attribute_name="job_name", + parent_fields=["task_parameters"], + comment="Databricks Job name that triggers the DLT pipeline", + ), + Attribute( + attribute_name="databricks_conn_id", + parent_fields=["task_parameters"], + required=False, + comment="Airflow connection ID for Databricks (default: databricks_default)", + ), + Attribute( + attribute_name="wait_for_completion", + parent_fields=["task_parameters"], + required=False, + validator=bool, + comment="Wait for job to complete (default: true)", + ), + Attribute( + attribute_name="poll_interval_seconds", + parent_fields=["task_parameters"], + required=False, + validator=int, + comment="Polling interval in seconds (default: 30)", + ), + Attribute( + attribute_name="timeout_seconds", + parent_fields=["task_parameters"], + required=False, + validator=int, + comment="Timeout in seconds (default: 3600)", + ), + Attribute( + attribute_name="cancel_on_kill", + parent_fields=["task_parameters"], + required=False, + validator=bool, + comment="Cancel Databricks job if Airflow task is killed (default: true)", + ), + ] + ) + + def __init__( + self, + name: str, + pipeline_name: str, + pipeline: Any, + job_config: dict[str, Any], + ) -> None: + """Initialize a DatabricksDLTTask instance. + + Args: + name: The task name (used as task_id in Airflow). + pipeline_name: Name of the Dagger pipeline this task belongs to. + pipeline: The parent Pipeline object. + job_config: Dictionary containing the task configuration from YAML. + """ + super().__init__(name, pipeline_name, pipeline, job_config) + + self._job_name: str = self.parse_attribute("job_name") + self._databricks_conn_id: str = ( + self.parse_attribute("databricks_conn_id") or "databricks_default" + ) + wait_for_completion: Optional[bool] = self.parse_attribute("wait_for_completion") + self._wait_for_completion: bool = ( + wait_for_completion if wait_for_completion is not None else True + ) + self._poll_interval_seconds: int = ( + self.parse_attribute("poll_interval_seconds") or 30 + ) + self._timeout_seconds: int = self.parse_attribute("timeout_seconds") or 3600 + cancel_on_kill: Optional[bool] = self.parse_attribute("cancel_on_kill") + self._cancel_on_kill: bool = ( + cancel_on_kill if cancel_on_kill is not None else True + ) + + @property + def job_name(self) -> str: + """Databricks Job name that triggers the DLT pipeline.""" + return self._job_name + + @property + def databricks_conn_id(self) -> str: + """Airflow connection ID for Databricks.""" + return self._databricks_conn_id + + @property + def wait_for_completion(self) -> bool: + """Whether to wait for job completion.""" + return self._wait_for_completion + + @property + def poll_interval_seconds(self) -> int: + """Polling interval in seconds.""" + return self._poll_interval_seconds + + @property + def timeout_seconds(self) -> int: + """Timeout in seconds.""" + return self._timeout_seconds + + @property + def cancel_on_kill(self) -> bool: + """Whether to cancel Databricks job if Airflow task is killed.""" + return self._cancel_on_kill diff --git a/dagger/plugins/__init__.py b/dagger/plugins/__init__.py new file mode 100644 index 0000000..26acb8c --- /dev/null +++ b/dagger/plugins/__init__.py @@ -0,0 +1 @@ +"""Dagger plugins for task generation.""" diff --git a/dagger/utilities/module.py b/dagger/utilities/module.py index 7f33690..a12c25f 100644 --- a/dagger/utilities/module.py +++ b/dagger/utilities/module.py @@ -51,7 +51,9 @@ def read_task_config(self, task): return content @staticmethod - def load_plugins_to_jinja_environment(environment: jinja2.Environment) -> jinja2.Environment: + def load_plugins_to_jinja_environment( + environment: jinja2.Environment, + ) -> jinja2.Environment: """ Dynamically load all classes(plugins) from the folders defined in the conf.PLUGIN_DIRS variable. The folder contains all plugins that are part of the project. @@ -60,12 +62,20 @@ def load_plugins_to_jinja_environment(environment: jinja2.Environment) -> jinja2 """ for plugin_path in conf.PLUGIN_DIRS: for root, dirs, files in os.walk(plugin_path): - dirs[:] = [directory for directory in dirs if not directory.lower().startswith("test")] + dirs[:] = [ + directory + for directory in dirs + if not directory.lower().startswith("test") + ] for plugin_file in files: - if plugin_file.endswith(".py") and not (plugin_file.startswith("__") or plugin_file.startswith("test")): + if plugin_file.endswith(".py") and not ( + plugin_file.startswith("__") or plugin_file.startswith("test") + ): module_name = plugin_file.replace(".py", "") module_path = os.path.join(root, plugin_file) - spec = importlib.util.spec_from_file_location(module_name, module_path) + spec = importlib.util.spec_from_file_location( + module_name, module_path + ) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) @@ -84,8 +94,7 @@ def replace_template_parameters(_task_str, _template_parameters): return ( rendered_task # TODO Remove this hack and use Jinja escaping instead of special expression in template files - .replace("__CBS__", "{") - .replace("__CBE__", "}") + .replace("__CBS__", "{").replace("__CBE__", "}") ) @staticmethod @@ -102,7 +111,7 @@ def generate_task_configs(self): template_parameters = {} template_parameters.update(self._default_parameters or {}) template_parameters.update(attrs) - template_parameters['branch_name'] = branch_name + template_parameters["branch_name"] = branch_name template_parameters.update(self._jinja_parameters) for task, task_yaml in self._tasks.items(): diff --git a/dockers/airflow/Dockerfile b/dockers/airflow/Dockerfile index 2bd40d5..71e73d7 100644 --- a/dockers/airflow/Dockerfile +++ b/dockers/airflow/Dockerfile @@ -52,7 +52,7 @@ RUN curl -Ls "https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip" -o "awsc RUN pip install -U --progress-bar off --no-cache-dir pip setuptools wheel COPY requirements.txt requirements.txt -RUN pip install --progress-bar off --no-cache-dir apache-airflow[amazon,postgres,s3,statsd]==$AIRFLOW_VERSION --constraint $AIRFLOW_CONSTRAINTS && \ +RUN pip install --progress-bar off --no-cache-dir apache-airflow[amazon,databricks,postgres,s3,statsd]==$AIRFLOW_VERSION --constraint $AIRFLOW_CONSTRAINTS && \ pip install --progress-bar off --no-cache-dir -r requirements.txt && \ apt-get purge --auto-remove -yq $BUILD_DEPS && \ apt-get autoremove --purge -yq && \ diff --git a/reqs/dev.txt b/reqs/dev.txt index b39d00a..cb96f9e 100644 --- a/reqs/dev.txt +++ b/reqs/dev.txt @@ -1,5 +1,5 @@ pip==24.0 -apache-airflow[amazon,postgres,s3,statsd]==2.11.0 --constraint "https://raw.githubusercontent.com/apache/airflow/constraints-2.11.0/constraints-3.12.txt" +apache-airflow[amazon,databricks,postgres,s3,statsd]==2.11.0 --constraint "https://raw.githubusercontent.com/apache/airflow/constraints-2.11.0/constraints-3.12.txt" black==22.10.0 bumpversion==0.6.0 coverage==7.4.4 diff --git a/reqs/test.txt b/reqs/test.txt index 6bb6c2e..3b97347 100644 --- a/reqs/test.txt +++ b/reqs/test.txt @@ -1,4 +1,4 @@ -apache-airflow[amazon,postgres,s3,statsd]==2.11.0 --constraint "https://raw.githubusercontent.com/apache/airflow/constraints-2.11.0/constraints-3.12.txt" +apache-airflow[amazon,databricks,postgres,s3,statsd]==2.11.0 --constraint "https://raw.githubusercontent.com/apache/airflow/constraints-2.11.0/constraints-3.12.txt" pytest-cov==4.0.0 pytest==7.2.0 graphviz diff --git a/tests/dag_creator/airflow/operator_creators/__init__.py b/tests/dag_creator/airflow/operator_creators/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/dag_creator/airflow/operator_creators/test_databricks_dlt_creator.py b/tests/dag_creator/airflow/operator_creators/test_databricks_dlt_creator.py new file mode 100644 index 0000000..39de91b --- /dev/null +++ b/tests/dag_creator/airflow/operator_creators/test_databricks_dlt_creator.py @@ -0,0 +1,276 @@ +"""Unit tests for DatabricksDLTCreator.""" + +import sys +import unittest +from datetime import timedelta +from unittest.mock import MagicMock, patch + +from dagger.dag_creator.airflow.operator_creators.databricks_dlt_creator import ( + DatabricksDLTCreator, + _cancel_databricks_run, +) + + +class TestDatabricksDLTCreator(unittest.TestCase): + """Test cases for DatabricksDLTCreator.""" + + def setUp(self) -> None: + """Set up test fixtures.""" + self.mock_task = MagicMock() + self.mock_task.name = "test_dlt_task" + self.mock_task.job_name = "test-dlt-job" + self.mock_task.databricks_conn_id = "databricks_default" + self.mock_task.wait_for_completion = True + self.mock_task.poll_interval_seconds = 30 + self.mock_task.timeout_seconds = 3600 + self.mock_task.cancel_on_kill = True + + self.mock_dag = MagicMock() + + # Set up mock for DatabricksRunNowOperator + self.mock_operator = MagicMock() + self.mock_operator_class = MagicMock(return_value=self.mock_operator) + self.mock_databricks_module = MagicMock() + self.mock_databricks_module.DatabricksRunNowOperator = self.mock_operator_class + + def test_ref_name(self) -> None: + """Test that ref_name is correctly set.""" + self.assertEqual(DatabricksDLTCreator.ref_name, "databricks_dlt") + + @patch.dict( + sys.modules, + {"airflow.providers.databricks.operators.databricks": MagicMock()}, + ) + def test_create_operator(self) -> None: + """Test operator creation returns an operator instance.""" + mock_operator = MagicMock() + mock_operator_class = MagicMock(return_value=mock_operator) + sys.modules[ + "airflow.providers.databricks.operators.databricks" + ].DatabricksRunNowOperator = mock_operator_class + + creator = DatabricksDLTCreator(self.mock_task, self.mock_dag) + operator = creator._create_operator() + + mock_operator_class.assert_called_once() + self.assertEqual(operator, mock_operator) + + @patch.dict( + sys.modules, + {"airflow.providers.databricks.operators.databricks": MagicMock()}, + ) + def test_create_operator_maps_task_properties(self) -> None: + """Test that task properties are correctly mapped to operator.""" + mock_operator_class = MagicMock() + sys.modules[ + "airflow.providers.databricks.operators.databricks" + ].DatabricksRunNowOperator = mock_operator_class + + creator = DatabricksDLTCreator(self.mock_task, self.mock_dag) + creator._create_operator() + + call_kwargs = mock_operator_class.call_args[1] + + self.assertEqual(call_kwargs["dag"], self.mock_dag) + self.assertEqual(call_kwargs["task_id"], "test_dlt_task") + self.assertEqual(call_kwargs["databricks_conn_id"], "databricks_default") + self.assertEqual(call_kwargs["job_name"], "test-dlt-job") + self.assertEqual(call_kwargs["wait_for_termination"], True) + self.assertEqual(call_kwargs["polling_period_seconds"], 30) + self.assertEqual(call_kwargs["execution_timeout"], timedelta(seconds=3600)) + self.assertTrue(call_kwargs["do_xcom_push"]) + self.assertEqual(call_kwargs["on_failure_callback"], _cancel_databricks_run) + + @patch.dict( + sys.modules, + {"airflow.providers.databricks.operators.databricks": MagicMock()}, + ) + def test_create_operator_with_custom_values(self) -> None: + """Test operator creation with non-default values.""" + self.mock_task.databricks_conn_id = "custom_conn" + self.mock_task.wait_for_completion = False + self.mock_task.poll_interval_seconds = 60 + self.mock_task.timeout_seconds = 7200 + + mock_operator_class = MagicMock() + sys.modules[ + "airflow.providers.databricks.operators.databricks" + ].DatabricksRunNowOperator = mock_operator_class + + creator = DatabricksDLTCreator(self.mock_task, self.mock_dag) + creator._create_operator() + + call_kwargs = mock_operator_class.call_args[1] + + self.assertEqual(call_kwargs["databricks_conn_id"], "custom_conn") + self.assertEqual(call_kwargs["wait_for_termination"], False) + self.assertEqual(call_kwargs["polling_period_seconds"], 60) + self.assertEqual(call_kwargs["execution_timeout"], timedelta(seconds=7200)) + + @patch.dict( + sys.modules, + {"airflow.providers.databricks.operators.databricks": MagicMock()}, + ) + def test_create_operator_empty_job_name_raises_error(self) -> None: + """Test that empty job_name raises ValueError.""" + self.mock_task.job_name = "" + + creator = DatabricksDLTCreator(self.mock_task, self.mock_dag) + + with self.assertRaises(ValueError) as context: + creator._create_operator() + + self.assertIn("job_name is required", str(context.exception)) + self.assertIn("test_dlt_task", str(context.exception)) + + @patch.dict( + sys.modules, + {"airflow.providers.databricks.operators.databricks": MagicMock()}, + ) + def test_create_operator_none_job_name_raises_error(self) -> None: + """Test that None job_name raises ValueError.""" + self.mock_task.job_name = None + + creator = DatabricksDLTCreator(self.mock_task, self.mock_dag) + + with self.assertRaises(ValueError) as context: + creator._create_operator() + + self.assertIn("job_name is required", str(context.exception)) + + @patch.dict( + sys.modules, + {"airflow.providers.databricks.operators.databricks": MagicMock()}, + ) + def test_create_operator_passes_kwargs(self) -> None: + """Test that additional kwargs are passed to operator.""" + mock_operator_class = MagicMock() + sys.modules[ + "airflow.providers.databricks.operators.databricks" + ].DatabricksRunNowOperator = mock_operator_class + + creator = DatabricksDLTCreator(self.mock_task, self.mock_dag) + creator._create_operator(retries=3, retry_delay=60) + + call_kwargs = mock_operator_class.call_args[1] + + self.assertEqual(call_kwargs["retries"], 3) + self.assertEqual(call_kwargs["retry_delay"], 60) + + +class TestCancelDatabricksRun(unittest.TestCase): + """Test cases for _cancel_databricks_run callback.""" + + def test_cancel_run_no_task_instance(self) -> None: + """Test callback handles missing task instance gracefully.""" + context: dict = {} + + # Should not raise, just log warning + _cancel_databricks_run(context) + + def test_cancel_run_no_run_id(self) -> None: + """Test callback handles missing run_id gracefully.""" + mock_ti = MagicMock() + mock_ti.task_id = "test_task" + mock_ti.xcom_pull.return_value = None + + context = {"task_instance": mock_ti} + + # Should not raise, just log warning + _cancel_databricks_run(context) + + mock_ti.xcom_pull.assert_called_once_with(task_ids="test_task", key="run_id") + + @patch.dict( + sys.modules, + {"airflow.providers.databricks.hooks.databricks": MagicMock()}, + ) + def test_cancel_run_success(self) -> None: + """Test successful cancellation of Databricks run.""" + mock_hook = MagicMock() + mock_hook_class = MagicMock(return_value=mock_hook) + sys.modules[ + "airflow.providers.databricks.hooks.databricks" + ].DatabricksHook = mock_hook_class + + mock_ti = MagicMock() + mock_ti.task_id = "test_task" + mock_ti.xcom_pull.return_value = "run_12345" + mock_ti.task.databricks_conn_id = "databricks_default" + + context = {"task_instance": mock_ti} + + _cancel_databricks_run(context) + + mock_hook_class.assert_called_once_with(databricks_conn_id="databricks_default") + mock_hook.cancel_run.assert_called_once_with("run_12345") + + @patch.dict( + sys.modules, + {"airflow.providers.databricks.hooks.databricks": MagicMock()}, + ) + def test_cancel_run_handles_exception(self) -> None: + """Test callback handles cancellation errors gracefully.""" + mock_hook = MagicMock() + mock_hook.cancel_run.side_effect = Exception("API Error") + mock_hook_class = MagicMock(return_value=mock_hook) + sys.modules[ + "airflow.providers.databricks.hooks.databricks" + ].DatabricksHook = mock_hook_class + + mock_ti = MagicMock() + mock_ti.task_id = "test_task" + mock_ti.xcom_pull.return_value = "run_12345" + mock_ti.task.databricks_conn_id = "databricks_default" + + context = {"task_instance": mock_ti} + + # Should not raise, just log error + _cancel_databricks_run(context) + + mock_hook.cancel_run.assert_called_once_with("run_12345") + + @patch.dict( + sys.modules, + {"airflow.providers.databricks.hooks.databricks": MagicMock()}, + ) + def test_cancel_run_with_custom_conn_id(self) -> None: + """Test cancellation uses correct connection ID.""" + mock_hook = MagicMock() + mock_hook_class = MagicMock(return_value=mock_hook) + sys.modules[ + "airflow.providers.databricks.hooks.databricks" + ].DatabricksHook = mock_hook_class + + mock_ti = MagicMock() + mock_ti.task_id = "test_task" + mock_ti.xcom_pull.return_value = "run_67890" + mock_ti.task.databricks_conn_id = "custom_databricks_conn" + + context = {"task_instance": mock_ti} + + _cancel_databricks_run(context) + + mock_hook_class.assert_called_once_with( + databricks_conn_id="custom_databricks_conn" + ) + + @patch.dict( + sys.modules, + {"airflow.providers.databricks.hooks.databricks": None}, + ) + def test_cancel_run_handles_import_error(self) -> None: + """Test callback handles missing databricks provider gracefully.""" + mock_ti = MagicMock() + mock_ti.task_id = "test_task" + mock_ti.xcom_pull.return_value = "run_12345" + mock_ti.task.databricks_conn_id = "databricks_default" + + context = {"task_instance": mock_ti} + + # Should not raise, just log error + _cancel_databricks_run(context) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/fixtures/pipeline/tasks/databricks_dlt_task.yaml b/tests/fixtures/pipeline/tasks/databricks_dlt_task.yaml new file mode 100644 index 0000000..1902cf0 --- /dev/null +++ b/tests/fixtures/pipeline/tasks/databricks_dlt_task.yaml @@ -0,0 +1,22 @@ +type: databricks_dlt +description: Test DLT pipeline task +inputs: + - type: athena + name: input_table + schema: test_schema + table: input_table +outputs: + - type: databricks + name: output_table + catalog: test_catalog + schema: test_schema + table: output_table +airflow_task_parameters: +template_parameters: +task_parameters: + job_name: test-dlt-job + databricks_conn_id: databricks_test + wait_for_completion: true + poll_interval_seconds: 60 + timeout_seconds: 7200 + cancel_on_kill: true diff --git a/tests/pipeline/ios/test_databricks_io.py b/tests/pipeline/ios/test_databricks_io.py index b1d0c45..e4a9456 100644 --- a/tests/pipeline/ios/test_databricks_io.py +++ b/tests/pipeline/ios/test_databricks_io.py @@ -1,17 +1,225 @@ +"""Unit tests for DatabricksIO.""" + import unittest -from dagger.pipeline.io_factory import databricks_io import yaml +from dagger.pipeline.ios import databricks_io +from dagger.utilities.exceptions import DaggerMissingFieldException + + +class TestDatabricksIO(unittest.TestCase): + """Test cases for DatabricksIO.""" -class DbIOTest(unittest.TestCase): def setUp(self) -> None: - with open('tests/fixtures/pipeline/ios/databricks_io.yaml', "r") as stream: + """Set up test fixtures.""" + with open("tests/fixtures/pipeline/ios/databricks_io.yaml", "r") as stream: config = yaml.safe_load(stream) self.db_io = databricks_io.DatabricksIO(config, "/") - def test_properties(self): - self.assertEqual(self.db_io.alias(), "databricks://test_catalog/test_schema/test_table") - self.assertEqual(self.db_io.rendered_name, "test_catalog.test_schema.test_table") - self.assertEqual(self.db_io.airflow_name, "databricks-test_catalog-test_schema-test_table") + def test_ref_name(self) -> None: + """Test that ref_name is correctly set.""" + self.assertEqual(databricks_io.DatabricksIO.ref_name, "databricks") + + def test_catalog(self) -> None: + """Test catalog property.""" + self.assertEqual(self.db_io.catalog, "test_catalog") + + def test_schema(self) -> None: + """Test schema property.""" + self.assertEqual(self.db_io.schema, "test_schema") + + def test_table(self) -> None: + """Test table property.""" + self.assertEqual(self.db_io.table, "test_table") + + def test_alias(self) -> None: + """Test alias method returns databricks:// URI format.""" + self.assertEqual( + self.db_io.alias(), "databricks://test_catalog/test_schema/test_table" + ) + + def test_rendered_name(self) -> None: + """Test rendered_name returns dot-separated format.""" + self.assertEqual( + self.db_io.rendered_name, "test_catalog.test_schema.test_table" + ) + + def test_airflow_name(self) -> None: + """Test airflow_name returns hyphen-separated format.""" + self.assertEqual( + self.db_io.airflow_name, "databricks-test_catalog-test_schema-test_table" + ) + + def test_name(self) -> None: + """Test name property from base IO class.""" + self.assertEqual(self.db_io.name, "test") + + def test_has_dependency_default(self) -> None: + """Test that has_dependency defaults to True.""" + self.assertTrue(self.db_io.has_dependency) + + +class TestDatabricksIOInlineConfig(unittest.TestCase): + """Test cases for DatabricksIO with inline configuration.""" + + def test_with_minimal_config(self) -> None: + """Test DatabricksIO with minimal required configuration.""" + config = { + "type": "databricks", + "name": "minimal_table", + "catalog": "my_catalog", + "schema": "my_schema", + "table": "my_table", + } + + db_io = databricks_io.DatabricksIO(config, "/test/path") + + self.assertEqual(db_io.catalog, "my_catalog") + self.assertEqual(db_io.schema, "my_schema") + self.assertEqual(db_io.table, "my_table") + self.assertEqual(db_io.name, "minimal_table") + + def test_alias_format_with_special_characters(self) -> None: + """Test alias format with underscores and numbers.""" + config = { + "type": "databricks", + "name": "output_123", + "catalog": "prod_catalog_v2", + "schema": "analytics_schema", + "table": "user_events_2024", + } + + db_io = databricks_io.DatabricksIO(config, "/") + + self.assertEqual( + db_io.alias(), + "databricks://prod_catalog_v2/analytics_schema/user_events_2024", + ) + self.assertEqual( + db_io.rendered_name, "prod_catalog_v2.analytics_schema.user_events_2024" + ) + self.assertEqual( + db_io.airflow_name, + "databricks-prod_catalog_v2-analytics_schema-user_events_2024", + ) + + def test_has_dependency_false(self) -> None: + """Test that has_dependency can be set to False.""" + config = { + "type": "databricks", + "name": "no_dep_table", + "catalog": "cat", + "schema": "sch", + "table": "tbl", + "has_dependency": False, + } + + db_io = databricks_io.DatabricksIO(config, "/") + + self.assertFalse(db_io.has_dependency) + + +class TestDatabricksIOMissingFields(unittest.TestCase): + """Test cases for DatabricksIO error handling.""" + + def test_missing_catalog_raises_exception(self) -> None: + """Test that missing catalog raises DaggerMissingFieldException.""" + config = { + "type": "databricks", + "name": "test_table", + "schema": "test_schema", + "table": "test_table", + } + + with self.assertRaises(DaggerMissingFieldException): + databricks_io.DatabricksIO(config, "/test/config.yaml") + + def test_missing_schema_raises_exception(self) -> None: + """Test that missing schema raises DaggerMissingFieldException.""" + config = { + "type": "databricks", + "name": "test_table", + "catalog": "test_catalog", + "table": "test_table", + } + + with self.assertRaises(DaggerMissingFieldException): + databricks_io.DatabricksIO(config, "/test/config.yaml") + + def test_missing_table_raises_exception(self) -> None: + """Test that missing table raises DaggerMissingFieldException.""" + config = { + "type": "databricks", + "name": "test_table", + "catalog": "test_catalog", + "schema": "test_schema", + } + + with self.assertRaises(DaggerMissingFieldException): + databricks_io.DatabricksIO(config, "/test/config.yaml") + + def test_missing_name_raises_exception(self) -> None: + """Test that missing name raises DaggerMissingFieldException.""" + config = { + "type": "databricks", + "catalog": "test_catalog", + "schema": "test_schema", + "table": "test_table", + } + + with self.assertRaises(DaggerMissingFieldException): + databricks_io.DatabricksIO(config, "/test/config.yaml") + + +class TestDatabricksIOEquality(unittest.TestCase): + """Test cases for DatabricksIO equality comparison.""" + + def test_equal_ios_are_equal(self) -> None: + """Test that two IOs with same alias are equal.""" + config1 = { + "type": "databricks", + "name": "table1", + "catalog": "cat", + "schema": "sch", + "table": "tbl", + } + config2 = { + "type": "databricks", + "name": "table2", # Different name, same catalog.schema.table + "catalog": "cat", + "schema": "sch", + "table": "tbl", + } + + io1 = databricks_io.DatabricksIO(config1, "/") + io2 = databricks_io.DatabricksIO(config2, "/") + + self.assertEqual(io1, io2) + + def test_different_ios_are_not_equal(self) -> None: + """Test that two IOs with different aliases are not equal.""" + config1 = { + "type": "databricks", + "name": "table1", + "catalog": "cat1", + "schema": "sch", + "table": "tbl", + } + config2 = { + "type": "databricks", + "name": "table2", + "catalog": "cat2", + "schema": "sch", + "table": "tbl", + } + + io1 = databricks_io.DatabricksIO(config1, "/") + io2 = databricks_io.DatabricksIO(config2, "/") + + self.assertNotEqual(io1, io2) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/pipeline/tasks/test_databricks_dlt_task.py b/tests/pipeline/tasks/test_databricks_dlt_task.py new file mode 100644 index 0000000..a222148 --- /dev/null +++ b/tests/pipeline/tasks/test_databricks_dlt_task.py @@ -0,0 +1,176 @@ +"""Unit tests for DatabricksDLTTask.""" + +import unittest +from unittest.mock import MagicMock + +import yaml + +from dagger.pipeline.tasks.databricks_dlt_task import DatabricksDLTTask + + +class TestDatabricksDLTTask(unittest.TestCase): + """Test cases for DatabricksDLTTask.""" + + def setUp(self) -> None: + """Set up test fixtures.""" + with open( + "tests/fixtures/pipeline/tasks/databricks_dlt_task.yaml", "r" + ) as stream: + self.config = yaml.safe_load(stream) + + # Create a mock pipeline object + self.mock_pipeline = MagicMock() + self.mock_pipeline.directory = "tests/fixtures/pipeline/tasks" + + self.task = DatabricksDLTTask( + name="test_dlt_task", + pipeline_name="test_pipeline", + pipeline=self.mock_pipeline, + job_config=self.config, + ) + + def test_ref_name(self) -> None: + """Test that ref_name is correctly set.""" + self.assertEqual(DatabricksDLTTask.ref_name, "databricks_dlt") + + def test_job_name(self) -> None: + """Test job_name property.""" + self.assertEqual(self.task.job_name, "test-dlt-job") + + def test_databricks_conn_id(self) -> None: + """Test databricks_conn_id property.""" + self.assertEqual(self.task.databricks_conn_id, "databricks_test") + + def test_wait_for_completion(self) -> None: + """Test wait_for_completion property.""" + self.assertTrue(self.task.wait_for_completion) + + def test_poll_interval_seconds(self) -> None: + """Test poll_interval_seconds property.""" + self.assertEqual(self.task.poll_interval_seconds, 60) + + def test_timeout_seconds(self) -> None: + """Test timeout_seconds property.""" + self.assertEqual(self.task.timeout_seconds, 7200) + + def test_cancel_on_kill(self) -> None: + """Test cancel_on_kill property.""" + self.assertTrue(self.task.cancel_on_kill) + + def test_task_name(self) -> None: + """Test that task name is correctly set.""" + self.assertEqual(self.task.name, "test_dlt_task") + + def test_pipeline_name(self) -> None: + """Test that pipeline_name is correctly set.""" + self.assertEqual(self.task.pipeline_name, "test_pipeline") + + +class TestDatabricksDLTTaskDefaults(unittest.TestCase): + """Test cases for DatabricksDLTTask default values.""" + + def setUp(self) -> None: + """Set up test fixtures with minimal config.""" + self.config = { + "type": "databricks_dlt", + "description": "Test DLT task with defaults", + "inputs": [], + "outputs": [], + "airflow_task_parameters": None, + "template_parameters": None, + "task_parameters": { + "job_name": "minimal-dlt-job", + }, + } + + self.mock_pipeline = MagicMock() + self.mock_pipeline.directory = "tests/fixtures/pipeline/tasks" + + self.task = DatabricksDLTTask( + name="minimal_dlt_task", + pipeline_name="test_pipeline", + pipeline=self.mock_pipeline, + job_config=self.config, + ) + + def test_default_databricks_conn_id(self) -> None: + """Test default databricks_conn_id value.""" + self.assertEqual(self.task.databricks_conn_id, "databricks_default") + + def test_default_wait_for_completion(self) -> None: + """Test default wait_for_completion value.""" + self.assertTrue(self.task.wait_for_completion) + + def test_default_poll_interval_seconds(self) -> None: + """Test default poll_interval_seconds value.""" + self.assertEqual(self.task.poll_interval_seconds, 30) + + def test_default_timeout_seconds(self) -> None: + """Test default timeout_seconds value.""" + self.assertEqual(self.task.timeout_seconds, 3600) + + def test_default_cancel_on_kill(self) -> None: + """Test default cancel_on_kill value.""" + self.assertTrue(self.task.cancel_on_kill) + + +class TestDatabricksDLTTaskBooleanHandling(unittest.TestCase): + """Test cases for boolean parameter handling edge cases.""" + + def test_wait_for_completion_false(self) -> None: + """Test that wait_for_completion=false is correctly handled.""" + config = { + "type": "databricks_dlt", + "description": "Test", + "inputs": [], + "outputs": [], + "airflow_task_parameters": None, + "template_parameters": None, + "task_parameters": { + "job_name": "test-job", + "wait_for_completion": False, + }, + } + + mock_pipeline = MagicMock() + mock_pipeline.directory = "tests/fixtures/pipeline/tasks" + + task = DatabricksDLTTask( + name="test_task", + pipeline_name="test_pipeline", + pipeline=mock_pipeline, + job_config=config, + ) + + self.assertFalse(task.wait_for_completion) + + def test_cancel_on_kill_false(self) -> None: + """Test that cancel_on_kill=false is correctly handled.""" + config = { + "type": "databricks_dlt", + "description": "Test", + "inputs": [], + "outputs": [], + "airflow_task_parameters": None, + "template_parameters": None, + "task_parameters": { + "job_name": "test-job", + "cancel_on_kill": False, + }, + } + + mock_pipeline = MagicMock() + mock_pipeline.directory = "tests/fixtures/pipeline/tasks" + + task = DatabricksDLTTask( + name="test_task", + pipeline_name="test_pipeline", + pipeline=mock_pipeline, + job_config=config, + ) + + self.assertFalse(task.cancel_on_kill) + + +if __name__ == "__main__": + unittest.main()