Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit fa77c1a

Browse files
committedMay 21, 2024·
chore(optim): wrap torch.autograd.grad() with torch.enable_grad() context
1 parent b3f570c commit fa77c1a

File tree

2 files changed

+13
-15
lines changed

2 files changed

+13
-15
lines changed
 

‎torchopt/optim/func/base.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,10 @@ def step(
8787
if inplace is None:
8888
inplace = self.inplace
8989

90-
# Step parameter only
91-
grads = torch.autograd.grad(loss, params, create_graph=True, allow_unused=True)
90+
with torch.enable_grad():
91+
# Step parameters only
92+
grads = torch.autograd.grad(loss, params, create_graph=True, allow_unused=True)
93+
9294
updates, self.optim_state = self.impl.update(
9395
grads,
9496
self.optim_state,

‎torchopt/optim/meta/base.py

+9-13
Original file line numberDiff line numberDiff line change
@@ -72,26 +72,22 @@ def step(self, loss: torch.Tensor) -> None: # pylint: disable=too-many-locals
7272
):
7373
flat_params: TupleOfTensors
7474
flat_params, container_treespec = pytree.tree_flatten_as_tuple(param_container) # type: ignore[arg-type]
75+
7576
if isinstance(state, UninitializedState):
7677
state = self.impl.init(flat_params)
77-
grads = torch.autograd.grad(
78-
loss,
79-
flat_params,
80-
create_graph=True,
81-
allow_unused=True,
82-
)
83-
updates, new_state = self.impl.update(
84-
grads,
85-
state,
86-
params=flat_params,
87-
inplace=False,
88-
)
89-
self.state_groups[i] = new_state
78+
79+
with torch.enable_grad():
80+
grads = torch.autograd.grad(loss, flat_params, create_graph=True, allow_unused=True)
81+
82+
updates, new_state = self.impl.update(grads, state, params=flat_params, inplace=False)
83+
9084
flat_new_params = apply_updates(flat_params, updates, inplace=False)
9185
new_params: ModuleTensorContainers = pytree.tree_unflatten( # type: ignore[assignment]
9286
container_treespec,
9387
flat_new_params,
9488
)
89+
90+
self.state_groups[i] = new_state
9591
for container, new_param in zip(param_container, new_params):
9692
container.update(new_param)
9793

0 commit comments

Comments
 (0)
Please sign in to comment.