-
Notifications
You must be signed in to change notification settings - Fork 6
/
correlations.py
86 lines (64 loc) · 3.18 KB
/
correlations.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from src.utils_contributions import *
import torch.nn.functional as F
from src.contributions import ModelWrapper, ClassificationModelWrapperCaptum, interpret_sentence_sst2
#import contributions
import json
import statistics
import random
random.seed(10)
import argparse
from collections import defaultdict
import json
import torch.nn as nn
device = "cuda" if torch.cuda.is_available() else "cpu"
import sys
def compute_correlation(relevancies, methods_list, reference_method):
"""Spearman's rank correlation coefficient between methods in methods_list and reference_method."""
corrs_method = defaultdict(list)
for method in methods_list:
results_corr = []
for i in np.arange(len(relevancies[method])):
sp = spearmanr(relevancies[method][i],relevancies[reference_method][i])
results_corr.append(sp[0])
corrs_method[method].append(np.mean(results_corr))
corrs_method[method].append(np.std(results_corr))
return corrs_method
def main(args):
model_name = args.model
dataset_name = args.dataset
if dataset_name == 'sva':
attributions_file = f'./data/{model_name}_{dataset_name}_attributions.npy'
relevancies = np.load(attributions_file,allow_pickle=True)[()]
methods_list = ['raw','rollout','norm','ours']
reference_method = 'blankout'
corrs_method = compute_correlation(relevancies, methods_list, reference_method)
outfile = f'./data/{model_name}_{dataset_name}_correlations.json'
with open(outfile, 'w') as f:
json.dump(corrs_method, f)
elif dataset_name == 'sst2':
attributions_file = f'./data/{model_name}_{dataset_name}_attributions.npy'
relevancies = np.load(attributions_file,allow_pickle=True)[()]
# Compute mean and avg correlation between methods
methods_list = methods_list = ['raw','rollout','norm','ours','grad']
reference_method = 'grad'
corrs_method = compute_correlation(relevancies, methods_list, reference_method)
# Compute mean and avg rank of special tokens
special_tok_attributions_file = f'./data/{model_name}_{dataset_name}_special_tok_attributions.npy'
special_tokens_relevancies = np.load(special_tok_attributions_file,allow_pickle=True)[()]
special_tokens_method = defaultdict(list)
for method in special_tokens_relevancies.keys():
special_tokens_method[method].append(statistics.mean(special_tokens_relevancies[method]))
special_tokens_method[method].append(statistics.stdev(special_tokens_relevancies[method]))
outfile = f'./data/{model_name}_{dataset_name}_rank_special_tokens.json'
with open(outfile, 'w') as f:
json.dump(special_tokens_method, f)
outfile = f'./data/{model_name}_{dataset_name}_correlations.json'
with open(outfile, 'w') as f:
json.dump(corrs_method, f)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-model', help="model used", type= str)
parser.add_argument('-dataset', help="sst2/sva", type=str)
args=parser.parse_args()
main(args)