Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 16 additions & 18 deletions src/blop/ax/agent.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import importlib.util
import logging
from collections.abc import Sequence
from typing import Any, TypeGuard, cast
from typing import Any, cast

from ax import Client
from ax.analysis import ContourPlot
from ax.core.types import TParamValue

# ===============================
# TODO: Remove when Python 3.10 is no longer supported
Expand Down Expand Up @@ -34,14 +35,6 @@
logger = logging.getLogger(__name__)


def _has_str_keys(d: dict[DOF, Any] | dict[str, Any]) -> TypeGuard[dict[str, Any]]:
return all(isinstance(key, str) for key in d.keys())


def _has_dof_keys(d: dict[DOF, Any] | dict[str, Any]) -> TypeGuard[dict[DOF, Any]]:
return all(isinstance(key, DOF) for key in d.keys())


class _AxAgentMixin:
"""
Mixin providing Ax-related functionality shared by agents.
Expand All @@ -63,7 +56,7 @@ def fixed_dofs(self) -> dict[str, Any] | None:
return self._optimizer.fixed_parameters

@fixed_dofs.setter
def fixed_dofs(self, fixed_dofs: dict[DOF, Any] | dict[str, Any] | None) -> None:
def fixed_dofs(self, fixed_dofs: dict[DOF, Any] | None) -> None:
"""
Fix degrees of freedom to a certain value for future optimizations.

Expand All @@ -77,14 +70,7 @@ def fixed_dofs(self, fixed_dofs: dict[DOF, Any] | dict[str, Any] | None) -> None
self._optimizer.fixed_parameters = None
return

if _has_str_keys(fixed_dofs):
self._optimizer.fixed_parameters = fixed_dofs
elif _has_dof_keys(fixed_dofs):
self._optimizer.fixed_parameters = {dof.parameter_name: value for dof, value in fixed_dofs.items()}
else:
raise ValueError(
f"Keys must all be either {type(DOF)} or {type(str)}, but got {type(list(fixed_dofs.keys())[0])}"
)
self._optimizer.fixed_parameters = {dof.parameter_name: value for dof, value in fixed_dofs.items()}

def suggest(self, num_points: int = 1) -> list[dict]:
"""
Expand Down Expand Up @@ -181,6 +167,18 @@ def checkpoint(self) -> None:
"""
self._optimizer.checkpoint()

def reconfigure_search_space(self, dof_mappings: dict[DOF, tuple[float, float] | list[TParamValue]]) -> None:
"""
Update bounds or values of existing DOFs for future optimizations.

Parameters
----------
dof_mappings : dict[DOF, tuple[float, float] | list[float] | list[int] | list[str] | list[bool]]
Mapping of DOFs to their new search space.
"""

self._optimizer.reconfigure_search_space({dof.parameter_name: update for dof, update in dof_mappings.items()})


class Agent(_AxAgentMixin):
"""
Expand Down
72 changes: 70 additions & 2 deletions src/blop/ax/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from typing import Any

from ax import ChoiceParameterConfig, Client, RangeParameterConfig
from ax.core.parameter import ChoiceParameter, RangeParameter
from ax.core.types import TParamValue

from ..protocols import ID_KEY, CanRegisterSuggestions, Checkpointable, Optimizer

Expand Down Expand Up @@ -100,10 +102,11 @@ def fixed_parameters(self, fixed_parameters: dict[str, Any] | None) -> None:
if not fixed_parameters:
self._fixed_parameters = None
return
unknown_parameter_names = set(fixed_parameters) - set(self._parameter_names)

unknown_parameter_names = fixed_parameters.keys() - set(self._parameter_names)
if unknown_parameter_names:
raise KeyError(
f"Unknown fixed parameter(s): {sorted(unknown_parameter_names)}, expected: {sorted(self._parameter_names)}"
f"Unknown parameter(s): {sorted(unknown_parameter_names)}, expected: {sorted(self._parameter_names)}"
)
self._fixed_parameters = dict(fixed_parameters)

Expand Down Expand Up @@ -214,3 +217,68 @@ def checkpoint(self) -> None:
if not self.checkpoint_path:
raise ValueError("Checkpoint path is not set. Please set a checkpoint path when initializing the optimizer.")
self._client.save_to_json_file(self.checkpoint_path)

def _apply_parameter_update(
self,
parameter_name: str,
value: tuple[float, float] | list[TParamValue],
original_range_values: dict[str, tuple[float, float]],
original_choice_values: dict[str, list[TParamValue]],
) -> None:
"""
Validate and apply a single parameter update, storing the original value for rollback in case of failure.

Raises
------
TypeError
If the provided value does not match the expected type for the parameter.
"""
parameter = self._client._experiment.parameters[parameter_name]
if isinstance(parameter, RangeParameter):
if not isinstance(value, tuple):
raise TypeError(f"{RangeParameter.__name__} only accepts tuples of length 2, but got: {value}")
original_range_values[parameter_name] = (parameter.lower, parameter.upper)
parameter.update_range(*value)
elif isinstance(parameter, ChoiceParameter):
if not isinstance(value, list):
raise TypeError(f"{ChoiceParameter.__name__} only accepts list of items, but got: {value}")
original_choice_values[parameter_name] = parameter.values
parameter.set_values(value)
else:
raise TypeError(f"Expected RangeParameter or ChoiceParameter, but got {parameter}")

def _rollback_parameter_updates(
self,
original_range_values: dict[str, tuple[float, float]],
original_choice_values: dict[str, list[TParamValue]],
) -> None:
"""
Rollback original parameter state after a failed update
"""
for parameter_name, value in original_range_values.items():
parameter = self._client._experiment.parameters[parameter_name]
if isinstance(parameter, RangeParameter):
parameter.update_range(*value)
for parameter_name, value in original_choice_values.items():
parameter = self._client._experiment.parameters[parameter_name]
if isinstance(parameter, ChoiceParameter):
parameter.set_values(value)

def reconfigure_search_space(self, parameter_mappings: dict[str, tuple[float, float] | list[TParamValue]]) -> None:
"""
Update the bounds or values of existing parameters in the underlying experiment

Parameters
----------
parameter_mappings : dict[str, tuple[float, float] | list[TParamValue]]
Mapping of parameter names to (lower, upper) bounds or a list of values depending on the parameter type.

"""
original_range_values = {}
original_choice_values = {}
try:
for parameter_name, value in parameter_mappings.items():
self._apply_parameter_update(parameter_name, value, original_range_values, original_choice_values)
except Exception as e:
self._rollback_parameter_updates(original_range_values, original_choice_values)
raise e
30 changes: 28 additions & 2 deletions src/blop/tests/ax/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,11 @@ def test_agent_suggest_fixed_dofs(mock_evaluation_function):
objectives=[objective],
evaluation_function=mock_evaluation_function,
)
with pytest.raises(ValueError):
agent.fixed_dofs = {"test_movable1": 3, dof2: 4}
# Keys must be a DOF object
with pytest.raises(AttributeError):
agent.fixed_dofs = {"test_movable1": 3}

# Valid updates should fix the DOF
agent.fixed_dofs = {dof2: 4}
parameterizations = agent.suggest(5)
for i in range(5):
Expand Down Expand Up @@ -219,6 +222,29 @@ def test_ingest_baseline(mock_evaluation_function):
assert summary_df["arm_name"].values[0] == "baseline"


def test_reconfigure_search_space(mock_evaluation_function):
movable1 = MovableSignal(name="test_movable1")
movable2 = MovableSignal(name="test_movable2")
dof1 = RangeDOF(actuator=movable1, bounds=(0, 10), parameter_type="float")
dof2 = RangeDOF(actuator=movable2, bounds=(0, 10), parameter_type="float")
objective = Objective(name="test_objective", minimize=False)
agent = Agent(
sensors=[],
dofs=[dof1, dof2],
objectives=[objective],
evaluation_function=mock_evaluation_function,
)
# Keys must be DOF objects, not parameter names
with pytest.raises(AttributeError):
agent.reconfigure_search_space({"test_movable1": (3, 6)})

# Valid update should restrict the search space
agent.reconfigure_search_space({dof1: (3, 6)})
parameterizations = agent.suggest(10)
for i in range(10):
assert 3 <= parameterizations[i]["test_movable1"] <= 6


def test_agent_init_actuator_string_raises(mock_evaluation_function):
dof1 = RangeDOF(actuator="test_movable1", bounds=(0, 10), parameter_type="float")
dof2 = RangeDOF(actuator="test_movable2", bounds=(0, 10), parameter_type="float")
Expand Down
68 changes: 66 additions & 2 deletions src/blop/tests/ax/test_optimizer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from unittest.mock import patch

import numpy as np
import pytest
from ax import ChoiceParameterConfig, RangeParameterConfig
Expand Down Expand Up @@ -35,11 +37,14 @@ def test_ax_fixed_parameters():
parameter_constraints=["x1 + x2 <= 10"],
outcome_constraints=["y1 >= 0", "y2 <= 0"],
)
optimizer.fixed_parameters = {"x3": 3}
assert optimizer.fixed_parameters == {"x3": 3}
# Unknown parameter name
with pytest.raises(KeyError):
optimizer.fixed_parameters = {"x4": 3}

# Optimizer state should reflect setter call
optimizer.fixed_parameters = {"x3": 3}
assert optimizer.fixed_parameters == {"x3": 3}


def test_ax_optimizer_suggest():
optimizer = AxOptimizer(
Expand Down Expand Up @@ -177,3 +182,62 @@ def test_ax_optimizer_checkpoint_no_path():

with pytest.raises(ValueError):
optimizer.checkpoint()


def test_ax_optimizer_reconfigurable_search_space():
optimizer = AxOptimizer(
parameters=[
RangeParameterConfig(name="x1", bounds=(-5.0, 5.0), parameter_type="float"),
RangeParameterConfig(name="x2", bounds=(-5.0, 5.0), parameter_type="float"),
ChoiceParameterConfig(name="x3", values=[0, 1, 2, 3, 4, 5], parameter_type="int", is_ordered=True),
],
objective="y1,-y2",
parameter_constraints=["x1 + x2 <= 10"],
outcome_constraints=["y1 >= 0", "y2 <= 0"],
)
# Unknown parameter name
with pytest.raises(KeyError):
optimizer.reconfigure_search_space({"x4": (-4, 4)})
# ChoiceParameter expects a list
with pytest.raises(TypeError):
optimizer.reconfigure_search_space({"x3": 3})
# ChoiceParameter expects a list of types castable to parameter_type
with pytest.raises(ValueError):
optimizer.reconfigure_search_space({"x3": ["2", "Hello", 3.6]})
# RangeParameter expects a tuple
with pytest.raises(TypeError):
optimizer.reconfigure_search_space({"x1": 3})

# Changing the serach space should reflect in parameter state
optimizer.reconfigure_search_space({"x1": (-4, 4), "x3": [6, 7, 8]})
param_x1 = optimizer._client._experiment.parameters["x1"]
param_x3 = optimizer._client._experiment.parameters["x3"]
assert (param_x1.lower, param_x1.upper) == (-4, 4)
assert param_x3.values == [6, 7, 8]


def test_ax_optimizer_reconfigurable_search_space_rollback():
optimizer = AxOptimizer(
parameters=[
RangeParameterConfig(name="x1", bounds=(-5.0, 5.0), parameter_type="float"),
RangeParameterConfig(name="x2", bounds=(-5.0, 5.0), parameter_type="float"),
ChoiceParameterConfig(name="x3", values=[0, 1, 2, 3, 4, 5], parameter_type="int", is_ordered=True),
],
objective="y1,-y2",
parameter_constraints=["x1 + x2 <= 10"],
outcome_constraints=["y1 >= 0", "y2 <= 0"],
)
p1 = optimizer._client._experiment.parameters["x1"]
original_p1 = (p1.lower, p1.upper)
# Make second state change fail after first
with patch.object(p1, "update_range", side_effect=RuntimeError("boom")):
with pytest.raises(RuntimeError):
optimizer.reconfigure_search_space(
{
"x1": (3, 6),
"x2": 5,
}
)

# Rollback should restore the original state
assert (p1.lower, p1.upper) == original_p1
Loading