Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 22 additions & 13 deletions src/blop/ax/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,8 @@
logger = logging.getLogger(__name__)


def _has_str_keys(d: dict[DOF, Any] | dict[str, Any]) -> TypeGuard[dict[str, Any]]:
return all(isinstance(key, str) for key in d.keys())


def _has_dof_keys(d: dict[DOF, Any] | dict[str, Any]) -> TypeGuard[dict[DOF, Any]]:
return all(isinstance(key, DOF) for key in d.keys())
def _has_dof_keys(d: dict[Any, Any]) -> TypeGuard[dict[DOF, Any]]:
return all(isinstance(key, DOF) for key in d)


class _AxAgentMixin:
Expand All @@ -63,7 +59,7 @@ def fixed_dofs(self) -> dict[str, Any] | None:
return self._optimizer.fixed_parameters

@fixed_dofs.setter
def fixed_dofs(self, fixed_dofs: dict[DOF, Any] | dict[str, Any] | None) -> None:
def fixed_dofs(self, fixed_dofs: dict[DOF, Any] | None) -> None:
"""
Fix degrees of freedom to a certain value for future optimizations.

Expand All @@ -77,14 +73,10 @@ def fixed_dofs(self, fixed_dofs: dict[DOF, Any] | dict[str, Any] | None) -> None
self._optimizer.fixed_parameters = None
return

if _has_str_keys(fixed_dofs):
self._optimizer.fixed_parameters = fixed_dofs
elif _has_dof_keys(fixed_dofs):
if _has_dof_keys(fixed_dofs):
self._optimizer.fixed_parameters = {dof.parameter_name: value for dof, value in fixed_dofs.items()}
else:
raise ValueError(
f"Keys must all be either {type(DOF)} or {type(str)}, but got {type(list(fixed_dofs.keys())[0])}"
)
raise TypeError("Keys must be DOF objects")

def suggest(self, num_points: int = 1) -> list[dict]:
"""
Expand Down Expand Up @@ -181,6 +173,23 @@ def checkpoint(self) -> None:
"""
self._optimizer.checkpoint()

def reconfigure_search_space(
self, dof_mappings: dict[DOF, tuple[float, float] | list[float] | list[int] | list[str] | list[bool]]
) -> None:
"""
Update bounds or values of existing DOFs for future optimizations.

Parameters
----------
dof_mappings : dict[DOF, tuple[float, float] | list[float] | list[int] | list[str] | list[bool]]
Mapping of DOFs to their new search space.
"""

if _has_dof_keys(dof_mappings):
self._optimizer._reconfigure_search_space({dof.parameter_name: update for dof, update in dof_mappings.items()})
else:
raise TypeError("Keys must all be type DOF")


class Agent(_AxAgentMixin):
"""
Expand Down
102 changes: 96 additions & 6 deletions src/blop/ax/optimizer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
from collections.abc import Sequence
from typing import Any
from typing import Any, TypeGuard

from ax import ChoiceParameterConfig, Client, RangeParameterConfig
from ax.core.parameter import PARAMETER_PYTHON_TYPE_MAP, ChoiceParameter, RangeParameter
from ax.core.types import TParamValue

from ..protocols import ID_KEY, CanRegisterSuggestions, Checkpointable, Optimizer


def _is_tparamvalue_list(values, values_type) -> TypeGuard[list[TParamValue]]:
return issubclass(values_type, TParamValue) and all(isinstance(x, values_type) for x in values)


class AxOptimizer(Optimizer, Checkpointable, CanRegisterSuggestions):
"""
An optimizer that uses Ax as the backend for optimization and experiment tracking.
Expand Down Expand Up @@ -100,11 +106,7 @@ def fixed_parameters(self, fixed_parameters: dict[str, Any] | None) -> None:
if not fixed_parameters:
self._fixed_parameters = None
return
unknown_parameter_names = set(fixed_parameters) - set(self._parameter_names)
if unknown_parameter_names:
raise KeyError(
f"Unknown fixed parameter(s): {sorted(unknown_parameter_names)}, expected: {sorted(self._parameter_names)}"
)
self._verify_parameter_names(set(fixed_parameters))
self._fixed_parameters = dict(fixed_parameters)

def suggest(self, num_points: int | None = None) -> list[dict]:
Expand Down Expand Up @@ -214,3 +216,91 @@ def checkpoint(self) -> None:
if not self.checkpoint_path:
raise ValueError("Checkpoint path is not set. Please set a checkpoint path when initializing the optimizer.")
self._client.save_to_json_file(self.checkpoint_path)

def _verify_parameter_names(self, parameter_names: set[str]) -> None:
"""
Ensure all parameter names exist in the experiment

Raises
------
KeyError
If any parameter name is unknown
"""
unknown_parameter_names = parameter_names - set(self._parameter_names)
if unknown_parameter_names:
raise KeyError(
f"Unknown parameter(s): {sorted(unknown_parameter_names)}, expected: {sorted(self._parameter_names)}"
)

def _apply_parameter_update(
self,
parameter_name: str,
value: tuple[float, float] | list[float] | list[int] | list[str] | list[bool],
original_range_values: dict[str, tuple[float, float]],
original_choice_values: dict[str, list[TParamValue]],
) -> None:
"""
Validate and apply a single parameter update, storing the original value for rollback in case of failure.

Raises
------
TypeError
If the provided value does not match the expected type for the parameter.
"""
parameter = self._client._experiment.parameters[parameter_name]
if isinstance(parameter, RangeParameter):
if isinstance(value, tuple) and len(value) == 2 and all(isinstance(x, float | int) for x in value):
original_range_values[parameter_name] = (parameter.lower, parameter.upper)
parameter.update_range(*value)
else:
raise TypeError(f"Expected range to be a tuple of two floats, but got {value}")
Comment thread
thopkins32 marked this conversation as resolved.
Outdated
elif isinstance(parameter, ChoiceParameter):
if isinstance(value, list) and _is_tparamvalue_list(value, PARAMETER_PYTHON_TYPE_MAP[parameter.parameter_type]):
original_choice_values[parameter_name] = parameter.values
parameter.set_values(value)
else:
raise TypeError(
f"Expected choice(s) to be a list of a single type (float, int, str, or bool), but got {value}"
)
else:
raise TypeError(f"Expected RangeParameter or ChoiceParameter, but got {parameter}")

def _rollback_parameter_updates(
self,
original_range_values: dict[str, tuple[float, float]],
original_choice_values: dict[str, list[TParamValue]],
) -> None:
"""
Rollback original parameter state after a failed update
"""
for parameter_name, value in original_range_values.items():
parameter = self._client._experiment.parameters[parameter_name]
if isinstance(parameter, RangeParameter):
parameter.update_range(*value)
for parameter_name, value in original_choice_values.items():
parameter = self._client._experiment.parameters[parameter_name]
if isinstance(parameter, ChoiceParameter):
parameter.set_values(value)

def _reconfigure_search_space(
Comment thread
thopkins32 marked this conversation as resolved.
Outdated
self, parameter_mappings: dict[str, tuple[float, float] | list[float] | list[int] | list[str] | list[bool]]
) -> None:
"""
Update the bounds or values of existing parameters in the underlying experiment

Parameters
----------
parameter_mappings : dict[str, tuple[float, float] | list[float] | list[int] | list[str] | list[bool]]
Mapping of parameter names to (lower, upper) bounds or a list of values depending on the parameter type.

"""
self._verify_parameter_names(set(parameter_mappings))
Comment thread
thopkins32 marked this conversation as resolved.
Outdated

original_range_values = {}
original_choice_values = {}
try:
for parameter_name, value in parameter_mappings.items():
self._apply_parameter_update(parameter_name, value, original_range_values, original_choice_values)
except Exception:
self._rollback_parameter_updates(original_range_values, original_choice_values)
raise
30 changes: 28 additions & 2 deletions src/blop/tests/ax/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,11 @@ def test_agent_suggest_fixed_dofs(mock_evaluation_function):
objectives=[objective],
evaluation_function=mock_evaluation_function,
)
with pytest.raises(ValueError):
agent.fixed_dofs = {"test_movable1": 3, dof2: 4}
# Keys must be a DOF object
with pytest.raises(TypeError):
agent.fixed_dofs = {"test_movable1": 3}

# Valid updates should fix the DOF
agent.fixed_dofs = {dof2: 4}
parameterizations = agent.suggest(5)
for i in range(5):
Expand Down Expand Up @@ -219,6 +222,29 @@ def test_ingest_baseline(mock_evaluation_function):
assert summary_df["arm_name"].values[0] == "baseline"


def test_reconfigure_search_space(mock_evaluation_function):
movable1 = MovableSignal(name="test_movable1")
movable2 = MovableSignal(name="test_movable2")
dof1 = RangeDOF(actuator=movable1, bounds=(0, 10), parameter_type="float")
dof2 = RangeDOF(actuator=movable2, bounds=(0, 10), parameter_type="float")
objective = Objective(name="test_objective", minimize=False)
agent = Agent(
sensors=[],
dofs=[dof1, dof2],
objectives=[objective],
evaluation_function=mock_evaluation_function,
)
# Keys must be DOF objects, not parameter names
with pytest.raises(TypeError):
agent.reconfigure_search_space({"test_movable1": (3, 6)})

# Valid update should restrict the search space
agent.reconfigure_search_space({dof1: (3, 6)})
parameterizations = agent.suggest(10)
for i in range(10):
assert 3 <= parameterizations[i]["test_movable1"] <= 6


def test_agent_init_actuator_string_raises(mock_evaluation_function):
dof1 = RangeDOF(actuator="test_movable1", bounds=(0, 10), parameter_type="float")
dof2 = RangeDOF(actuator="test_movable2", bounds=(0, 10), parameter_type="float")
Expand Down
68 changes: 66 additions & 2 deletions src/blop/tests/ax/test_optimizer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from unittest.mock import patch

import numpy as np
import pytest
from ax import ChoiceParameterConfig, RangeParameterConfig
Expand Down Expand Up @@ -35,11 +37,14 @@ def test_ax_fixed_parameters():
parameter_constraints=["x1 + x2 <= 10"],
outcome_constraints=["y1 >= 0", "y2 <= 0"],
)
optimizer.fixed_parameters = {"x3": 3}
assert optimizer.fixed_parameters == {"x3": 3}
# Unknown parameter name
with pytest.raises(KeyError):
optimizer.fixed_parameters = {"x4": 3}

# Optimizer state should reflect setter call
optimizer.fixed_parameters = {"x3": 3}
assert optimizer.fixed_parameters == {"x3": 3}


def test_ax_optimizer_suggest():
optimizer = AxOptimizer(
Expand Down Expand Up @@ -177,3 +182,62 @@ def test_ax_optimizer_checkpoint_no_path():

with pytest.raises(ValueError):
optimizer.checkpoint()


def test_ax_optimizer_reconfigurable_search_space():
optimizer = AxOptimizer(
parameters=[
RangeParameterConfig(name="x1", bounds=(-5.0, 5.0), parameter_type="float"),
RangeParameterConfig(name="x2", bounds=(-5.0, 5.0), parameter_type="float"),
ChoiceParameterConfig(name="x3", values=[0, 1, 2, 3, 4, 5], parameter_type="int", is_ordered=True),
],
objective="y1,-y2",
parameter_constraints=["x1 + x2 <= 10"],
outcome_constraints=["y1 >= 0", "y2 <= 0"],
)
# Unknown parameter name
with pytest.raises(KeyError):
optimizer._reconfigure_search_space({"x4": (-4, 4)})
# ChoiceParameter expects a list
with pytest.raises(TypeError):
optimizer._reconfigure_search_space({"x3": 3})
# ChoiceParameter expects a list of single type
with pytest.raises(TypeError):
optimizer._reconfigure_search_space({"x3": ["2", 5, 3.6]})
# RangeParameter expects a tuple
with pytest.raises(TypeError):
optimizer._reconfigure_search_space({"x1": 3})

# Changing the serach space should reflect in parameter state
optimizer._reconfigure_search_space({"x1": (-4, 4), "x3": [6, 7, 8]})
param_x1 = optimizer._client._experiment.parameters["x1"]
param_x3 = optimizer._client._experiment.parameters["x3"]
assert (param_x1.lower, param_x1.upper) == (-4, 4)
assert param_x3.values == [6, 7, 8]


def test_ax_optimizer_reconfigurable_search_space_rollback():
optimizer = AxOptimizer(
parameters=[
RangeParameterConfig(name="x1", bounds=(-5.0, 5.0), parameter_type="float"),
RangeParameterConfig(name="x2", bounds=(-5.0, 5.0), parameter_type="float"),
ChoiceParameterConfig(name="x3", values=[0, 1, 2, 3, 4, 5], parameter_type="int", is_ordered=True),
],
objective="y1,-y2",
parameter_constraints=["x1 + x2 <= 10"],
outcome_constraints=["y1 >= 0", "y2 <= 0"],
)
p1 = optimizer._client._experiment.parameters["x1"]
original_p1 = (p1.lower, p1.upper)
# Make second state change fail after first
with patch.object(p1, "update_range", side_effect=RuntimeError("boom")):
with pytest.raises(RuntimeError):
optimizer._reconfigure_search_space(
{
"x1": (3, 6),
"x2": 5,
}
)

# Rollback should restore the original state
assert (p1.lower, p1.upper) == original_p1
Loading