Skip to content

Commit

Permalink
fix: [pre-commit.ci] auto fixes [...]
Browse files Browse the repository at this point in the history
  • Loading branch information
pre-commit-ci[bot] committed Sep 2, 2024
1 parent 42e54aa commit b0cbdcd
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 24 deletions.
2 changes: 1 addition & 1 deletion torchopt/nn/stateless.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def reparametrize(
module: nn.Module,
named_tensors: dict[str, torch.Tensor] | Iterable[tuple[str, torch.Tensor]],
allow_missing: bool = False,
) -> Generator[nn.Module, None, None]:
) -> Generator[nn.Module]:
"""Reparameterize the module parameters and/or buffers."""
if not isinstance(named_tensors, dict):
named_tensors = dict(named_tensors)
Expand Down
7 changes: 4 additions & 3 deletions tutorials/2_Visualization.ipynb
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"cells": [
"cells": [
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -181,8 +181,9 @@
"# Draw computation graph\n",
"display(\n",
" torchopt.visual.make_dot(\n",
" loss, [net_state_0, net_state_1, {'meta_param': meta_param, 'loss': loss}]\n",
" )\n",
" loss,\n",
" [net_state_0, net_state_1, {'meta_param': meta_param, 'loss': loss}],\n",
" ),\n",
")"
]
}
Expand Down
44 changes: 31 additions & 13 deletions tutorials/3_Meta_Optimizer.ipynb
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"cells": [
"cells": [
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -200,8 +200,9 @@
"outer_loss = F.mse_loss(net(x), y)\n",
"display(\n",
" torchopt.visual.make_dot(\n",
" outer_loss, params=[net_state_0, net_state_1, {'x': x, 'outer_loss': outer_loss}]\n",
" )\n",
" outer_loss,\n",
" params=[net_state_0, net_state_1, {'x': x, 'outer_loss': outer_loss}],\n",
" ),\n",
")"
]
},
Expand Down Expand Up @@ -247,8 +248,9 @@
"outer_loss = F.mse_loss(net(x), y)\n",
"display(\n",
" torchopt.visual.make_dot(\n",
" outer_loss, params=[net_state_0, net_state_1, {'x': x, 'outer_loss': outer_loss}]\n",
" )\n",
" outer_loss,\n",
" params=[net_state_0, net_state_1, {'x': x, 'outer_loss': outer_loss}],\n",
" ),\n",
")"
]
},
Expand Down Expand Up @@ -513,21 +515,30 @@
"source": [
"functional_adam = torchopt.adam(\n",
" lr=torchopt.schedule.linear_schedule(\n",
" init_value=1e-3, end_value=1e-4, transition_steps=10000, transition_begin=2000\n",
" )\n",
" init_value=1e-3,\n",
" end_value=1e-4,\n",
" transition_steps=10000,\n",
" transition_begin=2000,\n",
" ),\n",
")\n",
"\n",
"adam = torchopt.Adam(\n",
" net.parameters(),\n",
" lr=torchopt.schedule.linear_schedule(\n",
" init_value=1e-3, end_value=1e-4, transition_steps=10000, transition_begin=2000\n",
" init_value=1e-3,\n",
" end_value=1e-4,\n",
" transition_steps=10000,\n",
" transition_begin=2000,\n",
" ),\n",
")\n",
"\n",
"meta_adam = torchopt.MetaAdam(\n",
" net,\n",
" lr=torchopt.schedule.linear_schedule(\n",
" init_value=1e-3, end_value=1e-4, transition_steps=10000, transition_begin=2000\n",
" init_value=1e-3,\n",
" end_value=1e-4,\n",
" transition_steps=10000,\n",
" transition_begin=2000,\n",
" ),\n",
")"
]
Expand Down Expand Up @@ -610,19 +621,26 @@
"optim = torchopt.MetaAdam(net, lr=1.0, moment_requires_grad=True, use_accelerated_op=True)\n",
"\n",
"net_state_0 = torchopt.extract_state_dict(\n",
" net, by='reference', enable_visual=True, visual_prefix='step0.'\n",
" net,\n",
" by='reference',\n",
" enable_visual=True,\n",
" visual_prefix='step0.',\n",
")\n",
"inner_loss = F.mse_loss(net(x), y)\n",
"optim.step(inner_loss)\n",
"net_state_1 = torchopt.extract_state_dict(\n",
" net, by='reference', enable_visual=True, visual_prefix='step1.'\n",
" net,\n",
" by='reference',\n",
" enable_visual=True,\n",
" visual_prefix='step1.',\n",
")\n",
"\n",
"outer_loss = F.mse_loss(net(x), y)\n",
"display(\n",
" torchopt.visual.make_dot(\n",
" outer_loss, params=[net_state_0, net_state_1, {'x': x, 'outer_loss': outer_loss}]\n",
" )\n",
" outer_loss,\n",
" params=[net_state_0, net_state_1, {'x': x, 'outer_loss': outer_loss}],\n",
" ),\n",
")"
]
},
Expand Down
12 changes: 7 additions & 5 deletions tutorials/4_Stop_Gradient.ipynb
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"cells": [
"cells": [
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -192,7 +192,7 @@
" one_step_net_state,\n",
" {'meta_parameter': meta_parameter, 'outer_loss': outer_loss},\n",
" ),\n",
" )\n",
" ),\n",
")"
]
},
Expand Down Expand Up @@ -393,7 +393,7 @@
" one_step_net_state,\n",
" {'meta_parameter': meta_parameter, 'outer_loss': outer_loss},\n",
" ),\n",
" )\n",
" ),\n",
")\n",
"\n",
"# Outer update\n",
Expand Down Expand Up @@ -457,7 +457,9 @@
"torchopt.stop_gradient(net)\n",
"torchopt.stop_gradient(optim)\n",
"one_step_net_state_detached = torchopt.extract_state_dict(\n",
" net, enable_visual=True, visual_prefix='step1.detached.'\n",
" net,\n",
" enable_visual=True,\n",
" visual_prefix='step1.detached.',\n",
")\n",
"\n",
"# Inner update\n",
Expand All @@ -480,7 +482,7 @@
" one_step_net_state_detached,\n",
" {'meta_parameter': meta_parameter, 'outer_loss': outer_loss},\n",
" ),\n",
" )\n",
" ),\n",
")"
]
},
Expand Down
8 changes: 6 additions & 2 deletions tutorials/6_Zero_Order_Differentiation.ipynb
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"cells": [
"cells": [
{
"cell_type": "markdown",
"id": "8850c832-3b54-4971-8ee0-2cd64b585ea8",
Expand Down Expand Up @@ -175,7 +175,11 @@
"\n",
"\n",
"@torchopt.diff.zero_order(\n",
" distribution=distribution, method='forward', argnums=0, num_samples=100, sigma=0.01\n",
" distribution=distribution,\n",
" method='forward',\n",
" argnums=0,\n",
" num_samples=100,\n",
" sigma=0.01,\n",
")\n",
"def forward_process(params, fn, x, y):\n",
" y_pred = fn(params, x)\n",
Expand Down

0 comments on commit b0cbdcd

Please sign in to comment.