-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_validate.py
82 lines (51 loc) · 2.47 KB
/
train_validate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import torch
import numpy as np
from sklearn.metrics import f1_score
import pandas as pd
from sklearn.metrics import r2_score
import utils
def train_model(model, train_dl, optimizer, loss_func, device, opt):
all_labels, all_predictions = [], []
model.train()
total_train_loss = 0.0
for batch_idx, batch in enumerate(train_dl):
optimizer.zero_grad()
input_ids, attn_mask, lengths, labels = batch['input_ids'], batch['attn_mask'], batch['lengths'], batch['labels']
input_ids = input_ids.to(device)
attn_mask = attn_mask.to(device)
logits = model.forward(input_ids, attn_mask, lengths)
labels = labels.to(device)
loss = utils.get_loss(logits, labels, loss_func, opt)
loss.backward()
optimizer.step()
with torch.no_grad():
predictions = utils.get_predictions(logits, opt)
all_predictions += predictions
labels = labels.detach().cpu().numpy()
labels = labels.reshape(1,-1).tolist()[0]
all_labels += labels
total_train_loss += loss.item()
score = utils.get_score(all_labels, all_predictions, opt)
return total_train_loss/(batch_idx+1), score
def validate_model(model, dl, loss_func, device, opt):
all_labels, all_predictions, all_sess_ids = [], [], []
total_loss = 0.0
model.eval()
with torch.no_grad():
for batch_idx, batch in enumerate(dl):
input_ids, attn_mask, lengths, labels, sess_ids = batch['input_ids'], batch['attn_mask'], batch['lengths'], batch['labels'], batch['session_ids']
labels = labels.to(device)
input_ids = input_ids.to(device)
attn_mask = attn_mask.to(device)
logits = model.forward(input_ids, attn_mask, lengths)
loss = utils.get_loss(logits, labels, loss_func, opt)
total_loss += loss.item()
predictions = utils.get_predictions(logits, opt)
labels = labels.detach().cpu().numpy()
labels = labels.reshape(1,-1).tolist()[0]
all_labels += labels
all_predictions += predictions
all_sess_ids += sess_ids
score = utils.get_score(all_labels, all_predictions, opt)
results_df = pd.DataFrame({'labels':all_labels, 'predictions':all_predictions, 'ID':all_sess_ids})
return total_loss/(batch_idx+1), score, results_df