-
Notifications
You must be signed in to change notification settings - Fork 30
/
pinn.py
151 lines (112 loc) · 5.07 KB
/
pinn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
from collections import OrderedDict
from typing import Callable
import torch
from torch import nn
from torch.func import functional_call, grad, vmap
class LinearNN(nn.Module):
def __init__(
self,
num_inputs: int = 1,
num_layers: int = 1,
num_neurons: int = 5,
act: nn.Module = nn.Tanh(),
) -> None:
"""Basic neural network architecture with linear layers
Args:
num_inputs (int, optional): the dimensionality of the input tensor
num_layers (int, optional): the number of hidden layers
num_neurons (int, optional): the number of neurons for each hidden layer
act (nn.Module, optional): the non-linear activation function to use for stitching
linear layers togeter
"""
super().__init__()
self.num_inputs = num_inputs
self.num_neurons = num_neurons
self.num_layers = num_layers
layers = []
# input layer
layers.append(nn.Linear(self.num_inputs, num_neurons))
# hidden layers with linear layer and activation
for _ in range(num_layers):
layers.extend([nn.Linear(num_neurons, num_neurons), act])
# output layer
layers.append(nn.Linear(num_neurons, 1))
# build the network
self.network = nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.network(x.reshape(-1, 1)).squeeze()
def make_forward_fn(
model: nn.Module,
derivative_order: int = 1,
) -> list[Callable]:
"""Make a functional forward pass and gradient functions given an input model
This function creates a set of functional calls of the input model
It returns a list of composable v-mapped version of the forward pass
and of higher-order derivatives with respect to the inputs as
specified by the input argument `derivative_order`
Args:
model (nn.Module): the model to make the functional calls for. It can be any subclass of
a nn.Module
derivative_order (int, optional): Up to which order return functions for computing the
derivative of the model with respect to the inputs
Returns:
list[Callable]: A list of functions where each element corresponds to
a v-mapped version of the model forward pass and its derivatives. The
0-th element is always the forward pass and, depending on the value of
the `derivative_order` argument, the following elements corresponds to
the i-th order derivative function with respect to the model inputs. The
vmap ensures efficient support for batched inputs
"""
# notice that `functional_call` supports batched input by default
# thus there is not need to call vmap on it, as it's instead the case
# for the derivative calls
def f(x: torch.Tensor, params: dict[str, torch.nn.Parameter] | tuple[torch.nn.Parameter, ...]) -> torch.Tensor:
# the functional optimizer works with parameters represented as a tuple instead
# of the dictionary form required by the `functional_call` API
# here we perform the conversion from tuple to dictionary
if isinstance(params, tuple):
params_dict = tuple_to_dict_parameters(model, params)
else:
params_dict = params
return functional_call(model, params_dict, (x, ))
fns = []
fns.append(f)
dfunc = f
for _ in range(derivative_order):
# first compute the derivative function
dfunc = grad(dfunc)
# then use vmap to support batching
dfunc_vmap = vmap(dfunc, in_dims=(0, None))
fns.append(dfunc_vmap)
return fns
def tuple_to_dict_parameters(
model: nn.Module, params: tuple[torch.nn.Parameter, ...]
) -> OrderedDict[str, torch.nn.Parameter]:
"""Convert a set of parameters stored as a tuple into a dictionary form
This conversion is required to be able to call the `functional_call` API which requires
parameters in a dictionary form from the results of a functional optimization step which
returns the parameters as a tuple
Args:
model (nn.Module): the model to make the functional calls for. It can be any subclass of
a nn.Module
params (tuple[Parameter, ...]): the model parameters stored as a tuple
Returns:
An OrderedDict instance with the parameters stored as an ordered dictionary
"""
keys = list(dict(model.named_parameters()).keys())
values = list(params)
return OrderedDict(({k:v for k,v in zip(keys, values)}))
if __name__ == "__main__":
# TODO: turn this into a unit test
model = LinearNN(num_layers=2)
fns = make_forward_fn(model, derivative_order=2)
batch_size = 10
x = torch.randn(batch_size)
# params = dict(model.named_parameters())
params = dict(model.named_parameters())
fn_x = fns[0](x, params)
assert fn_x.shape[0] == batch_size
dfn_x = fns[1](x, params)
assert dfn_x.shape[0] == batch_size
ddfn_x = fns[2](x, params)
assert ddfn_x.shape[0] == batch_size