Skip to content

Commit ed80862

Browse files
authored
feat: Change Worker ID (#235)
2 parents c663171 + 07f1053 commit ed80862

File tree

7 files changed

+165
-9
lines changed

7 files changed

+165
-9
lines changed

docs/getting_started.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ pip install neural-pipeline-search
1313

1414
## The 3 Main Components
1515

16-
1. **Establish a [`pipeline_space=`](reference/pipeline_space.md)**:
16+
1. **Establish a [`pipeline_space`](reference/pipeline_space.md)**:
1717

1818
```python
1919
pipeline_space={

neps/api.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def run( # noqa: C901, D417, PLR0913
4646
objective_value_on_error: float | None = None,
4747
cost_value_on_error: float | None = None,
4848
sample_batch_size: int | None = None,
49+
worker_id: str | None = None,
4950
optimizer: (
5051
OptimizerChoice
5152
| Mapping[str, Any]
@@ -249,6 +250,22 @@ def evaluate_pipeline(some_parameter: float) -> float:
249250
evaluations, even if they were to come in relatively
250251
quickly.
251252
253+
worker_id: An optional string to identify the worker (run instance).
254+
If not provided, a `worker_id` will be automatically generated using the pattern:
255+
`worker_<N>`, where `<N>` is a unique integer for each worker and increments with each new worker.
256+
A list of all workers created so far is stored in
257+
`root_directory/optimizer_state.pkl` under the attribute `worker_ids`.
258+
259+
??? tip "Why specify a `worker_id`?"
260+
Specifying a `worker_id` is useful for tracking which worker performed specific tasks
261+
in the results. For example, when debugging or running on a cluster, you can include
262+
the process ID and machine name in the `worker_id` for better traceability.
263+
264+
??? warning "Duplication of `worker_id`"
265+
Ensure that each worker has a unique `worker_id`. If a duplicate `worker_id` is detected,
266+
the optimization process will be stopped with an error to prevent overwriting the results
267+
of other workers.
268+
252269
optimizer: Which optimizer to use.
253270
254271
Not sure which to use? Leave this at `"auto"` and neps will
@@ -513,6 +530,7 @@ def __call__(
513530
overwrite_optimization_dir=overwrite_root_directory,
514531
sample_batch_size=sample_batch_size,
515532
write_summary_to_disk=write_summary_to_disk,
533+
worker_id=worker_id,
516534
)
517535

518536
post_run_csv(root_directory)

neps/runtime.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from __future__ import annotations
44

5-
import datetime
65
import logging
76
import os
87
import shutil
@@ -52,11 +51,6 @@
5251
logger = logging.getLogger(__name__)
5352

5453

55-
def _default_worker_name() -> str:
56-
isoformat = datetime.datetime.now(datetime.timezone.utc).isoformat()
57-
return f"{os.getpid()}-{isoformat}"
58-
59-
6054
_DDP_ENV_VAR_NAME = "NEPS_DDP_TRIAL_ID"
6155

6256

@@ -197,12 +191,13 @@ def new(
197191
worker_id: str | None = None,
198192
) -> DefaultWorker:
199193
"""Create a new worker."""
194+
worker_id = state.lock_and_set_new_worker_id(worker_id)
200195
return DefaultWorker(
201196
state=state,
202197
optimizer=optimizer,
203198
settings=settings,
204199
evaluation_fn=evaluation_fn,
205-
worker_id=worker_id if worker_id is not None else _default_worker_name(),
200+
worker_id=worker_id,
206201
)
207202

208203
def _check_worker_local_settings(
@@ -882,6 +877,7 @@ def _launch_runtime( # noqa: PLR0913
882877
max_evaluations_for_worker: int | None,
883878
sample_batch_size: int | None,
884879
write_summary_to_disk: bool = True,
880+
worker_id: str | None = None,
885881
) -> None:
886882
default_report_values = DefaultReportValues(
887883
objective_value_on_error=objective_value_on_error,
@@ -926,6 +922,7 @@ def _launch_runtime( # noqa: PLR0913
926922
)
927923
),
928924
shared_state=None, # TODO: Unused for the time being...
925+
worker_ids=None,
929926
),
930927
)
931928
break
@@ -990,5 +987,6 @@ def _launch_runtime( # noqa: PLR0913
990987
optimizer=optimizer,
991988
evaluation_fn=evaluation_fn,
992989
settings=settings,
990+
worker_id=worker_id,
993991
)
994992
worker.run()

neps/state/neps_state.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,45 @@ class NePSState:
257257
all_best_configs: list = field(default_factory=list)
258258
"""Trajectory to the newest incbumbent"""
259259

260+
def lock_and_set_new_worker_id(self, worker_id: str | None = None) -> str:
261+
"""Acquire the state lock and set a new worker id in the optimizer state.
262+
263+
Args:
264+
worker_id: The worker id to set. If `None`, a new worker id will be generated.
265+
266+
Returns:
267+
The worker id that was set.
268+
269+
Raises:
270+
NePSError: If the worker id already exists.
271+
"""
272+
with self._optimizer_lock.lock():
273+
with self._optimizer_state_path.open("rb") as f:
274+
opt_state: OptimizationState = pickle.load(f) # noqa: S301
275+
assert isinstance(opt_state, OptimizationState)
276+
worker_id = (
277+
worker_id
278+
if worker_id is not None
279+
else _get_worker_name(
280+
len(opt_state.worker_ids)
281+
if opt_state.worker_ids is not None
282+
else 0
283+
)
284+
)
285+
if opt_state.worker_ids and worker_id in opt_state.worker_ids:
286+
raise NePSError(
287+
f"Worker id '{worker_id}' already exists, \
288+
reserved worker ids: {opt_state.worker_ids}"
289+
)
290+
if opt_state.worker_ids is None:
291+
opt_state.worker_ids = []
292+
293+
opt_state.worker_ids.append(worker_id)
294+
bytes_ = pickle.dumps(opt_state, protocol=pickle.HIGHEST_PROTOCOL)
295+
with atomic_write(self._optimizer_state_path, "wb") as f:
296+
f.write(bytes_)
297+
return worker_id
298+
260299
def lock_and_read_trials(self) -> dict[str, Trial]:
261300
"""Acquire the state lock and read the trials."""
262301
with self._trial_lock.lock():
@@ -683,3 +722,7 @@ def _deserialize_optimizer_info(path: Path) -> OptimizerInfo:
683722
f" {path}. Expected a `dict` or `None`."
684723
)
685724
return OptimizerInfo(name=name, info=info or {})
725+
726+
727+
def _get_worker_name(idx: int) -> str:
728+
return f"worker_{idx}"

neps/state/optimizer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,5 @@ class OptimizationState:
4747
Please reach out to @eddiebergman if you have a use case for this so we can make
4848
it more robust.
4949
"""
50+
worker_ids: list[str] | None = None
51+
"""The list of workers that have been created so far."""

neps_examples/basic_usage/hyperparameters.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import logging
22
import numpy as np
33
import neps
4-
4+
import socket
5+
import os
56
# This example demonstrates how to use NePS to optimize hyperparameters
67
# of a pipeline. The pipeline is a simple function that takes in
78
# five hyperparameters and returns their sum.
@@ -28,4 +29,5 @@ def evaluate_pipeline(float1, float2, categorical, integer1, integer2):
2829
pipeline_space=pipeline_space,
2930
root_directory="results/hyperparameters_example",
3031
evaluations_to_spend=30,
32+
worker_id=f"worker_1-{socket.gethostname()}-{os.getpid()}",
3133
)
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
from __future__ import annotations
2+
3+
from pathlib import Path
4+
5+
import pytest
6+
7+
from neps.optimizers import OptimizerInfo
8+
from neps.optimizers.algorithms import random_search
9+
from neps.runtime import (
10+
DefaultReportValues,
11+
DefaultWorker,
12+
OnErrorPossibilities,
13+
WorkerSettings,
14+
)
15+
from neps.space import Float, SearchSpace
16+
from neps.state import NePSState, OptimizationState, SeedSnapshot
17+
18+
19+
@pytest.fixture
20+
def neps_state(tmp_path: Path) -> NePSState:
21+
return NePSState.create_or_load(
22+
path=tmp_path / "neps_state",
23+
optimizer_info=OptimizerInfo(name="blah", info={"nothing": "here"}),
24+
optimizer_state=OptimizationState(
25+
budget=None, seed_snapshot=SeedSnapshot.new_capture(), shared_state={}
26+
),
27+
)
28+
29+
30+
def test_create_worker_manual_id(neps_state: NePSState) -> None:
31+
settings = WorkerSettings(
32+
on_error=OnErrorPossibilities.IGNORE,
33+
default_report_values=DefaultReportValues(),
34+
evaluations_to_spend=1,
35+
include_in_progress_evaluations_towards_maximum=True,
36+
cost_to_spend=None,
37+
max_evaluations_for_worker=None,
38+
max_evaluation_time_total_seconds=None,
39+
max_wallclock_time_for_worker_seconds=None,
40+
max_evaluation_time_for_worker_seconds=None,
41+
max_cost_for_worker=None,
42+
batch_size=None,
43+
fidelities_to_spend=None,
44+
)
45+
46+
def eval_fn(config: dict) -> float:
47+
return 1.0
48+
49+
test_worker_id = "my_worker_123"
50+
optimizer = random_search(SearchSpace({"a": Float(0, 1)}))
51+
52+
worker = DefaultWorker.new(
53+
state=neps_state,
54+
settings=settings,
55+
optimizer=optimizer,
56+
evaluation_fn=eval_fn,
57+
worker_id=test_worker_id,
58+
)
59+
60+
assert worker.worker_id == test_worker_id
61+
assert neps_state.lock_and_get_optimizer_state().worker_ids == [test_worker_id]
62+
63+
64+
def test_create_worker_auto_id(neps_state: NePSState) -> None:
65+
settings = WorkerSettings(
66+
on_error=OnErrorPossibilities.IGNORE,
67+
default_report_values=DefaultReportValues(),
68+
evaluations_to_spend=1,
69+
include_in_progress_evaluations_towards_maximum=True,
70+
cost_to_spend=None,
71+
max_evaluations_for_worker=None,
72+
max_evaluation_time_total_seconds=None,
73+
max_wallclock_time_for_worker_seconds=None,
74+
max_evaluation_time_for_worker_seconds=None,
75+
max_cost_for_worker=None,
76+
batch_size=None,
77+
fidelities_to_spend=None,
78+
)
79+
80+
def eval_fn(config: dict) -> float:
81+
return 1.0
82+
83+
optimizer = random_search(SearchSpace({"a": Float(0, 1)}))
84+
85+
worker = DefaultWorker.new(
86+
state=neps_state,
87+
settings=settings,
88+
optimizer=optimizer,
89+
evaluation_fn=eval_fn,
90+
)
91+
92+
assert worker.worker_id == "worker_0"
93+
assert neps_state.lock_and_get_optimizer_state().worker_ids == [worker.worker_id]

0 commit comments

Comments
 (0)