Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# Version 1.3

* FEAT #411: Allow `Hyperparameter` to be removed from `ConfigurationSpace`.

# Version 1.2.2

* MAINT #404: Added support for Python 3.13.
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ select = [

ignore = [
"T201", # TODO: Remove
"COM812", # Causes issues with ruff formatter
"D100",
"D104", # Missing docstring in public package
"D105", # Missing docstring in magic mthod
Expand Down
6 changes: 2 additions & 4 deletions src/ConfigSpace/_condition_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
import numpy as np
from more_itertools import unique_everseen

from ConfigSpace.conditions import Condition, Conjunction
from ConfigSpace.conditions import Condition, ConditionLike, Conjunction
from ConfigSpace.exceptions import (
AmbiguousConditionError,
ChildNotFoundError,
Expand All @@ -62,7 +62,6 @@
from ConfigSpace.types import f64

if TYPE_CHECKING:
from ConfigSpace.conditions import ConditionLike
from ConfigSpace.hyperparameters import Hyperparameter
from ConfigSpace.types import Array

Expand Down Expand Up @@ -782,8 +781,7 @@ def _minimum_conditions(self) -> list[ConditionNode]:
# i.e. two hyperparameters both rely on algorithm == "A"
base_conditions: dict[int, ConditionNode] = {}
for node in self.nodes.values():
# This node has no parent as is a root
if node.parent_condition is None:
if node.parent_condition is None: # This node has no parent as it is a root node
assert node.name in self.roots
continue

Expand Down
10 changes: 5 additions & 5 deletions src/ConfigSpace/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
from ConfigSpace.api.types import Categorical, Float, Integer

__all__ = [
"types",
"distributions",
"Beta",
"Distribution",
"Normal",
"Uniform",
"Categorical",
"Distribution",
"Float",
"Integer",
"Normal",
"Uniform",
"distributions",
"types",
]
8 changes: 4 additions & 4 deletions src/ConfigSpace/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,10 @@ def __init__(
ConfigSpace package.
"""
if (
values is not None
and vector is not None
or values is None
and vector is None
(values is not None
and vector is not None)
or (values is None
and vector is None)
):
raise ValueError(
"Specify Configuration as either a dictionary or a vector.",
Expand Down
98 changes: 97 additions & 1 deletion src/ConfigSpace/configuration_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from ConfigSpace.configuration import Configuration, NotSet
from ConfigSpace.exceptions import (
ActiveHyperparameterNotSetError,
HyperparameterNotFoundError,
ForbiddenValueError,
IllegalVectorizedValueError,
InactiveHyperparameterSetError,
Expand Down Expand Up @@ -350,6 +351,101 @@ def _put_to_list(
self._len = len(self._dag.nodes)
self._check_default_configuration()

def remove(
self,
*args: Hyperparameter,
) -> None:
"""Remove a hyperparameter from the configuration space.

If the hyperparameter has children, the children are also removed.
This includes defined conditions and conjunctions!

!!! note

If removing multiple hyperparameters, it is better to remove them all
at once with one call to `remove()`, as we rebuilt a cache after each
call to `remove()`.

Args:
args: Hyperparameter(s) to remove
"""
remove_hps = []
for arg in args:
if isinstance(arg, Hyperparameter):
if arg.name not in self._dag.nodes:
raise HyperparameterNotFoundError(
f"Hyperparameter '{arg.name}' does not exist in space.",
)
remove_hps.append(arg)
else:
raise TypeError(f"Unknown type {type(arg)}")
remove_hps_names = [hp.name for hp in remove_hps]

# Filter HPs from the DAG
hps: list[Hyperparameter] = [node.hp for node in self._dag.nodes.values() if node.hp.name not in remove_hps_names]

def remove_hyperparameter_from_conjunction(
target: Conjunction | Condition | ForbiddenRelation | ForbiddenClause,
) -> (
Conjunction
| Condition
| ForbiddenClause
| ForbiddenRelation
| ForbiddenConjunction
| None
):
if isinstance(target, ForbiddenRelation) and (
target.left.name in remove_hps_names or target.right.name in remove_hps_names
):
return None
if isinstance(target, ForbiddenClause) and target.hyperparameter.name in remove_hps_names:
return None
if isinstance(target, Condition) and (
target.parent.name in remove_hps_names or target.child.name in remove_hps_names
):
return None
if isinstance(target, (Conjunction, ForbiddenConjunction)):
new_components = []
for component in target.components:
new_component = remove_hyperparameter_from_conjunction(component)
if new_component is not None:
new_components.append(new_component)
if len(new_components) >= 2: # Can create a conjunction
return type(target)(*new_components)
if len(new_components) == 1: # Only one component remains
return new_components[0]
return None # No components remain
return target # Nothing to change

# Remove HPs from conditions
conditions = []
for condition in self._dag.conditions:
condition = remove_hyperparameter_from_conjunction(condition)
if condition is not None: # If None, the conditional clause is empty and thus not added
conditions.append(condition)

# Remove HPs from Forbiddens
forbiddens = []
for forbidden in self._dag.forbiddens:
forbidden = remove_hyperparameter_from_conjunction(forbidden)
if forbidden is not None: # If None, the forbidden clause is empty and is not added
forbiddens.append(
remove_hyperparameter_from_conjunction(forbidden)
)

# Rebuild the DAG
self._dag = DAG()
with self._dag.update():
for hp in hps:
self._dag.add(hp)
for condition in conditions:
self._dag.add_condition(condition)
for forbidden in forbiddens:
self._dag.add_forbidden(forbidden)

self._len = len(self._dag.nodes)
self._check_default_configuration()

def add_configuration_space(
self,
prefix: str,
Expand Down Expand Up @@ -864,7 +960,7 @@ def __iter__(self) -> Iterator[str]:
return iter(self._dag.nodes.keys())

def items(self) -> ItemsView[str, Hyperparameter]:
"""Return an items view of the hyperparameters, same as `dict.items()`.""" # noqa: D402
"""Return an items view of the hyperparameters, same as `dict.items()`."""
return {name: node.hp for name, node in self._dag.nodes.items()}.items()

def __len__(self) -> int:
Expand Down
2 changes: 1 addition & 1 deletion src/ConfigSpace/hyperparameters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
"NormalIntegerHyperparameter",
"NumericalHyperparameter",
"OrdinalHyperparameter",
"UnParametrizedHyperparameter",
"UniformFloatHyperparameter",
"UniformIntegerHyperparameter",
"UnParametrizedHyperparameter",
]
104 changes: 104 additions & 0 deletions test/test_configuration_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,110 @@ def test_add():
cs.add(hp)


def test_remove():
cs = ConfigurationSpace()
hp = UniformIntegerHyperparameter("name", 0, 10)
hp2 = UniformFloatHyperparameter("name2", 0, 10)
hp3 = CategoricalHyperparameter(
"weather", ["dry", "rainy", "snowy"], default_value="dry"
)
cs.add(hp, hp2, hp3)
cs.remove(hp)
assert len(cs) == 2

# Test multi removal
cs.add(hp)
cs.remove(hp, hp2)
assert len(cs) == 1

# Test faulty input
with pytest.raises(TypeError):
cs.remove(object())

# Non existant HP
with pytest.raises(HyperparameterNotFoundError):
cs.remove(hp)

cs.add(hp, hp2)
# Test one correct one faulty, nothing should happen
with pytest.raises(TypeError):
cs.remove(hp, object())
assert len(cs) == 3

# Make hp2 a conditional parameter, the condition should also be removed when hp is removed
cond = EqualsCondition(hp, hp2, 1)
cs.add(cond)
cs.remove(hp)
assert len(cs) == 2
assert cs.conditional_hyperparameters == []
assert cs.conditions == []

# Set up forbidden relation, the relation should also be removed
forb = ForbiddenEqualsClause(hp3, "snowy")
cs.add(forb)
cs.remove(hp3)
assert len(cs) == 1
assert cs.forbidden_clauses == []

# And now for more complicated conditions
cs = ConfigurationSpace()
hp1 = CategoricalHyperparameter("input1", [0, 1])
cs.add(hp1)
hp2 = CategoricalHyperparameter("input2", [0, 1])
cs.add(hp2)
hp3 = CategoricalHyperparameter("input3", [0, 1])
cs.add(hp3)
hp4 = CategoricalHyperparameter("input4", [0, 1])
cs.add(hp4)
hp5 = CategoricalHyperparameter("input5", [0, 1])
cs.add(hp5)
hp6 = Constant("constant1", "True")
cs.add(hp6)

cond1 = EqualsCondition(hp6, hp1, 1)
cond2 = NotEqualsCondition(hp6, hp2, 1)
cond3 = InCondition(hp6, hp3, [1])
cond4 = EqualsCondition(hp6, hp4, 1)
cond5 = EqualsCondition(hp6, hp5, 1)

conj1 = AndConjunction(cond1, cond2)
conj2 = OrConjunction(conj1, cond3)
conj3 = AndConjunction(conj2, cond4, cond5)
cs.add(conj3)

cs.remove(hp3)
assert len(cs) == 5
# Only one part of the condition should be removed, not the entire condition
assert len(cs.conditional_hyperparameters) == 1
assert len(cs.conditions) == 1
# Test the exact value
assert (
str(cs.conditions[0])
== "((constant1 | input1 == 1 && constant1 | input2 != 1) && constant1 | input4 == 1 && constant1 | input5 == 1)"
)

# Now more complicated forbiddens
cs = ConfigurationSpace()
cs.add([hp1, hp2, hp3, hp4, hp5, hp6])
cs.add(conj3)

forb1 = ForbiddenEqualsClause(hp1, 1)
forb2 = ForbiddenAndConjunction(forb1, ForbiddenEqualsClause(hp2, 1))
forb3 = ForbiddenAndConjunction(forb2, ForbiddenEqualsClause(hp3, 1))
forb4 = ForbiddenEqualsClause(hp3, 1)
forb5 = ForbiddenEqualsClause(hp4, 1)
cs.add(forb3, forb4, forb5)

cs.remove(hp3)
assert len(cs) == 5
assert len(cs.forbidden_clauses) == 2
assert (
str(cs.forbidden_clauses[0])
== "(Forbidden: input1 == 1 && Forbidden: input2 == 1)"
)
assert str(cs.forbidden_clauses[1]) == "Forbidden: input4 == 1"


def test_add_non_hyperparameter():
cs = ConfigurationSpace()
with pytest.raises(TypeError):
Expand Down