-
Notifications
You must be signed in to change notification settings - Fork 0
/
CLIP_train.py
168 lines (122 loc) · 6.1 KB
/
CLIP_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import torch
import torch.nn as nn
import pandas as pd
from Preprocessing_text import preprocessing_text
from PIL import Image
from transformers import CLIPProcessor, CLIPModel,CLIPTokenizer,AdamW
from CLIP_dataset import CLIP_Dataset
from torch.utils.data import DataLoader
import random
import numpy as np
from CLIP_model import Classifier
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
from sklearn.metrics import f1_score
from hierachical_f1 import dag
def set_seed(seed):
random.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
torch.cuda.manual_seed(seed)
def multi_label_metrics(predictions,labels,threshold=0.05):
#probs=F.sigmoid(torch.tensor(predictions))
probs=F.softmax(torch.tensor(predictions))
y_pred=np.zeros(probs.shape)
y_pred[np.where(probs>=threshold)]=1
y_true=labels
f1=f1_score(y_true,y_pred,average='micro')
return f1
def train_fn(train_dataloaer,model,optimizer,criterion,device):
model.train()
total_loss=0.0
preds=[]
true_label=[]
for i,batch in enumerate(train_dataloader):
label=batch['labels'].to(device)
output=model(batch)
loss=criterion(output,label)
total_loss+=loss.item()
loss.backward()
optimizer.step()
optimizer.zero_grad()
if i%100==0:
print(i, total_loss)
#print(f"prediction label :{np.where(F.sigmoid(output).detach().cpu().numpy()>=threshold)} \n true_label : {np.where(label.detach().cpu().numpy()==1)}")
preds.extend(output.detach().cpu().numpy())
true_label.extend(label.detach().cpu().numpy())
f1=multi_label_metrics(preds,true_label)
avg_train_loss=total_loss/len(train_dataloader)
return total_loss,avg_train_loss,f1
def valid_fn(valid_dataloader,model,criterion,device):
model.eval()
total_loss=0.0
preds=[]
true_label=[]
threshold=0.05
with torch.no_grad():
for batch in valid_dataloader:
label=batch['labels'].to(device)
output=model(batch)
loss=criterion(output,label)
total_loss+=loss.item()
print(f"prediction label :{np.where(F.softmax(output).detach().cpu().numpy()>=threshold)},true_label : {np.where(label.detach().cpu().numpy()==1)}")
preds.extend(output.detach().cpu().numpy())
true_label.extend(label.detach().cpu().numpy())
#print(F.sigmoid(output), label)
f1=multi_label_metrics(preds,true_label)
avg_valid_loss=total_loss/len(valid_dataloader)
return total_loss,avg_valid_loss,f1
def experiment_fn(train_dataloader,valid_dataloader,device,model_name,n_labels):
encoder = CLIPModel.from_pretrained(model_name)
model=Classifier(encoder,n_labels,device).to(device)
optimizer=AdamW(model.parameters(),lr=1e-5,eps=1e-8)
#criterion=nn.BCEWithLogitsLoss().to(device)
criterion=nn.CrossEntropyLoss().to(device)
epoch=20
for ep in range(epoch):
train_loss,avg_train_loss,train_acc=train_fn(train_dataloader,model,optimizer,criterion,device)
valid_loss,avg_valid_loss,valid_acc=valid_fn(valid_dataloader,model,criterion,device)
print(f"EP: {ep} , train_acc : {train_acc}, valid_acc : {valid_acc}")
if __name__=="__main__":
seed=42
set_seed(seed)
n_labels=22
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
data_path="/home/labuser/Semeval/SemEval_task4/Data/subtask2a_annotation/"
train_data=pd.read_json(data_path+"train.json",lines=False)
valid_data=pd.read_json(data_path+"validation.json",lines=False)
train_text=train_data['text']
valid_text=valid_data['text']
train_image=train_data['image']
valid_image=valid_data['image']
train_label=train_data['labels']
valid_label=valid_data['labels']
label2id={"Presenting Irrelevant Data (Red Herring)":0,"Misrepresentation of Someone's Position (Straw Man)":1,"Whataboutism":2,"Causal Oversimplification":3,
"Obfuscation, Intentional vagueness, Confusion":4,"Appeal to authority":5,"Black-and-white Fallacy/Dictatorship":6,"Name calling/Labeling":7,
"Loaded Language":8,"Exaggeration/Minimisation":9,"Flag-waving":10,"Doubt":11,"Appeal to fear/prejudice":12,"Slogans":13,"Thought-terminating cliché":14,
"Bandwagon":15,"Reductio ad hitlerum":16,"Repetition":17,"Smears":18,"Glittering generalities (Virtue)":19,"Transfer":20,"Appeal to (Strong) Emotions":21}
id2label={}
for key,value in label2id.items():
id2label[value]=key
"""
for l in label2id.keys():
labels_num = train_data['labels'].str.contains(str(l),regex=False).sum()
print(f'train : the number of label {l}: {labels_num}, the ratio : {labels_num/len(train_data)*100:.1f}% ')
for l in label2id.keys():
v_labels_num = valid_data['labels'].str.contains(str(l),regex=False).sum()
print(f'valid : the number of label {l}: {v_labels_num}, the ratio : {v_labels_num/len(valid_data)*100:.1f}% ')
"""
train_image_path="/home/labuser/Semeval/SemEval_task4/Data/train_images/"
valid_image_path="/home/labuser/Semeval/SemEval_task4/Data/validation_images/"
train_pre_text=preprocessing_text(train_text)
valid_pre_text=preprocessing_text(valid_text)
model_name="openai/clip-vit-base-patch32"
processor=CLIPProcessor.from_pretrained(model_name)
train_dataset=CLIP_Dataset(train_pre_text,train_image,train_label,label2id,processor,train_image_path)
valid_dataset=CLIP_Dataset(valid_pre_text,valid_image,valid_label,label2id,processor,valid_image_path)
batch_size=8
train_dataloader=DataLoader(train_dataset,batch_size=batch_size,shuffle=True)
valid_dataloader=DataLoader(valid_dataset,batch_size=batch_size,shuffle=True)
experiment_fn(train_dataloader,valid_dataloader,device,model_name,n_labels)