diff --git a/metaflow/decorators.py b/metaflow/decorators.py index 760508497f0..47ea16223c1 100644 --- a/metaflow/decorators.py +++ b/metaflow/decorators.py @@ -264,13 +264,31 @@ def get_top_level_options(self): # compare this to parameters.add_custom_parameters def add_decorator_options(cmd): - flow_cls = getattr(current_flow, "flow_cls", None) - if flow_cls is None: - return cmd + """ + Lazily adds flow decorator options to a Click command. + + Defers option registration until get_params() is called, ensuring + current_flow.flow_cls is set (which happens after module imports complete). + """ + _original_get_params = cmd.get_params + _options_added = [False] # Use list for mutable closure variable + + def _lazy_get_params(ctx): + if not _options_added[0]: + _options_added[0] = True + flow_cls = getattr(current_flow, "flow_cls", None) + if flow_cls is not None: + _add_flow_decorator_options(cmd, flow_cls) + return _original_get_params(ctx) + + cmd.get_params = _lazy_get_params + return cmd + +def _add_flow_decorator_options(cmd, flow_cls): + """Helper to add decorator options to a command.""" seen = {} existing_params = set(p.name.lower() for p in cmd.params) - # Add decorator options for deco in flow_decorators(flow_cls): for option, kwargs in deco.options.items(): if option in seen: @@ -290,7 +308,6 @@ def add_decorator_options(cmd): kwargs["envvar"] = "METAFLOW_FLOW_%s" % option.upper() seen[option] = deco.name cmd.params.insert(0, click.Option(("--" + option,), **kwargs)) - return cmd def flow_decorators(flow_cls): diff --git a/test/unit/test_decorator_cli_options.py b/test/unit/test_decorator_cli_options.py new file mode 100644 index 00000000000..e4af3b726e4 --- /dev/null +++ b/test/unit/test_decorator_cli_options.py @@ -0,0 +1,32 @@ +def test_project_decorator_options_with_early_cli_import(): + """ + Test that @project decorator options appear when metaflow.cli is imported early. + """ + from metaflow.cli import echo_always # noqa: F401 + from metaflow import FlowSpec, project, step, decorators + from metaflow._vendor import click + from metaflow.parameters import flow_context + + @project(name="test_project") + class TestFlow(FlowSpec): + @step + def start(self): + self.next(self.end) + + @step + def end(self): + pass + + @click.command() + def mock_cmd(): + pass + + cmd = decorators.add_decorator_options(mock_cmd) + + with flow_context(TestFlow): + ctx = click.Context(cmd) + params = cmd.get_params(ctx) + param_names = [p.name for p in params] + + assert "branch" in param_names + assert "production" in param_names