|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +from pathlib import Path |
| 4 | + |
| 5 | +import pytest |
| 6 | + |
| 7 | +from neps.optimizers import OptimizerInfo |
| 8 | +from neps.optimizers.algorithms import random_search |
| 9 | +from neps.runtime import ( |
| 10 | + DefaultReportValues, |
| 11 | + DefaultWorker, |
| 12 | + OnErrorPossibilities, |
| 13 | + WorkerSettings, |
| 14 | +) |
| 15 | +from neps.space import Float, SearchSpace |
| 16 | +from neps.state import NePSState, OptimizationState, SeedSnapshot |
| 17 | + |
| 18 | + |
| 19 | +@pytest.fixture |
| 20 | +def neps_state(tmp_path: Path) -> NePSState: |
| 21 | + return NePSState.create_or_load( |
| 22 | + path=tmp_path / "neps_state", |
| 23 | + optimizer_info=OptimizerInfo(name="blah", info={"nothing": "here"}), |
| 24 | + optimizer_state=OptimizationState( |
| 25 | + budget=None, seed_snapshot=SeedSnapshot.new_capture(), shared_state={} |
| 26 | + ), |
| 27 | + ) |
| 28 | + |
| 29 | + |
| 30 | +def test_create_worker_manual_id(neps_state: NePSState) -> None: |
| 31 | + settings = WorkerSettings( |
| 32 | + on_error=OnErrorPossibilities.IGNORE, |
| 33 | + default_report_values=DefaultReportValues(), |
| 34 | + evaluations_to_spend=1, |
| 35 | + include_in_progress_evaluations_towards_maximum=True, |
| 36 | + cost_to_spend=None, |
| 37 | + max_evaluations_for_worker=None, |
| 38 | + max_evaluation_time_total_seconds=None, |
| 39 | + max_wallclock_time_for_worker_seconds=None, |
| 40 | + max_evaluation_time_for_worker_seconds=None, |
| 41 | + max_cost_for_worker=None, |
| 42 | + batch_size=None, |
| 43 | + fidelities_to_spend=None, |
| 44 | + ) |
| 45 | + |
| 46 | + def eval_fn(config: dict) -> float: |
| 47 | + return 1.0 |
| 48 | + |
| 49 | + test_worker_id = "my_worker_123" |
| 50 | + optimizer = random_search(SearchSpace({"a": Float(0, 1)})) |
| 51 | + |
| 52 | + worker = DefaultWorker.new( |
| 53 | + state=neps_state, |
| 54 | + settings=settings, |
| 55 | + optimizer=optimizer, |
| 56 | + evaluation_fn=eval_fn, |
| 57 | + worker_id=test_worker_id, |
| 58 | + ) |
| 59 | + |
| 60 | + assert worker.worker_id == test_worker_id |
| 61 | + assert neps_state.lock_and_get_optimizer_state().worker_ids == [test_worker_id] |
| 62 | + |
| 63 | + |
| 64 | +def test_create_worker_auto_id(neps_state: NePSState) -> None: |
| 65 | + settings = WorkerSettings( |
| 66 | + on_error=OnErrorPossibilities.IGNORE, |
| 67 | + default_report_values=DefaultReportValues(), |
| 68 | + evaluations_to_spend=1, |
| 69 | + include_in_progress_evaluations_towards_maximum=True, |
| 70 | + cost_to_spend=None, |
| 71 | + max_evaluations_for_worker=None, |
| 72 | + max_evaluation_time_total_seconds=None, |
| 73 | + max_wallclock_time_for_worker_seconds=None, |
| 74 | + max_evaluation_time_for_worker_seconds=None, |
| 75 | + max_cost_for_worker=None, |
| 76 | + batch_size=None, |
| 77 | + fidelities_to_spend=None, |
| 78 | + ) |
| 79 | + |
| 80 | + def eval_fn(config: dict) -> float: |
| 81 | + return 1.0 |
| 82 | + |
| 83 | + optimizer = random_search(SearchSpace({"a": Float(0, 1)})) |
| 84 | + |
| 85 | + worker = DefaultWorker.new( |
| 86 | + state=neps_state, |
| 87 | + settings=settings, |
| 88 | + optimizer=optimizer, |
| 89 | + evaluation_fn=eval_fn, |
| 90 | + ) |
| 91 | + |
| 92 | + assert worker.worker_id == "worker_0" |
| 93 | + assert neps_state.lock_and_get_optimizer_state().worker_ids == [worker.worker_id] |
0 commit comments