|
1 | 1 | {
|
2 |
| - "cells": [ |
| 2 | + "cells": [ |
3 | 3 | {
|
4 | 4 | "cell_type": "markdown",
|
5 | 5 | "metadata": {},
|
|
200 | 200 | "outer_loss = F.mse_loss(net(x), y)\n",
|
201 | 201 | "display(\n",
|
202 | 202 | " torchopt.visual.make_dot(\n",
|
203 |
| - " outer_loss, params=[net_state_0, net_state_1, {'x': x, 'outer_loss': outer_loss}]\n", |
204 |
| - " )\n", |
| 203 | + " outer_loss,\n", |
| 204 | + " params=[net_state_0, net_state_1, {'x': x, 'outer_loss': outer_loss}],\n", |
| 205 | + " ),\n", |
205 | 206 | ")"
|
206 | 207 | ]
|
207 | 208 | },
|
|
247 | 248 | "outer_loss = F.mse_loss(net(x), y)\n",
|
248 | 249 | "display(\n",
|
249 | 250 | " torchopt.visual.make_dot(\n",
|
250 |
| - " outer_loss, params=[net_state_0, net_state_1, {'x': x, 'outer_loss': outer_loss}]\n", |
251 |
| - " )\n", |
| 251 | + " outer_loss,\n", |
| 252 | + " params=[net_state_0, net_state_1, {'x': x, 'outer_loss': outer_loss}],\n", |
| 253 | + " ),\n", |
252 | 254 | ")"
|
253 | 255 | ]
|
254 | 256 | },
|
|
513 | 515 | "source": [
|
514 | 516 | "functional_adam = torchopt.adam(\n",
|
515 | 517 | " lr=torchopt.schedule.linear_schedule(\n",
|
516 |
| - " init_value=1e-3, end_value=1e-4, transition_steps=10000, transition_begin=2000\n", |
517 |
| - " )\n", |
| 518 | + " init_value=1e-3,\n", |
| 519 | + " end_value=1e-4,\n", |
| 520 | + " transition_steps=10000,\n", |
| 521 | + " transition_begin=2000,\n", |
| 522 | + " ),\n", |
518 | 523 | ")\n",
|
519 | 524 | "\n",
|
520 | 525 | "adam = torchopt.Adam(\n",
|
521 | 526 | " net.parameters(),\n",
|
522 | 527 | " lr=torchopt.schedule.linear_schedule(\n",
|
523 |
| - " init_value=1e-3, end_value=1e-4, transition_steps=10000, transition_begin=2000\n", |
| 528 | + " init_value=1e-3,\n", |
| 529 | + " end_value=1e-4,\n", |
| 530 | + " transition_steps=10000,\n", |
| 531 | + " transition_begin=2000,\n", |
524 | 532 | " ),\n",
|
525 | 533 | ")\n",
|
526 | 534 | "\n",
|
527 | 535 | "meta_adam = torchopt.MetaAdam(\n",
|
528 | 536 | " net,\n",
|
529 | 537 | " lr=torchopt.schedule.linear_schedule(\n",
|
530 |
| - " init_value=1e-3, end_value=1e-4, transition_steps=10000, transition_begin=2000\n", |
| 538 | + " init_value=1e-3,\n", |
| 539 | + " end_value=1e-4,\n", |
| 540 | + " transition_steps=10000,\n", |
| 541 | + " transition_begin=2000,\n", |
531 | 542 | " ),\n",
|
532 | 543 | ")"
|
533 | 544 | ]
|
|
610 | 621 | "optim = torchopt.MetaAdam(net, lr=1.0, moment_requires_grad=True, use_accelerated_op=True)\n",
|
611 | 622 | "\n",
|
612 | 623 | "net_state_0 = torchopt.extract_state_dict(\n",
|
613 |
| - " net, by='reference', enable_visual=True, visual_prefix='step0.'\n", |
| 624 | + " net,\n", |
| 625 | + " by='reference',\n", |
| 626 | + " enable_visual=True,\n", |
| 627 | + " visual_prefix='step0.',\n", |
614 | 628 | ")\n",
|
615 | 629 | "inner_loss = F.mse_loss(net(x), y)\n",
|
616 | 630 | "optim.step(inner_loss)\n",
|
617 | 631 | "net_state_1 = torchopt.extract_state_dict(\n",
|
618 |
| - " net, by='reference', enable_visual=True, visual_prefix='step1.'\n", |
| 632 | + " net,\n", |
| 633 | + " by='reference',\n", |
| 634 | + " enable_visual=True,\n", |
| 635 | + " visual_prefix='step1.',\n", |
619 | 636 | ")\n",
|
620 | 637 | "\n",
|
621 | 638 | "outer_loss = F.mse_loss(net(x), y)\n",
|
622 | 639 | "display(\n",
|
623 | 640 | " torchopt.visual.make_dot(\n",
|
624 |
| - " outer_loss, params=[net_state_0, net_state_1, {'x': x, 'outer_loss': outer_loss}]\n", |
625 |
| - " )\n", |
| 641 | + " outer_loss,\n", |
| 642 | + " params=[net_state_0, net_state_1, {'x': x, 'outer_loss': outer_loss}],\n", |
| 643 | + " ),\n", |
626 | 644 | ")"
|
627 | 645 | ]
|
628 | 646 | },
|
|
0 commit comments