diff --git a/torchopt/optim/func/base.py b/torchopt/optim/func/base.py index 7bb27877..7d2b4224 100644 --- a/torchopt/optim/func/base.py +++ b/torchopt/optim/func/base.py @@ -87,15 +87,16 @@ def step( if inplace is None: inplace = self.inplace - # Step parameter only - grads = torch.autograd.grad(loss, params, create_graph=True, allow_unused=True) - updates, self.optim_state = self.impl.update( - grads, - self.optim_state, - params=params, - inplace=inplace, - ) - return apply_updates(params, updates, inplace=inplace) + with torch.enable_grad(): + # Step parameters only + grads = torch.autograd.grad(loss, params, create_graph=True, allow_unused=True) + updates, self.optim_state = self.impl.update( + grads, + self.optim_state, + params=params, + inplace=inplace, + ) + return apply_updates(params, updates, inplace=inplace) def state_dict(self) -> OptState: """Extract the references of the optimizer states. diff --git a/torchopt/optim/meta/base.py b/torchopt/optim/meta/base.py index 73ecdde7..e23d51d6 100644 --- a/torchopt/optim/meta/base.py +++ b/torchopt/optim/meta/base.py @@ -66,32 +66,32 @@ def step(self, loss: torch.Tensor) -> None: # pylint: disable=too-many-locals loss (torch.Tensor): The loss that is used to compute the gradients to the network parameters. """ - # Step parameter only for i, (param_container, state) in enumerate( zip(self.param_containers_groups, self.state_groups), ): flat_params: TupleOfTensors flat_params, container_treespec = pytree.tree_flatten_as_tuple(param_container) # type: ignore[arg-type] + if isinstance(state, UninitializedState): state = self.impl.init(flat_params) - grads = torch.autograd.grad( - loss, - flat_params, - create_graph=True, - allow_unused=True, - ) - updates, new_state = self.impl.update( - grads, - state, - params=flat_params, - inplace=False, - ) - self.state_groups[i] = new_state - flat_new_params = apply_updates(flat_params, updates, inplace=False) + + with torch.enable_grad(): + # Step parameters only + grads = torch.autograd.grad(loss, flat_params, create_graph=True, allow_unused=True) + updates, new_state = self.impl.update( + grads, + state, + params=flat_params, + inplace=False, + ) + flat_new_params = apply_updates(flat_params, updates, inplace=False) + new_params: ModuleTensorContainers = pytree.tree_unflatten( # type: ignore[assignment] container_treespec, flat_new_params, ) + + self.state_groups[i] = new_state for container, new_param in zip(param_container, new_params): container.update(new_param)