Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 30 additions & 5 deletions src/blop/ax/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,23 @@
from ..plans import acquire_baseline, optimize, sample_suggestions
from ..protocols import AcquisitionPlan, Actuator, EvaluationFunction, OptimizationProblem, Sensor
from ..utils import InferredReadable
from .dof import DOF, DOFConstraint
from .dof import DOF, DOFConstraint, RangeDOF
from .objective import Objective, OutcomeConstraint, to_ax_objective_str
from .optimizer import AxOptimizer

logger = logging.getLogger(__name__)


def _has_str_keys(d: dict[DOF, Any] | dict[str, Any]) -> TypeGuard[dict[str, Any]]:
return all(isinstance(key, str) for key in d.keys())
def _has_str_keys(d: dict[Any, Any]) -> TypeGuard[dict[str, Any]]:
return all(isinstance(key, str) for key in d)


def _has_dof_keys(d: dict[DOF, Any] | dict[str, Any]) -> TypeGuard[dict[DOF, Any]]:
return all(isinstance(key, DOF) for key in d.keys())
def _has_dof_keys(d: dict[Any, Any]) -> TypeGuard[dict[DOF, Any]]:
return all(isinstance(key, DOF) for key in d)


def _has_range_dof_keys(d: dict[Any, Any]) -> TypeGuard[dict[RangeDOF, Any]]:
return all(isinstance(key, RangeDOF) for key in d)


class Agent:
Expand Down Expand Up @@ -412,3 +416,24 @@ def checkpoint(self) -> None:
Save the agent's state to a JSON file.
"""
self._optimizer.checkpoint()

def reconfigure_search_space(
self, dof_bounds: dict[RangeDOF, tuple[float, float]] | dict[str, tuple[float, float]]
) -> None:
"""
Update bounds of existing RangeDOFs for future optimizations.

Parameters
----------
dof_bounds : dict[RangeDOF, tuple[float, float]] | dict[str, tuple[float, float]]
Mapping of RangeDOFs or RangeDOF names to (upper, lower) bounds

"""
if _has_str_keys(dof_bounds):
self._optimizer._reconfigure_search_space(dof_bounds)
elif _has_range_dof_keys(dof_bounds):
self._optimizer._reconfigure_search_space({dof.parameter_name: bounds for dof, bounds in dof_bounds.items()})
else:
raise ValueError(
f"Keys must all be either {type(RangeDOF)} or {type(str)}, but got {type(list(dof_bounds.keys())[0])}"
)
22 changes: 22 additions & 0 deletions src/blop/ax/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any

from ax import ChoiceParameterConfig, Client, RangeParameterConfig
from ax.core import RangeParameter

from ..protocols import ID_KEY, CanRegisterSuggestions, Checkpointable, Optimizer

Expand Down Expand Up @@ -214,3 +215,24 @@ def checkpoint(self) -> None:
if not self.checkpoint_path:
raise ValueError("Checkpoint path is not set. Please set a checkpoint path when initializing the optimizer.")
self._client.save_to_json_file(self.checkpoint_path)

def _reconfigure_search_space(self, parameter_bounds: dict[str, tuple[float, float]]) -> None:
"""
Update the bounds of existing RangeParameters in the underlying experiment

Parameters
----------
parameter_bounds : dict[str, tuple[float, float]]
Mapping of parameter names to (lower, upper) bounds

"""
unknown_parameter_names = set(parameter_bounds) - set(self._parameter_names)
if unknown_parameter_names:
raise KeyError(
f"Unknown parameter(s): {sorted(unknown_parameter_names)}, expected: {sorted(self._parameter_names)}"
)

for parameter_name, bounds in parameter_bounds.items():
parameter = self._client._experiment.parameters[parameter_name]
if isinstance(parameter, RangeParameter):
parameter.update_range(*bounds)
20 changes: 20 additions & 0 deletions src/blop/tests/ax/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,3 +211,23 @@ def test_ingest_baseline(mock_evaluation_function):
summary_df = agent.ax_client.summarize()
assert len(summary_df) == 1
assert summary_df["arm_name"].values[0] == "baseline"


def test_reconfigure_search_space(mock_evaluation_function):
movable1 = MovableSignal(name="test_movable1")
movable2 = MovableSignal(name="test_movable2")
dof1 = RangeDOF(actuator=movable1, bounds=(0, 10), parameter_type="float")
dof2 = RangeDOF(actuator=movable2, bounds=(0, 10), parameter_type="float")
objective = Objective(name="test_objective", minimize=False)
agent = Agent(
sensors=[],
dofs=[dof1, dof2],
objectives=[objective],
evaluation_function=mock_evaluation_function,
)
with pytest.raises(ValueError):
agent.reconfigure_search_space({"test_movable1": (3, 6), dof2: (3, 6)})
agent.reconfigure_search_space({"test_movable1": (3, 6)})
parameterizations = agent.suggest(10)
for i in range(10):
assert 3 <= parameterizations[i]["test_movable1"] <= 6
18 changes: 18 additions & 0 deletions src/blop/tests/ax/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,3 +177,21 @@ def test_ax_optimizer_checkpoint_no_path():

with pytest.raises(ValueError):
optimizer.checkpoint()


def test_ax_optimizer_reconfigurable_search_space():
optimizer = AxOptimizer(
parameters=[
RangeParameterConfig(name="x1", bounds=(-5.0, 5.0), parameter_type="float"),
RangeParameterConfig(name="x2", bounds=(-5.0, 5.0), parameter_type="float"),
ChoiceParameterConfig(name="x3", values=[0, 1, 2, 3, 4, 5], parameter_type="int", is_ordered=True),
],
objective="y1,-y2",
parameter_constraints=["x1 + x2 <= 10"],
outcome_constraints=["y1 >= 0", "y2 <= 0"],
)
with pytest.raises(KeyError):
optimizer._reconfigure_search_space({"x4": (-4, 4)})
optimizer._reconfigure_search_space({"x1": (-4, 4)})
param_x1 = optimizer._client._experiment.parameters["x1"]
assert (param_x1.lower, param_x1.upper) == (-4, 4)
Loading