Skip to content

Commit 1238cde

Browse files
committed
fix(optim): move apply_updates in grad context
1 parent fa77c1a commit 1238cde

File tree

2 files changed

+14
-11
lines changed

2 files changed

+14
-11
lines changed

torchopt/optim/func/base.py

+7-8
Original file line numberDiff line numberDiff line change
@@ -90,14 +90,13 @@ def step(
9090
with torch.enable_grad():
9191
# Step parameters only
9292
grads = torch.autograd.grad(loss, params, create_graph=True, allow_unused=True)
93-
94-
updates, self.optim_state = self.impl.update(
95-
grads,
96-
self.optim_state,
97-
params=params,
98-
inplace=inplace,
99-
)
100-
return apply_updates(params, updates, inplace=inplace)
93+
updates, self.optim_state = self.impl.update(
94+
grads,
95+
self.optim_state,
96+
params=params,
97+
inplace=inplace,
98+
)
99+
return apply_updates(params, updates, inplace=inplace)
101100

102101
def state_dict(self) -> OptState:
103102
"""Extract the references of the optimizer states.

torchopt/optim/meta/base.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,14 @@ def step(self, loss: torch.Tensor) -> None: # pylint: disable=too-many-locals
7878

7979
with torch.enable_grad():
8080
grads = torch.autograd.grad(loss, flat_params, create_graph=True, allow_unused=True)
81+
updates, new_state = self.impl.update(
82+
grads,
83+
state,
84+
params=flat_params,
85+
inplace=False,
86+
)
87+
flat_new_params = apply_updates(flat_params, updates, inplace=False)
8188

82-
updates, new_state = self.impl.update(grads, state, params=flat_params, inplace=False)
83-
84-
flat_new_params = apply_updates(flat_params, updates, inplace=False)
8589
new_params: ModuleTensorContainers = pytree.tree_unflatten( # type: ignore[assignment]
8690
container_treespec,
8791
flat_new_params,

0 commit comments

Comments
 (0)