Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit c1973e1

Browse files
committed
Fix optimizer freeze support in HMC
1 parent 90974f4 commit c1973e1

File tree

3 files changed

+75
-4
lines changed

3 files changed

+75
-4
lines changed

fortuna/prob_model/posterior/sgmcmc/hmc/hmc_trainer.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,11 @@
2424
Array,
2525
Batch,
2626
)
27+
from fortuna.utils.freeze import (
28+
has_multiple_opt_state,
29+
get_trainable_opt_state,
30+
update_trainable_opt_state,
31+
)
2732

2833

2934
class HMCTrainer(MAPTrainer):
@@ -46,9 +51,12 @@ def training_step(
4651
unravel=unravel,
4752
**kwargs,
4853
)
49-
state = state.replace(
50-
opt_state=state.opt_state._replace(log_prob=aux["loss"]),
51-
)
54+
if has_multiple_opt_state(state):
55+
opt_state = get_trainable_opt_state(state)._replace(log_prob=aux["loss"])
56+
state = update_trainable_opt_state(state, opt_state)
57+
else:
58+
opt_state = state.opt_state._replace(log_prob=aux["loss"])
59+
state = state.replace(opt_state=opt_state)
5260
return state, aux
5361

5462
def __str__(self):

fortuna/utils/freeze.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
from optax import (
1919
multi_transform,
2020
set_to_zero,
21+
MultiTransformState,
2122
)
23+
from optax._src.wrappers import MaskedState
2224

2325
from fortuna.typing import (
2426
AnyKey,
@@ -27,6 +29,8 @@
2729
Params,
2830
)
2931

32+
from fortuna.prob_model.posterior.state import PosteriorState
33+
3034

3135
def all_values_in_labels(values: Iterable, labels: Any) -> None:
3236
"""
@@ -81,6 +85,65 @@ def freeze_optimizer(
8185
return multi_transform(partition_optimizers, partition_params)
8286

8387

88+
def has_multiple_opt_state(state: PosteriorState):
89+
"""
90+
Check if a given posterior state containts multiple optimizer states.
91+
92+
Parameters
93+
----------
94+
state: PosteriorState
95+
An instance of `PosteriorState`.
96+
97+
Returns
98+
-------
99+
bool
100+
"""
101+
return isinstance(state.opt_state, MultiTransformState)
102+
103+
104+
def get_trainable_opt_state(state: PosteriorState):
105+
"""
106+
Get a trainable optimizer state.
107+
108+
Parameters
109+
----------
110+
state: PosteriorState
111+
An instance of `PosteriorState`.
112+
113+
Returns
114+
-------
115+
opt_state: Any
116+
An instance of trainable optimizer state.
117+
"""
118+
return state.opt_state.inner_states["trainable"].inner_state
119+
120+
121+
def update_trainable_opt_state(state: PosteriorState, opt_state: Any):
122+
"""
123+
Update a trainable optimizer state.
124+
125+
Parameters
126+
----------
127+
state: PosteriorState
128+
An instance of `PosteriorState`.
129+
opt_state: Any
130+
An instance of trainable optimizer state.
131+
132+
Returns
133+
-------
134+
PosteriorState
135+
An updated posterior state.
136+
"""
137+
trainable_state = MaskedState(inner_state=opt_state)
138+
opt_state = MultiTransformState(
139+
inner_states={
140+
k: (trainable_state if k == "trainable" else v)
141+
for k, v in state.opt_state.inner_states.items()
142+
}
143+
)
144+
return state.replace(opt_state=opt_state)
145+
146+
84147
def get_trainable_paths(
85148
params: Params,
86149
freeze_fun: Optional[Callable[[Tuple[AnyKey, ...], Array], str]],

tests/fortuna/prob_model/test_train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,7 @@ def dryrun_task(task, method):
386386
)
387387
state = (
388388
prob_model.posterior.state.get()
389-
if method not in ["deep_ensemble", "sghmc", "cyclical_sgld"]
389+
if method not in ["deep_ensemble", "sghmc", "cyclical_sgld", "hmc"]
390390
else prob_model.posterior.state.get(-1)
391391
)
392392
model_editor_params = state.params["model_editor"]["params"].unfreeze()

0 commit comments

Comments
 (0)