Skip to content

Commit b0cbdcd

Browse files
fix: [pre-commit.ci] auto fixes [...]
1 parent 42e54aa commit b0cbdcd

5 files changed

+49
-24
lines changed

torchopt/nn/stateless.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def reparametrize(
8484
module: nn.Module,
8585
named_tensors: dict[str, torch.Tensor] | Iterable[tuple[str, torch.Tensor]],
8686
allow_missing: bool = False,
87-
) -> Generator[nn.Module, None, None]:
87+
) -> Generator[nn.Module]:
8888
"""Reparameterize the module parameters and/or buffers."""
8989
if not isinstance(named_tensors, dict):
9090
named_tensors = dict(named_tensors)

tutorials/2_Visualization.ipynb

+4-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
{
2-
"cells": [
2+
"cells": [
33
{
44
"cell_type": "markdown",
55
"metadata": {},
@@ -181,8 +181,9 @@
181181
"# Draw computation graph\n",
182182
"display(\n",
183183
" torchopt.visual.make_dot(\n",
184-
" loss, [net_state_0, net_state_1, {'meta_param': meta_param, 'loss': loss}]\n",
185-
" )\n",
184+
" loss,\n",
185+
" [net_state_0, net_state_1, {'meta_param': meta_param, 'loss': loss}],\n",
186+
" ),\n",
186187
")"
187188
]
188189
}

tutorials/3_Meta_Optimizer.ipynb

+31-13
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
{
2-
"cells": [
2+
"cells": [
33
{
44
"cell_type": "markdown",
55
"metadata": {},
@@ -200,8 +200,9 @@
200200
"outer_loss = F.mse_loss(net(x), y)\n",
201201
"display(\n",
202202
" torchopt.visual.make_dot(\n",
203-
" outer_loss, params=[net_state_0, net_state_1, {'x': x, 'outer_loss': outer_loss}]\n",
204-
" )\n",
203+
" outer_loss,\n",
204+
" params=[net_state_0, net_state_1, {'x': x, 'outer_loss': outer_loss}],\n",
205+
" ),\n",
205206
")"
206207
]
207208
},
@@ -247,8 +248,9 @@
247248
"outer_loss = F.mse_loss(net(x), y)\n",
248249
"display(\n",
249250
" torchopt.visual.make_dot(\n",
250-
" outer_loss, params=[net_state_0, net_state_1, {'x': x, 'outer_loss': outer_loss}]\n",
251-
" )\n",
251+
" outer_loss,\n",
252+
" params=[net_state_0, net_state_1, {'x': x, 'outer_loss': outer_loss}],\n",
253+
" ),\n",
252254
")"
253255
]
254256
},
@@ -513,21 +515,30 @@
513515
"source": [
514516
"functional_adam = torchopt.adam(\n",
515517
" lr=torchopt.schedule.linear_schedule(\n",
516-
" init_value=1e-3, end_value=1e-4, transition_steps=10000, transition_begin=2000\n",
517-
" )\n",
518+
" init_value=1e-3,\n",
519+
" end_value=1e-4,\n",
520+
" transition_steps=10000,\n",
521+
" transition_begin=2000,\n",
522+
" ),\n",
518523
")\n",
519524
"\n",
520525
"adam = torchopt.Adam(\n",
521526
" net.parameters(),\n",
522527
" lr=torchopt.schedule.linear_schedule(\n",
523-
" init_value=1e-3, end_value=1e-4, transition_steps=10000, transition_begin=2000\n",
528+
" init_value=1e-3,\n",
529+
" end_value=1e-4,\n",
530+
" transition_steps=10000,\n",
531+
" transition_begin=2000,\n",
524532
" ),\n",
525533
")\n",
526534
"\n",
527535
"meta_adam = torchopt.MetaAdam(\n",
528536
" net,\n",
529537
" lr=torchopt.schedule.linear_schedule(\n",
530-
" init_value=1e-3, end_value=1e-4, transition_steps=10000, transition_begin=2000\n",
538+
" init_value=1e-3,\n",
539+
" end_value=1e-4,\n",
540+
" transition_steps=10000,\n",
541+
" transition_begin=2000,\n",
531542
" ),\n",
532543
")"
533544
]
@@ -610,19 +621,26 @@
610621
"optim = torchopt.MetaAdam(net, lr=1.0, moment_requires_grad=True, use_accelerated_op=True)\n",
611622
"\n",
612623
"net_state_0 = torchopt.extract_state_dict(\n",
613-
" net, by='reference', enable_visual=True, visual_prefix='step0.'\n",
624+
" net,\n",
625+
" by='reference',\n",
626+
" enable_visual=True,\n",
627+
" visual_prefix='step0.',\n",
614628
")\n",
615629
"inner_loss = F.mse_loss(net(x), y)\n",
616630
"optim.step(inner_loss)\n",
617631
"net_state_1 = torchopt.extract_state_dict(\n",
618-
" net, by='reference', enable_visual=True, visual_prefix='step1.'\n",
632+
" net,\n",
633+
" by='reference',\n",
634+
" enable_visual=True,\n",
635+
" visual_prefix='step1.',\n",
619636
")\n",
620637
"\n",
621638
"outer_loss = F.mse_loss(net(x), y)\n",
622639
"display(\n",
623640
" torchopt.visual.make_dot(\n",
624-
" outer_loss, params=[net_state_0, net_state_1, {'x': x, 'outer_loss': outer_loss}]\n",
625-
" )\n",
641+
" outer_loss,\n",
642+
" params=[net_state_0, net_state_1, {'x': x, 'outer_loss': outer_loss}],\n",
643+
" ),\n",
626644
")"
627645
]
628646
},

tutorials/4_Stop_Gradient.ipynb

+7-5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
{
2-
"cells": [
2+
"cells": [
33
{
44
"cell_type": "markdown",
55
"metadata": {},
@@ -192,7 +192,7 @@
192192
" one_step_net_state,\n",
193193
" {'meta_parameter': meta_parameter, 'outer_loss': outer_loss},\n",
194194
" ),\n",
195-
" )\n",
195+
" ),\n",
196196
")"
197197
]
198198
},
@@ -393,7 +393,7 @@
393393
" one_step_net_state,\n",
394394
" {'meta_parameter': meta_parameter, 'outer_loss': outer_loss},\n",
395395
" ),\n",
396-
" )\n",
396+
" ),\n",
397397
")\n",
398398
"\n",
399399
"# Outer update\n",
@@ -457,7 +457,9 @@
457457
"torchopt.stop_gradient(net)\n",
458458
"torchopt.stop_gradient(optim)\n",
459459
"one_step_net_state_detached = torchopt.extract_state_dict(\n",
460-
" net, enable_visual=True, visual_prefix='step1.detached.'\n",
460+
" net,\n",
461+
" enable_visual=True,\n",
462+
" visual_prefix='step1.detached.',\n",
461463
")\n",
462464
"\n",
463465
"# Inner update\n",
@@ -480,7 +482,7 @@
480482
" one_step_net_state_detached,\n",
481483
" {'meta_parameter': meta_parameter, 'outer_loss': outer_loss},\n",
482484
" ),\n",
483-
" )\n",
485+
" ),\n",
484486
")"
485487
]
486488
},

tutorials/6_Zero_Order_Differentiation.ipynb

+6-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
{
2-
"cells": [
2+
"cells": [
33
{
44
"cell_type": "markdown",
55
"id": "8850c832-3b54-4971-8ee0-2cd64b585ea8",
@@ -175,7 +175,11 @@
175175
"\n",
176176
"\n",
177177
"@torchopt.diff.zero_order(\n",
178-
" distribution=distribution, method='forward', argnums=0, num_samples=100, sigma=0.01\n",
178+
" distribution=distribution,\n",
179+
" method='forward',\n",
180+
" argnums=0,\n",
181+
" num_samples=100,\n",
182+
" sigma=0.01,\n",
179183
")\n",
180184
"def forward_process(params, fn, x, y):\n",
181185
" y_pred = fn(params, x)\n",

0 commit comments

Comments
 (0)