Skip to content

Commit 22da959

Browse files
Add bundle name arg to list dags cli command (#45779)
Co-authored-by: Jed Cunningham <[email protected]>
1 parent 28c93d9 commit 22da959

File tree

5 files changed

+81
-17
lines changed

5 files changed

+81
-17
lines changed

airflow/api_connexion/schemas/dag_schema.py

+2
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ class Meta:
5151

5252
dag_id = auto_field(dump_only=True)
5353
dag_display_name = fields.String(attribute="dag_display_name", dump_only=True)
54+
bundle_name = auto_field(dump_only=True)
55+
bundle_version = auto_field(dump_only=True)
5456
is_paused = auto_field()
5557
is_active = auto_field(dump_only=True)
5658
last_parsed_time = auto_field(dump_only=True)

airflow/cli/cli_config.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ def string_lower_type(val):
172172
"--bundle-name",
173173
),
174174
help=("The name of the DAG bundle to use; may be provided more than once"),
175+
type=str,
175176
default=None,
176177
action="append",
177178
)
@@ -880,7 +881,7 @@ def string_lower_type(val):
880881
("--columns",),
881882
type=string_list_type,
882883
help="List of columns to render. (default: ['dag_id', 'fileloc', 'owner', 'is_paused'])",
883-
default=("dag_id", "fileloc", "owners", "is_paused"),
884+
default=("dag_id", "fileloc", "owners", "is_paused", "bundle_name", "bundle_version"),
884885
)
885886

886887
ARG_ASSET_LIST_COLUMNS = Arg(
@@ -978,7 +979,7 @@ class GroupCommand(NamedTuple):
978979
name="list",
979980
help="List all the DAGs",
980981
func=lazy_load_command("airflow.cli.commands.remote_commands.dag_command.dag_list_dags"),
981-
args=(ARG_SUBDIR, ARG_OUTPUT, ARG_VERBOSE, ARG_DAG_LIST_COLUMNS),
982+
args=(ARG_OUTPUT, ARG_VERBOSE, ARG_DAG_LIST_COLUMNS, ARG_BUNDLE_NAME),
982983
),
983984
ActionCommand(
984985
name="list-import-errors",

airflow/cli/commands/remote_commands/dag_command.py

+32-6
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
1616
# under the License.
17+
1718
"""Dag sub-commands."""
1819

1920
from __future__ import annotations
@@ -28,7 +29,7 @@
2829
from typing import TYPE_CHECKING
2930

3031
import re2
31-
from sqlalchemy import select
32+
from sqlalchemy import func, select
3233

3334
from airflow.api.client import get_current_api_client
3435
from airflow.api_connexion.schemas.dag_schema import dag_schema
@@ -38,6 +39,7 @@
3839
from airflow.exceptions import AirflowException
3940
from airflow.jobs.job import Job
4041
from airflow.models import DagBag, DagModel, DagRun, TaskInstance
42+
from airflow.models.errors import ParseImportError
4143
from airflow.models.serialized_dag import SerializedDagModel
4244
from airflow.sdk.definitions._internal.dag_parsing_context import _airflow_parsing_context_manager
4345
from airflow.utils import cli as cli_utils, timezone
@@ -224,6 +226,8 @@ def _get_dagbag_dag_details(dag: DAG) -> dict:
224226
return {
225227
"dag_id": dag.dag_id,
226228
"dag_display_name": dag.dag_display_name,
229+
"bundle_name": dag.get_bundle_name(),
230+
"bundle_version": dag.get_bundle_version(),
227231
"is_paused": dag.get_is_paused(),
228232
"is_active": dag.get_is_active(),
229233
"last_parsed_time": None,
@@ -322,11 +326,12 @@ def print_execution_interval(interval: DataInterval | None):
322326
@suppress_logs_and_warning
323327
@providers_configuration_loaded
324328
@provide_session
325-
def dag_list_dags(args, session=NEW_SESSION) -> None:
329+
def dag_list_dags(args, session: Session = NEW_SESSION) -> None:
326330
"""Display dags with or without stats at the command line."""
327331
cols = args.columns if args.columns else []
328332
invalid_cols = [c for c in cols if c not in dag_schema.fields]
329333
valid_cols = [c for c in cols if c in dag_schema.fields]
334+
330335
if invalid_cols:
331336
from rich import print as rich_print
332337

@@ -335,8 +340,18 @@ def dag_list_dags(args, session=NEW_SESSION) -> None:
335340
f"List of valid columns: {list(dag_schema.fields.keys())}",
336341
file=sys.stderr,
337342
)
338-
dagbag = DagBag(process_subdir(args.subdir))
339-
if dagbag.import_errors:
343+
344+
dagbag = DagBag(read_dags_from_db=True)
345+
dagbag.collect_dags_from_db()
346+
347+
# Get import errors from the DB
348+
query = select(func.count()).select_from(ParseImportError)
349+
if args.bundle_name:
350+
query = query.where(ParseImportError.bundle_name.in_(args.bundle_name))
351+
352+
dagbag_import_errors = session.scalar(query)
353+
354+
if dagbag_import_errors > 0:
340355
from rich import print as rich_print
341356

342357
rich_print(
@@ -353,8 +368,19 @@ def get_dag_detail(dag: DAG) -> dict:
353368
dag_detail = _get_dagbag_dag_details(dag)
354369
return {col: dag_detail[col] for col in valid_cols}
355370

371+
def filter_dags_by_bundle(dags: list[DAG], bundle_names: list[str] | None) -> list[DAG]:
372+
"""Filter DAGs based on the specified bundle name, if provided."""
373+
if not bundle_names:
374+
return dags
375+
376+
validate_dag_bundle_arg(bundle_names)
377+
return [dag for dag in dags if dag.get_bundle_name() in bundle_names]
378+
356379
AirflowConsole().print_as(
357-
data=sorted(dagbag.dags.values(), key=operator.attrgetter("dag_id")),
380+
data=sorted(
381+
filter_dags_by_bundle(list(dagbag.dags.values()), args.bundle_name),
382+
key=operator.attrgetter("dag_id"),
383+
),
358384
output=args.output,
359385
mapper=get_dag_detail,
360386
)
@@ -364,7 +390,7 @@ def get_dag_detail(dag: DAG) -> dict:
364390
@suppress_logs_and_warning
365391
@providers_configuration_loaded
366392
@provide_session
367-
def dag_details(args, session=NEW_SESSION):
393+
def dag_details(args, session: Session = NEW_SESSION):
368394
"""Get DAG details given a DAG id."""
369395
dag = DagModel.get_dagmodel(args.dag_id, session=session)
370396
if not dag:

tests/cli/commands/remote_commands/test_dag_command.py

+25-9
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,12 @@
5050

5151
from tests.models import TEST_DAGS_FOLDER
5252
from tests_common.test_utils.config import conf_vars
53-
from tests_common.test_utils.db import clear_db_dags, clear_db_runs, parse_and_sync_to_db
53+
from tests_common.test_utils.db import (
54+
clear_db_dags,
55+
clear_db_import_errors,
56+
clear_db_runs,
57+
parse_and_sync_to_db,
58+
)
5459

5560
DEFAULT_DATE = timezone.make_aware(datetime(2015, 1, 1), timezone=timezone.utc)
5661
if pendulum.__version__.startswith("3"):
@@ -77,7 +82,11 @@ def teardown_class(cls) -> None:
7782
clear_db_dags()
7883

7984
def setup_method(self):
80-
clear_db_runs() # clean-up all dag run before start each test
85+
clear_db_runs()
86+
clear_db_import_errors()
87+
88+
def teardown_method(self):
89+
clear_db_import_errors()
8190

8291
def test_show_dag_dependencies_print(self):
8392
with contextlib.redirect_stdout(StringIO()) as temp_stdout:
@@ -274,12 +283,17 @@ def test_cli_list_dags_invalid_cols(self):
274283
assert "Ignoring the following invalid columns: ['invalid_col']" in out
275284

276285
@conf_vars({("core", "load_examples"): "false"})
277-
def test_cli_list_dags_prints_import_errors(self):
278-
dag_path = os.path.join(TEST_DAGS_FOLDER, "test_invalid_cron.py")
279-
args = self.parser.parse_args(["dags", "list", "--output", "yaml", "--subdir", dag_path])
280-
with contextlib.redirect_stderr(StringIO()) as temp_stderr:
281-
dag_command.dag_list_dags(args)
282-
out = temp_stderr.getvalue()
286+
def test_cli_list_dags_prints_import_errors(self, configure_testing_dag_bundle, get_test_dag):
287+
path_to_parse = TEST_DAGS_FOLDER / "test_invalid_cron.py"
288+
get_test_dag("test_invalid_cron")
289+
290+
args = self.parser.parse_args(["dags", "list", "--output", "yaml", "--bundle-name", "testing"])
291+
292+
with configure_testing_dag_bundle(path_to_parse):
293+
with contextlib.redirect_stderr(StringIO()) as temp_stderr:
294+
dag_command.dag_list_dags(args)
295+
out = temp_stderr.getvalue()
296+
283297
assert "Failed to load all files." in out
284298

285299
@conf_vars({("core", "load_examples"): "true"})
@@ -305,7 +319,9 @@ def test_dagbag_dag_col(self):
305319
@conf_vars({("core", "load_examples"): "false"})
306320
def test_cli_list_import_errors(self):
307321
dag_path = os.path.join(TEST_DAGS_FOLDER, "test_invalid_cron.py")
308-
args = self.parser.parse_args(["dags", "list", "--output", "yaml", "--subdir", dag_path])
322+
args = self.parser.parse_args(
323+
["dags", "list-import-errors", "--output", "yaml", "--subdir", dag_path]
324+
)
309325
with contextlib.redirect_stdout(StringIO()) as temp_stdout:
310326
with pytest.raises(SystemExit) as err_ctx:
311327
dag_command.dag_list_import_errors(args)

tests_common/pytest_plugin.py

+19
Original file line numberDiff line numberDiff line change
@@ -1423,6 +1423,25 @@ def _get(dag_id: str):
14231423
dagbag = DagBag(dag_folder=dag_file, include_examples=False)
14241424

14251425
dag = dagbag.get_dag(dag_id)
1426+
1427+
if dagbag.import_errors:
1428+
session = settings.Session()
1429+
from airflow.models.errors import ParseImportError
1430+
from airflow.utils import timezone
1431+
1432+
# Add the new import errors
1433+
for _filename, stacktrace in dagbag.import_errors.items():
1434+
session.add(
1435+
ParseImportError(
1436+
filename=str(dag_file),
1437+
bundle_name="testing",
1438+
timestamp=timezone.utcnow(),
1439+
stacktrace=stacktrace,
1440+
)
1441+
)
1442+
1443+
return
1444+
14261445
if AIRFLOW_V_3_0_PLUS:
14271446
session = settings.Session()
14281447
from airflow.models.dagbundle import DagBundleModel

0 commit comments

Comments
 (0)