Skip to content

Commit e5ad72d

Browse files
authored
Merge branch 'master' into txt-files-fidelity-stopping-crit
2 parents 7d983fe + 8bd516e commit e5ad72d

File tree

2 files changed

+19
-19
lines changed

2 files changed

+19
-19
lines changed

neps/optimizers/utils/brackets.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -422,20 +422,24 @@ def __post_init__(self) -> None:
422422
def next(self) -> BracketAction:
423423
# Starting from the highest rung going down, check if any configs to promote
424424
for lower, upper in reversed(list(pairwise(self.rungs))):
425-
k = len(lower) // self.eta
426-
if k == 0:
427-
continue # Not enough configs to promote yet
425+
import copy
428426

429-
if self.is_multi_objective:
430-
best_k = lower.mo_selector(selector=self.mo_selector, k=k)
431-
else:
432-
best_k = lower.top_k(k)
433-
candidates = best_k.drop(
427+
lower_dropped = copy.deepcopy(lower)
428+
lower_dropped.table = lower_dropped.table.drop(
434429
upper.config_ids,
435430
axis="index",
436431
level="id",
437432
errors="ignore",
438433
)
434+
k = len(lower_dropped) // self.eta
435+
if k == 0:
436+
continue # Not enough configs to promote yet
437+
438+
if self.is_multi_objective:
439+
best_k = lower_dropped.mo_selector(selector=self.mo_selector, k=k)
440+
else:
441+
best_k = lower_dropped.top_k(k)
442+
candidates = best_k.copy(deep=True)
439443
if candidates.empty:
440444
continue # No configs that aren't already promoted
441445

@@ -616,11 +620,11 @@ def priority(x: BracketAction) -> tuple[int, int]:
616620
case PromoteAction(new_rung=new_rung):
617621
return 0, new_rung
618622
case SampleAction(sample_at_rung):
619-
return 0, sample_at_rung
623+
return 1, sample_at_rung
620624
case "pending":
621-
return 1, 0
622-
case "done":
623625
return 2, 0
626+
case "done":
627+
return 3, 0
624628
case _:
625629
raise RuntimeError("This is a bug!")
626630

neps/space/domain.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -282,14 +282,10 @@ def from_unit(self, x: Tensor, *, dtype: torch.dtype | None = None) -> Tensor:
282282
x = torch.round(x)
283283

284284
if (x > upper).any():
285-
import warnings
286-
287-
warnings.warn( # noqa: B028
288-
"Decoded value is above the upper bound of the domain. "
289-
"Clipping to the upper bound. "
290-
"This is likely due floating point precision in `torch.exp(x)` "
291-
"with torch.float64."
292-
)
285+
# Decoded value is above the upper bound of the domain.
286+
# Clipping to the upper bound.
287+
# This is likely due floating point precision in `torch.exp(x)`
288+
# with torch.float64.
293289
x = torch.clip(x, max=self.upper)
294290

295291
return x.type(dtype)

0 commit comments

Comments
 (0)