generated from VectorInstitute/bootcamp_template
-
Notifications
You must be signed in to change notification settings - Fork 1
/
custom_lightning_module.py
99 lines (74 loc) · 3.5 KB
/
custom_lightning_module.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import time
import lightning as L
import torch
import torch.nn.functional as F
import torchmetrics
def compute_accuracy(model, data_loader, device):
model.eval()
correct_pred, num_examples = 0, 0
with torch.no_grad():
for features, targets in data_loader:
features = features.view(-1, 28*28).to(device)
targets = targets.to(device)
logits = model(features)
_, predicted_labels = torch.max(logits, 1)
num_examples += targets.size(0)
correct_pred += (predicted_labels == targets).sum()
return correct_pred.float()/num_examples * 100
def train(num_epochs, model, optimizer, train_loader, device):
start_time = time.time()
for epoch in range(num_epochs):
model.train()
for batch_idx, (features, targets) in enumerate(train_loader):
features = features.view(-1, 28*28).to(device)
targets = targets.to(device)
# FORWARD AND BACK PROP
logits = model(features)
loss = F.cross_entropy(logits, targets)
optimizer.zero_grad()
loss.backward()
# UPDATE MODEL PARAMETERS
optimizer.step()
# LOGGING
if not batch_idx % 400:
print("Epoch: %03d/%03d | Batch %03d/%03d | Loss: %.4f"
% (epoch+1, num_epochs, batch_idx,
len(train_loader), loss))
with torch.set_grad_enabled(False):
print("Epoch: %03d/%03d training accuracy: %.2f%%" % (
epoch+1, num_epochs,
compute_accuracy(model, train_loader, device)))
print("Time elapsed: %.2f min" % ((time.time() - start_time)/60))
print("Total Training Time: %.2f min" % ((time.time() - start_time)/60))
class CustomLightningModule(L.LightningModule):
def __init__(self, model, learning_rate=5e-5):
super().__init__()
self.learning_rate = learning_rate
self.model = model
self.val_acc = torchmetrics.Accuracy(task="multiclass", num_classes=2)
self.test_acc = torchmetrics.Accuracy(task="multiclass", num_classes=2)
def forward(self, input_ids, attention_mask, labels):
return self.model(input_ids, attention_mask=attention_mask, labels=labels)
def training_step(self, batch, batch_idx):
outputs = self(batch["input_ids"], attention_mask=batch["attention_mask"],
labels=batch["label"])
self.log("train_loss", outputs["loss"])
return outputs["loss"] # this is passed to the optimizer for training
def validation_step(self, batch, batch_idx):
outputs = self(batch["input_ids"], attention_mask=batch["attention_mask"],
labels=batch["label"])
self.log("val_loss", outputs["loss"], prog_bar=True)
logits = outputs["logits"]
predicted_labels = torch.argmax(logits, 1)
self.val_acc(predicted_labels, batch["label"])
self.log("val_acc", self.val_acc, prog_bar=True)
def test_step(self, batch, batch_idx):
outputs = self(batch["input_ids"], attention_mask=batch["attention_mask"],
labels=batch["label"])
logits = outputs["logits"]
predicted_labels = torch.argmax(logits, 1)
self.test_acc(predicted_labels, batch["label"])
self.log("accuracy", self.test_acc, prog_bar=True)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
return optimizer