From 6c5304ab12a0a45cfc3378450e470608d5746b75 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Tue, 21 May 2024 14:32:52 +0000 Subject: [PATCH 1/2] chore(optim): wrap `torch.autograd.grad()` with `torch.enable_grad()` context --- torchopt/optim/func/base.py | 6 ++++-- torchopt/optim/meta/base.py | 24 ++++++++++-------------- 2 files changed, 14 insertions(+), 16 deletions(-) diff --git a/torchopt/optim/func/base.py b/torchopt/optim/func/base.py index 7bb27877..fede4c2c 100644 --- a/torchopt/optim/func/base.py +++ b/torchopt/optim/func/base.py @@ -87,8 +87,10 @@ def step( if inplace is None: inplace = self.inplace - # Step parameter only - grads = torch.autograd.grad(loss, params, create_graph=True, allow_unused=True) + with torch.enable_grad(): + # Step parameters only + grads = torch.autograd.grad(loss, params, create_graph=True, allow_unused=True) + updates, self.optim_state = self.impl.update( grads, self.optim_state, diff --git a/torchopt/optim/meta/base.py b/torchopt/optim/meta/base.py index 73ecdde7..331431e9 100644 --- a/torchopt/optim/meta/base.py +++ b/torchopt/optim/meta/base.py @@ -66,32 +66,28 @@ def step(self, loss: torch.Tensor) -> None: # pylint: disable=too-many-locals loss (torch.Tensor): The loss that is used to compute the gradients to the network parameters. """ - # Step parameter only for i, (param_container, state) in enumerate( zip(self.param_containers_groups, self.state_groups), ): flat_params: TupleOfTensors flat_params, container_treespec = pytree.tree_flatten_as_tuple(param_container) # type: ignore[arg-type] + if isinstance(state, UninitializedState): state = self.impl.init(flat_params) - grads = torch.autograd.grad( - loss, - flat_params, - create_graph=True, - allow_unused=True, - ) - updates, new_state = self.impl.update( - grads, - state, - params=flat_params, - inplace=False, - ) - self.state_groups[i] = new_state + + with torch.enable_grad(): + # Step parameters only + grads = torch.autograd.grad(loss, flat_params, create_graph=True, allow_unused=True) + + updates, new_state = self.impl.update(grads, state, params=flat_params, inplace=False) + flat_new_params = apply_updates(flat_params, updates, inplace=False) new_params: ModuleTensorContainers = pytree.tree_unflatten( # type: ignore[assignment] container_treespec, flat_new_params, ) + + self.state_groups[i] = new_state for container, new_param in zip(param_container, new_params): container.update(new_param) From 5f9b9181d064439c557ff72039aa014baeaeff70 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Tue, 21 May 2024 18:04:51 +0000 Subject: [PATCH 2/2] fix(optim): move `apply_updates` in grad context --- torchopt/optim/func/base.py | 15 +++++++-------- torchopt/optim/meta/base.py | 10 +++++++--- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/torchopt/optim/func/base.py b/torchopt/optim/func/base.py index fede4c2c..7d2b4224 100644 --- a/torchopt/optim/func/base.py +++ b/torchopt/optim/func/base.py @@ -90,14 +90,13 @@ def step( with torch.enable_grad(): # Step parameters only grads = torch.autograd.grad(loss, params, create_graph=True, allow_unused=True) - - updates, self.optim_state = self.impl.update( - grads, - self.optim_state, - params=params, - inplace=inplace, - ) - return apply_updates(params, updates, inplace=inplace) + updates, self.optim_state = self.impl.update( + grads, + self.optim_state, + params=params, + inplace=inplace, + ) + return apply_updates(params, updates, inplace=inplace) def state_dict(self) -> OptState: """Extract the references of the optimizer states. diff --git a/torchopt/optim/meta/base.py b/torchopt/optim/meta/base.py index 331431e9..e23d51d6 100644 --- a/torchopt/optim/meta/base.py +++ b/torchopt/optim/meta/base.py @@ -78,10 +78,14 @@ def step(self, loss: torch.Tensor) -> None: # pylint: disable=too-many-locals with torch.enable_grad(): # Step parameters only grads = torch.autograd.grad(loss, flat_params, create_graph=True, allow_unused=True) + updates, new_state = self.impl.update( + grads, + state, + params=flat_params, + inplace=False, + ) + flat_new_params = apply_updates(flat_params, updates, inplace=False) - updates, new_state = self.impl.update(grads, state, params=flat_params, inplace=False) - - flat_new_params = apply_updates(flat_params, updates, inplace=False) new_params: ModuleTensorContainers = pytree.tree_unflatten( # type: ignore[assignment] container_treespec, flat_new_params,