Skip to content

Commit b2b8c11

Browse files
committed
Switch to Ax Client API for simplicity
1 parent 2021281 commit b2b8c11

2 files changed

Lines changed: 121 additions & 226 deletions

File tree

src/blop/integrations/ax.py

Lines changed: 82 additions & 148 deletions
Original file line numberDiff line numberDiff line change
@@ -1,152 +1,86 @@
1-
from abc import ABC, abstractmethod
2-
from collections.abc import Callable, Sequence
3-
from typing import Any, cast
1+
import copy
2+
from collections.abc import Callable
3+
from typing import Any
44

55
import pandas as pd
6-
from ax import Arm, Experiment, Runner, Trial
7-
from ax.core.base_trial import BaseTrial
8-
from ax.utils.common.result import Ok, Result
9-
from ax.core.map_metric import MapMetric
10-
from ax.core.map_data import MapData
11-
from bluesky import RunEngine
6+
from ax.service.ax_client import AxClient
127
from bluesky.plans import list_scan
13-
from bluesky.protocols import HasName, Movable, NamedMovable, Readable
14-
from databroker import Broker
15-
from tiled.client.container import Container
16-
from numpy.typing import NDArray
17-
18-
19-
class BlopExperiment(Experiment):
20-
def __init__(self, RE: RunEngine, readables: Sequence[Readable], movables: Sequence[NamedMovable], *args, **kwargs):
21-
super().__init__(*args, runner=BlopRunner(RE, readables, movables), **kwargs)
22-
self._validate_search_space(movables)
23-
self._validate_optimization_config(readables, movables)
24-
25-
def _validate_search_space(self, movables: Sequence[NamedMovable]):
26-
"""Validates that the parameters are compatible with the `Movable`s."""
27-
parameter_names = set(self.search_space.parameters.keys())
28-
for m, p in zip(movables, self.search_space.parameters.values(), strict=False):
29-
if m.name != p.name:
30-
if m.name not in parameter_names:
31-
raise ValueError(f"The movable name {m.name} is not a parameter in the search space.")
32-
raise ValueError(
33-
f"The moveable name {m.name} is in the search space, but the order is not correct. "
34-
"The order of movables must match the order of the parameters in the search space "
35-
"so we can unpack the arm correctly."
36-
)
37-
38-
def _validate_optimization_config(self, readables: Sequence[Readable], movables: Sequence[NamedMovable]):
39-
"""Validates that the objectives are compatible with the `Readable`s."""
40-
# Check that each metric is a BlopMetric
41-
metrics = self.optimization_config.objective.metrics
42-
if any(not isinstance(m, BlopMetric) for m in metrics):
43-
non_blop_metrics = "\n".join([f"{m.name}: {type(m)}" for m in metrics if not isinstance(m, BlopMetric)])
44-
raise ValueError(f"All objectives must inherit from `BlopMetric`, but found:\n{non_blop_metrics}")
45-
46-
# Check that each metric's parameters reference a `Readable` or `Movable`
47-
metric_param_names = {p for m in cast(Sequence[BlopMetric], metrics) for p in m.param_names}
48-
unmatched_parameters = {
49-
p
50-
for p in metric_param_names
51-
if not any(r.name in p for r in readables) and not any(m.name in p for m in movables)
52-
}
53-
if unmatched_parameters:
54-
raise ValueError(
55-
f"The following parameters are not referenced in any `Readable` or `Movable`: {unmatched_parameters}"
56-
)
57-
58-
59-
class BlopRunner(Runner):
60-
def __init__(self, RE: RunEngine, readables: Sequence[Readable], movables: Sequence[Movable], *args, **kwargs):
61-
super().__init__(*args, **kwargs)
62-
self._RE = RE
63-
self._readables = readables
64-
self._movables = movables
65-
66-
def _unpack_arm(self, arm: Arm) -> list[Movable | list[Any]]:
67-
"""Unpacks the arm's parameters into the format of the `list_scan` plan."""
8+
from bluesky.protocols import NamedMovable, Readable
9+
10+
11+
def create_blop_experiment(ax_client: AxClient, parameters: list[dict[str, Any]], *args, **kwargs) -> None:
12+
# Check that a movable key is present
13+
if not all("movable" in p for p in parameters):
14+
raise ValueError("All parameters must have a 'movable' key.")
15+
16+
# Check that a name attribute is present
17+
if not all(hasattr(p["movable"], "name") for p in parameters):
18+
raise ValueError("All 'movable' values must have a 'name' attribute.")
19+
20+
ax_parameters = copy.copy(parameters)
21+
for p in ax_parameters:
22+
p["name"] = p["movable"].name
23+
del p["movable"]
24+
25+
ax_client.create_experiment(*args, parameters=ax_parameters, **kwargs)
26+
27+
28+
def create_bluesky_evaluator(
29+
RE,
30+
db,
31+
readables: list[Readable],
32+
movables: list[NamedMovable],
33+
evaluation_function: Callable[[pd.DataFrame], dict[str, tuple[float, float]]],
34+
plan: Callable | None = None,
35+
) -> Callable:
36+
"""
37+
Create an evaluation function that runs a Bluesky plan and evaluates objectives.
38+
39+
Parameters:
40+
-----------
41+
RE : RunEngine
42+
The Bluesky RunEngine
43+
db : databroker
44+
The databroker/tiled instance
45+
movables : List
46+
List of Bluesky motors/devices to optimize
47+
detectors : List
48+
List of Bluesky detectors to read
49+
evaluation_function : Callable[[pd.DataFrame], Dict[str, Tuple[float, float]]]
50+
Function that takes a dataframe from databroker and returns
51+
a dictionary mapping objective names to (mean, sem) tuples
52+
plan : Callable, optional
53+
Custom Bluesky plan to use. If None, uses list_scan
54+
55+
Returns:
56+
--------
57+
Callable
58+
Function that takes an Ax parameterization and returns objective values
59+
"""
60+
plan_function = plan or list_scan
61+
62+
def evaluate(parameterization: dict[str, float] | dict[str, list[float]]) -> dict[str, tuple[float, float]]:
63+
# Prepare the parameters for the plan
6864
unpacked = []
69-
for m, p in zip(self._movables, arm.parameters.values(), strict=True):
70-
unpacked.append(m)
71-
unpacked.append([p])
72-
return unpacked
73-
74-
def run(self, trial: Trial, **kwargs):
75-
# TODO: Can probably do a yield from here instead and move the RunEngine call
76-
# to the outermost part of execution.
77-
# RE(trial.run()) or something like that.
78-
uid = self._RE(list_scan(self._readables, *self._unpack_arm(trial.arm)))
79-
return {"uid": uid}
80-
81-
def clone(self) -> "BlopRunner":
82-
"""Create a copy of this Runner."""
83-
return BlopRunner(RE=self._RE, readables=self._readables, movables=self._movables)
84-
85-
86-
class BlopMetric(MapMetric, ABC):
87-
def __init__(self, *args, **kwargs):
88-
super().__init__(*args, **kwargs)
89-
90-
@abstractmethod
91-
def unpack_trial(self, trial: BaseTrial) -> pd.DataFrame:
92-
"""Unpacks the trial data into a DataFrame where each row is the result of a single evaluation of an arm."""
93-
...
94-
95-
def fetch_trial_data(self, trial: BaseTrial, **kwargs) -> Result[MapData, Exception]:
96-
# Unpack the trial data into a dataframe where each row is
97-
# the result of a single evaluation of an arm.
98-
df = self.unpack_trial(trial)
99-
100-
# Create a dataframe that includes the arm name, metric name, and trial index
101-
df["arm_name"] = [arm_name for arm_name in trial.arms_by_name.keys()]
102-
df["metric_name"] = self.name
103-
df["trial_index"] = trial.index
104-
105-
return Ok(value=MapData(df=df))
106-
107-
108-
class TiledMetric(BlopMetric):
109-
def __init__(self, tiled_client: Container, *args, **kwargs):
110-
super().__init__(*args, **kwargs)
111-
self._tiled_client = tiled_client
112-
# Need to save these so we can clone the metric easily
113-
self._args = args
114-
self._kwargs = kwargs
115-
116-
def unpack_trial(self, trial: BaseTrial) -> list[Any]:
117-
# TODO: Implement this
118-
# uid = trial.run_metadata["uid"]
119-
raise NotImplementedError("TiledMetric is not implemented yet.")
120-
121-
def clone(self) -> "TiledMetric":
122-
return self.__class__(self._tiled_client, *self._args, **self._kwargs)
123-
124-
125-
class DatabrokerMetric(BlopMetric):
126-
def __init__(self, broker: Broker, *args, **kwargs):
127-
super().__init__(*args, **kwargs)
128-
self._broker = broker
129-
# Need to save these so we can clone the metric easily
130-
self._args = args
131-
self._kwargs = kwargs
132-
133-
def unpack_trial(self, trial: BaseTrial) -> pd.DataFrame:
134-
"""Unpacks the trial using the databroker client.
135-
136-
Parameters
137-
----------
138-
trial: BaseTrial
139-
The trial to unpack.
140-
141-
Returns
142-
-------
143-
pd.DataFrame
144-
The trial data.
145-
"""
146-
uid = trial.run_metadata["uid"]
147-
# TODO: Why is [0] needed here?
148-
df: pd.DataFrame = self._broker[uid][0].table(fill=True)
149-
return df
150-
151-
def clone(self) -> "DatabrokerMetric":
152-
return self.__class__(self._broker, *self._args, **self._kwargs)
65+
for m in movables:
66+
if m.name in parameterization:
67+
unpacked.append(m)
68+
if isinstance(parameterization[m.name], float):
69+
unpacked.append([parameterization[m.name]])
70+
elif isinstance(parameterization[m.name], list):
71+
unpacked.append(parameterization[m.name])
72+
else:
73+
raise ValueError(f"Parameter {m.name} must be a float or list of floats.")
74+
else:
75+
raise ValueError(f"Parameter {m.name} not found in parameterization. Parameterization: {parameterization}")
76+
77+
# Run the plan
78+
uid = RE(plan_function(readables, *unpacked))
79+
80+
# Fetch the data
81+
results_df = db[uid][0].table(fill=True)
82+
83+
# Evaluate the data
84+
return evaluation_function(results_df)
85+
86+
return evaluate
Lines changed: 39 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -1,92 +1,53 @@
1-
from ax import (
2-
ComparisonOp,
3-
Models,
4-
Objective,
5-
OptimizationConfig,
6-
OutcomeConstraint,
7-
ParameterType,
8-
RangeParameter,
9-
SearchSpace,
10-
)
1+
import pandas as pd
2+
from ax.service.ax_client import AxClient
3+
from ax.service.utils.instantiation import ObjectiveProperties
114

12-
from blop.integrations.ax import BlopExperiment, DatabrokerMetric
5+
from blop.integrations.ax import create_blop_experiment, create_bluesky_evaluator
136
from blop.sim import Beamline
147
from blop.utils import get_beam_stats
158

169

17-
def test_ax_experiment(RE, db):
10+
def test_ax_client_experiment(RE, db):
1811
beamline = Beamline(name="bl")
1912
beamline.det.noise.put(False)
2013

21-
parameters = [
22-
RangeParameter(name="bl_kbv_dsv", lower=-5.0, upper=5.0, parameter_type=ParameterType.FLOAT),
23-
RangeParameter(name="bl_kbv_usv", lower=-5.0, upper=5.0, parameter_type=ParameterType.FLOAT),
24-
RangeParameter(name="bl_kbh_dsh", lower=-5.0, upper=5.0, parameter_type=ParameterType.FLOAT),
25-
RangeParameter(name="bl_kbh_ush", lower=-5.0, upper=5.0, parameter_type=ParameterType.FLOAT),
26-
]
27-
search_space = SearchSpace(parameters=parameters)
28-
29-
image_sum_metric = DatabrokerMetric(
30-
broker=db,
31-
name="bl_det_sum",
32-
param_names=["bl_det_image"],
33-
compute_fn=lambda df: (df["bl_det_image"].sum().mean(), 0.0),
34-
)
35-
optimization_config = OptimizationConfig(
36-
objective=Objective(metric=image_sum_metric, minimize=False),
37-
outcome_constraints=[
38-
OutcomeConstraint(
39-
metric=DatabrokerMetric(
40-
broker=db,
41-
param_names=["bl_det_image"],
42-
name="bl_det_wid_x",
43-
compute_fn=lambda df: (get_beam_stats(df["bl_det_image"].iloc[0], threshold=0.0)["wid_x"], 0.0),
44-
),
45-
op=ComparisonOp.LEQ,
46-
bound=10,
47-
relative=False,
48-
),
49-
OutcomeConstraint(
50-
metric=DatabrokerMetric(
51-
broker=db,
52-
param_names=["bl_det_image"],
53-
name="bl_det_wid_y",
54-
compute_fn=lambda df: (get_beam_stats(df["bl_det_image"].iloc[0], threshold=0.0)["wid_y"], 0.0),
55-
),
56-
op=ComparisonOp.LEQ,
57-
bound=10,
58-
relative=False,
59-
),
14+
ax_client = AxClient()
15+
create_blop_experiment(
16+
ax_client,
17+
parameters=[
18+
{
19+
"movable": beamline.kbv_dsv,
20+
"type": "range",
21+
"bounds": [-5.0, 5.0],
22+
},
23+
{
24+
"movable": beamline.kbv_usv,
25+
"type": "range",
26+
"bounds": [-5.0, 5.0],
27+
},
28+
{
29+
"movable": beamline.kbh_dsh,
30+
"type": "range",
31+
"bounds": [-5.0, 5.0],
32+
},
33+
{
34+
"movable": beamline.kbh_ush,
35+
"type": "range",
36+
"bounds": [-5.0, 5.0],
37+
},
6038
],
39+
objectives={"beam_intensity": ObjectiveProperties(minimize=False), "beam_area": ObjectiveProperties(minimize=True)},
6140
)
6241

63-
readables = [beamline.det]
64-
movables = [beamline.kbv_dsv, beamline.kbv_usv, beamline.kbh_dsh, beamline.kbh_ush]
42+
def evaluate(results_df: pd.DataFrame) -> dict[str, tuple[float, float]]:
43+
stats = get_beam_stats(results_df["bl_det_image"].iloc[0])
44+
return {"beam_intensity": (stats["sum"], None), "beam_area": (stats["area"], None)}
6545

66-
experiment = BlopExperiment(
67-
RE=RE,
68-
readables=readables,
69-
movables=movables,
70-
name="test_ax_experiment",
71-
search_space=search_space,
72-
optimization_config=optimization_config,
46+
evaluator = create_bluesky_evaluator(
47+
RE, db, [beamline.det], [beamline.kbv_dsv, beamline.kbv_usv, beamline.kbh_dsh, beamline.kbh_ush], evaluate
7348
)
49+
for _ in range(25):
50+
parameterization, trial_index = ax_client.get_next_trial()
51+
ax_client.complete_trial(trial_index=trial_index, raw_data=evaluator(parameterization))
7452

75-
sobol = Models.SOBOL(experiment.search_space)
76-
for _ in range(5):
77-
trial = experiment.new_trial(generator_run=sobol.gen(1))
78-
# TODO: Try RE(trial.run())
79-
trial.run()
80-
trial.mark_completed()
81-
82-
best_arm = None
83-
for _ in range(5):
84-
gpei = Models.BOTORCH_MODULAR(experiment=experiment, data=experiment.fetch_data())
85-
generator_run = gpei.gen(1)
86-
best_arm, _ = generator_run.best_arm_predictions
87-
trial = experiment.new_trial(generator_run=generator_run)
88-
trial.run()
89-
trial.mark_completed()
90-
91-
experiment.fetch_data()
92-
assert best_arm is not None
53+
print(ax_client.generation_strategy.trials_as_df)

0 commit comments

Comments
 (0)