diff --git a/amlb/benchmarks/file.py b/amlb/benchmarks/file.py index 034af74cc..3ff80e874 100644 --- a/amlb/benchmarks/file.py +++ b/amlb/benchmarks/file.py @@ -33,7 +33,7 @@ def load_file_benchmark( """Loads benchmark from a local file.""" benchmark_file = _find_local_benchmark_definition(name, benchmark_definition_dirs) log.info("Loading benchmark definitions from %s.", benchmark_file) - tasks = config_load(benchmark_file) + tasks = config_load(benchmark_file, strict=True) benchmark_name, _ = os.path.splitext(os.path.basename(benchmark_file)) for task in tasks: if task["openml_task_id"] is not None and not isinstance( diff --git a/amlb/frameworks/definitions.py b/amlb/frameworks/definitions.py index 2c2c68bc7..f30d4ae2a 100644 --- a/amlb/frameworks/definitions.py +++ b/amlb/frameworks/definitions.py @@ -44,7 +44,13 @@ def _load_and_merge_framework_definitions( definitions_by_tag = Namespace() for tag in [default_tag] + config.frameworks.tags: definitions_by_file = [ - config_load(_definition_file(file, tag)) for file in frameworks_file + config_load( + _definition_file(file, tag), + strict=( + tag == default_tag + ), # Only strict for base files, not tagged variants + ) + for file in frameworks_file ] if not config.frameworks.allow_duplicates: for d1, d2 in itertools.combinations( diff --git a/amlb/resources.py b/amlb/resources.py index 8a37977c4..05ce506ad 100644 --- a/amlb/resources.py +++ b/amlb/resources.py @@ -194,8 +194,10 @@ def _constraints(self): constraints_file = [constraints_file] constraints = Namespace() - for ef in constraints_file: - constraints += config_load(ef) + for i, ef in enumerate(constraints_file): + # First constraint file should exist (usually the default one) + # Additional files (e.g., user overrides) are optional + constraints += config_load(ef, strict=(i == 0)) for name, c in constraints: c.name = str_sanitize(name) diff --git a/amlb/utils/config.py b/amlb/utils/config.py index a95ac1b51..b1c97d824 100644 --- a/amlb/utils/config.py +++ b/amlb/utils/config.py @@ -47,9 +47,22 @@ def yaml_load(*_, **__): # type: ignore[misc] ) -def config_load(path, verbose=False): +def config_load(path, verbose=False, strict=False): + """Load a configuration file. + + :param path: Path to the configuration file. + :param verbose: If True, log at INFO level instead of DEBUG when loading. + :param strict: If True, raise an error when the file doesn't exist. + If False, log a warning/debug message and return empty Namespace. + :return: Namespace containing the configuration. + """ path = normalize_path(path) if not os.path.isfile(path): + if strict: + raise FileNotFoundError( + f"Configuration file not found: `{path}`. " + f"Please check the path for typos or verify the file exists." + ) log.log( logging.WARNING if verbose else logging.DEBUG, "No config file at `%s`, ignoring it.", diff --git a/amlb/utils/os.py b/amlb/utils/os.py index 131b559c0..36fd680cf 100644 --- a/amlb/utils/os.py +++ b/amlb/utils/os.py @@ -56,10 +56,11 @@ def dir_of(caller_file, rel_to_project_root=False): return abs_path -def list_all_files(paths, filter_=None): +def list_all_files(paths, filter_=None, strict=False): """ :param paths: the directories to look into. :param filter_: None, or a predicate function returning True iff the file should be listed. + :param strict: If True, raise an error when a path doesn't exist. """ filter_ = filter_ or (lambda _: True) all_files = [] @@ -76,6 +77,11 @@ def list_all_files(paths, filter_=None): if filter_(path): all_files.append(path) else: + if strict: + raise FileNotFoundError( + f"Path not found: `{path}`. " + f"Please check the path for typos or verify it exists." + ) log.warning("Skipping file `%s` as it doesn't exist.", path) return all_files diff --git a/recover_results.py b/recover_results.py index 62ca00e45..35dfe7e0b 100644 --- a/recover_results.py +++ b/recover_results.py @@ -29,7 +29,7 @@ amlb.logger.setup(root_level="DEBUG", console_level="INFO") root_dir = os.path.dirname(__file__) -config = config_load(os.path.join(root_dir, "resources", "config.yaml")) +config = config_load(os.path.join(root_dir, "resources", "config.yaml"), strict=True) config_args = ns.parse( root_dir=root_dir, script=os.path.basename(__file__), diff --git a/runbenchmark.py b/runbenchmark.py index 431e865ef..3adef3085 100644 --- a/runbenchmark.py +++ b/runbenchmark.py @@ -277,14 +277,16 @@ log.debug("Script args: %s.", args) config_default = config_load( - os.path.join(default_dirs.root_dir, "resources", "config.yaml") + os.path.join(default_dirs.root_dir, "resources", "config.yaml"), + strict=True, # Default config must exist ) config_default_dirs = default_dirs # allowing config override from user_dir: useful to define custom benchmarks and frameworks for example. config_user = config_load( extras.get( "config", os.path.join(args.userdir or default_dirs.user_dir, "config.yaml") - ) + ), + strict=False, # User config is optional ) # config listing properties set by command line config_args = ns.parse( diff --git a/tests/conftest.py b/tests/conftest.py index 3f93aea90..211baa78b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,7 +9,8 @@ @pytest.fixture def load_default_resources(tmp_path): config_default = config_load( - os.path.join(default_dirs.root_dir, "resources", "config.yaml") + os.path.join(default_dirs.root_dir, "resources", "config.yaml"), + strict=True, # Default config must exist ) config_default_dirs = default_dirs config_test = Namespace( diff --git a/tests/unit/amlb/benchmarks/test_strict_loading.py b/tests/unit/amlb/benchmarks/test_strict_loading.py new file mode 100644 index 000000000..cfcd8f597 --- /dev/null +++ b/tests/unit/amlb/benchmarks/test_strict_loading.py @@ -0,0 +1,23 @@ +import pytest +from amlb.benchmarks.file import load_file_benchmark + + +def test_missing_benchmark_file_raises_error(): + """Missing benchmark file should raise ValueError when finding it, then FileNotFoundError when loading it.""" + with pytest.raises(ValueError) as exc_info: + load_file_benchmark("nonexistent_benchmark", ["/tmp/nonexistent_dir"]) + + error_message = str(exc_info.value) + assert "incorrect benchmark" in error_message.lower() + + +def test_benchmark_file_with_typo_in_path_raises_error(): + """If a full path to benchmark is provided but has a typo, it should raise ValueError.""" + with pytest.raises(ValueError) as exc_info: + load_file_benchmark( + "/tmp/nonexistent_benchmark_with_typo.yaml", + ["/tmp"], # directory exists, but file doesn't + ) + + error_message = str(exc_info.value) + assert "incorrect benchmark" in error_message.lower() diff --git a/tests/unit/amlb/frameworks/definitions/test_strict_loading.py b/tests/unit/amlb/frameworks/definitions/test_strict_loading.py new file mode 100644 index 000000000..1a03e27a3 --- /dev/null +++ b/tests/unit/amlb/frameworks/definitions/test_strict_loading.py @@ -0,0 +1,57 @@ +import os +import pytest +import tempfile +from amlb.frameworks.definitions import load_framework_definitions + +here = os.path.realpath(os.path.dirname(__file__)) +res = os.path.join(here, "resources") + + +@pytest.mark.use_disk +def test_missing_framework_file_raises_error(simple_resource): + """Missing base framework file should raise FileNotFoundError.""" + with pytest.raises(FileNotFoundError) as exc_info: + load_framework_definitions( + "/tmp/nonexistent_frameworks.yaml", simple_resource.config + ) + + error_message = str(exc_info.value) + assert "not found" in error_message.lower() + assert "typo" in error_message.lower() + + +@pytest.mark.use_disk +def test_missing_tagged_framework_file_is_tolerated(simple_resource): + """Missing tagged variant files (e.g., frameworks_stable.yaml) should be tolerated.""" + # Create a minimal framework file + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + f.write(""" +TestFramework: + version: "1.0" +""") + temp_file = f.name + + try: + # This should work even if tagged variants don't exist + # (e.g., frameworks_stable.yaml, frameworks_latest.yaml, etc.) + definitions_by_tag = load_framework_definitions( + temp_file, simple_resource.config + ) + + # Should have loaded at least the default tag + assert len(definitions_by_tag) >= 1 + finally: + os.unlink(temp_file) + + +@pytest.mark.use_disk +def test_multiple_framework_files_first_must_exist(simple_resource): + """When loading multiple files, all base files must exist.""" + existing_file = f"{res}/frameworks_inheritance.yaml" + nonexistent_file = "/tmp/nonexistent_frameworks.yaml" + + # Both files in the list must exist at the base level + with pytest.raises(FileNotFoundError): + load_framework_definitions( + [existing_file, nonexistent_file], simple_resource.config + ) diff --git a/tests/unit/amlb/utils/config/test_config_load_strict.py b/tests/unit/amlb/utils/config/test_config_load_strict.py new file mode 100644 index 000000000..47957a35f --- /dev/null +++ b/tests/unit/amlb/utils/config/test_config_load_strict.py @@ -0,0 +1,57 @@ +import os +import tempfile +import pytest +from amlb.utils.config import config_load +from amlb.utils import Namespace + + +def test_config_load_nonexistent_file_strict_false(): + """Non-existent file with strict=False should return empty Namespace.""" + result = config_load("/tmp/definitely_nonexistent_file_12345.yaml", strict=False) + assert isinstance(result, Namespace) + assert len(dir(result)) == 0 # Empty namespace + + +def test_config_load_nonexistent_file_strict_true(): + """Non-existent file with strict=True should raise FileNotFoundError.""" + with pytest.raises(FileNotFoundError) as exc_info: + config_load("/tmp/definitely_nonexistent_file_12345.yaml", strict=True) + + error_message = str(exc_info.value) + assert "not found" in error_message.lower() + assert "typo" in error_message.lower() # Should mention checking for typos + + +def test_config_load_existing_file_strict_false(): + """Existing file with strict=False should load normally.""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + f.write("test_key: test_value\n") + temp_file = f.name + + try: + result = config_load(temp_file, strict=False) + assert isinstance(result, Namespace) + assert result.test_key == "test_value" + finally: + os.unlink(temp_file) + + +def test_config_load_existing_file_strict_true(): + """Existing file with strict=True should load normally.""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + f.write("test_key: test_value\n") + temp_file = f.name + + try: + result = config_load(temp_file, strict=True) + assert isinstance(result, Namespace) + assert result.test_key == "test_value" + finally: + os.unlink(temp_file) + + +def test_config_load_default_is_strict_false(): + """Default behavior should be strict=False for backward compatibility.""" + result = config_load("/tmp/definitely_nonexistent_file_12345.yaml") + assert isinstance(result, Namespace) + assert len(dir(result)) == 0 # Empty namespace diff --git a/upload_results.py b/upload_results.py index 310f020b8..ae3b53173 100644 --- a/upload_results.py +++ b/upload_results.py @@ -88,7 +88,7 @@ def parse_args(): def find_most_recent_result_folder() -> pathlib.Path: root_dir = pathlib.Path(__file__).parent - config = config_load(root_dir / "resources" / "config.yaml") + config = config_load(root_dir / "resources" / "config.yaml", strict=True) output_dir = pathlib.Path(config.output_dir or default_dirs.output_dir) def dirname_to_datetime(dirname: str) -> datetime: