Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 4 additions & 7 deletions amlb/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import pandas as pd

from .frameworks.definitions import load_framework_definition
from .job import Job, JobError, SimpleJobRunner, MultiThreadingJobRunner
from .datasets import DataLoader, DataSourceType
from .data import DatasetType
Expand Down Expand Up @@ -119,13 +120,8 @@ def __init__(
Benchmark.data_loader = DataLoader(rconfig())

self._job_history = self._load_job_history(job_history=job_history)

fsplits = framework_name.split(":", 1)
framework_name = fsplits[0]
tag = fsplits[1] if len(fsplits) > 1 else None
self.framework_def, self.framework_name = rget().framework_definition(
framework_name, tag
)
framework = load_framework_definition(framework_name, rget())
self.framework_def, self.framework_name = framework, framework.name
log.debug("Using framework definition: %s.", self.framework_def)

self.constraint_def, self.constraint_name = rget().constraint_definition(
Expand Down Expand Up @@ -658,6 +654,7 @@ def handle_unfulfilled(message, on_auto="warn"):


class BenchmarkTask:

def __init__(self, benchmark: Benchmark, task_def, fold):
"""

Expand Down
50 changes: 49 additions & 1 deletion amlb/frameworks/definitions.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
from __future__ import annotations

import copy
import itertools
import logging
import os
from typing import List, Optional, Union
from dataclasses import dataclass, field
from typing import List, Optional, Union, TYPE_CHECKING

from amlb.utils import Namespace, config_load, str_sanitize

if TYPE_CHECKING:
from amlb import Resources

log = logging.getLogger(__name__)

default_tag = "_"
Expand Down Expand Up @@ -233,3 +239,45 @@ def _remove_frameworks_with_unknown_parent(frameworks: Namespace):
"Removing framework %s as parent %s doesn't exist.", framework, parent
)
del frameworks[framework]


@dataclass
class Image:
author: str
image: str
tag: str


@dataclass
class Framework:
name: str
abstract: bool
module: str
version: str
# Image
image: Image
# Setup
_setup_cmd: str | None
setup_cmd: str | None
setup_script: str | None
setup_env: dict = field(default_factory=dict)
setup_args: list[str] = field(default_factory=list)
# more optionals
params: dict = field(default_factory=dict)
refs: list = field(default_factory=list)
description: str | None = None
project: str | None = None

def __post_init__(self):
if isinstance(self.image, dict):
self.image = Image(**self.image)
Comment on lines +271 to +273
Copy link
Collaborator Author

@PGijsbers PGijsbers Dec 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This helps where framework namespaces were initiated based on nested dicts (loaded from configuration files). E.g.

framework = Framework(**{"name": ..., "image": {"author": ...}})

correctly initializes an Image instance.



def load_framework_definition(
framework_name: str, configuration: "Resources"
) -> Framework:
tag = None
if ":" in framework_name:
framework_name, tag = framework_name.split(":", 1)
definition_ns, name = configuration.framework_definition(framework_name, tag)
return Framework(**Namespace.dict(definition_ns))
Loading