-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathffmodel.py
109 lines (91 loc) · 3.18 KB
/
ffmodel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
# coding: utf-8
"""
Implement of Forward-Forward Algorithm (FF).
Author: Mimi
Date: 2023-07-15
"""
from typing import List, Tuple
import torch
import torch.nn as nn
class FFLinear(nn.Linear):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
device: str = "cpu",
lr: float = 0.01,
goodness_threshold: float = 2.0,
) -> None:
super().__init__(
in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
)
self.goodness_threshold = goodness_threshold
self.optimizer = torch.optim.AdamW(self.parameters(), lr=lr)
self.relu = torch.nn.ReLU()
def linear_transform(self, inputs: torch.Tensor) -> torch.Tensor:
# L2 Norm & smoothy
inputs_l2_norm = inputs.norm(2, 1, keepdim=True) + 1e-4
# Normalization
inputs /= inputs_l2_norm
# Linear transformation
outputs = torch.mm(inputs, self.weight.T) + self.bias.unsqueeze(0)
return self.relu(outputs)
def forward(
self,
pos_inputs: torch.Tensor,
neg_inputs: torch.Tensor,
train_mode: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# Compute `goodness`
pos_outputs = self.linear_transform(pos_inputs)
neg_outputs = self.linear_transform(neg_inputs)
pos_goodness = pos_outputs.pow(2).mean(1)
neg_goodness = neg_outputs.pow(2).mean(1)
# Clean the layer optimizer
self.optimizer.zero_grad()
# Compute loss
pos_loss = self.goodness_threshold - pos_goodness
neg_loss = neg_goodness - self.goodness_threshold
loss = torch.log(1 + torch.exp(torch.cat([pos_loss, neg_loss]))).mean()
# Update the weights & bias of the layer
if train_mode:
loss.backward()
self.optimizer.step()
return pos_outputs.detach(), neg_outputs.detach(), loss.detach()
class FFClassifier(torch.nn.Module):
def __init__(self, dims: List[int], device: str) -> None:
super().__init__()
self.layers = [
FFLinear(
in_features=dims[i],
out_features=dims[i+1],
lr=0.01,
device=device,
) for i in range(len(dims)-1)
]
self.dropout = torch.nn.Dropout(p=0.3)
def forward(
self,
pos_inputs: torch.Tensor,
neg_inputs: torch.Tensor,
train_mode: bool = True,
) -> torch.Tensor:
total_loss = 0.0
for layer in self.layers:
pos_inputs, neg_inputs, loss = layer(pos_inputs, neg_inputs, train_mode)
pos_inputs = self.dropout(pos_inputs)
neg_inputs = self.dropout(neg_inputs)
total_loss += loss.item()
return total_loss
@torch.no_grad()
def predict(self, inputs: torch.Tensor, num_classes: int = 10) -> int:
goodness = 0
for idx, layer in enumerate(self.layers):
inputs = layer.linear_transform(inputs)
if idx > 0:
goodness += inputs.pow(2).mean(1)
return torch.argmax(goodness)