-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconfig.py
114 lines (104 loc) · 5.57 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import os
import random
import torch
import argparse
import warnings
import torch.backends.cudnn as cudnn
parser = argparse.ArgumentParser(description='SimKGC arguments')
parser.add_argument('--pretrained-model', default='bert-base-uncased', type=str, metavar='N',
help='path to pretrained model')
parser.add_argument('--task', default='wn18rr', type=str, metavar='N',
help='dataset name')
parser.add_argument('--train-path', default='', type=str, metavar='N',
help='path to training data')
parser.add_argument('--valid-path', default='', type=str, metavar='N',
help='path to valid data')
parser.add_argument('--model-dir', default='', type=str, metavar='N',
help='path to model dir')
parser.add_argument('--continue_training', default=None,
help='path to checkpoint if training needs to be continued')
parser.add_argument('--patience', default=None, type=int,
help='number of epochs for patience')
parser.add_argument('--warmup', default=400, type=int, metavar='N',
help='warmup steps')
parser.add_argument('--max-to-keep', default=5, type=int, metavar='N',
help='max number of checkpoints to keep')
parser.add_argument('--grad-clip', default=10.0, type=float, metavar='N',
help='gradient clipping')
parser.add_argument('--pooling', default='cls', type=str, metavar='N',
help='bert pooling')
parser.add_argument('--dropout', default=0.1, type=float, metavar='N',
help='dropout on final linear layer')
parser.add_argument('--use-amp', action='store_true',
help='Use amp if available')
parser.add_argument('--t', default=0.05, type=float,
help='temperature parameter')
parser.add_argument('--use-link-graph', action='store_true',
help='use neighbors from link graph as context')
parser.add_argument('--eval-every-n-step', default=10000, type=int,
help='evaluate every n steps')
parser.add_argument('--pre-batch', default=0, type=int,
help='number of pre-batch used for negatives')
parser.add_argument('--pre-batch-weight', default=0.5, type=float,
help='the weight for logits from pre-batch negatives')
parser.add_argument('--additive-margin', default=0.0, type=float, metavar='N',
help='additive margin for InfoNCE loss function')
parser.add_argument('--finetune-t', action='store_true',
help='make temperature as a trainable parameter or not')
parser.add_argument('--max-num-tokens', default=50, type=int,
help='maximum number of tokens')
parser.add_argument('--use-self-negative', action='store_true',
help='use head entity as negative')
parser.add_argument('-j', '--workers', default=3, type=int, metavar='N',
help='number of data loading workers')
parser.add_argument('--epochs', default=10, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('-b', '--batch-size', default=128, type=int,
metavar='N',
help='mini-batch size (default: 256), this is the total '
'batch size of all GPUs on the current node when '
'using Data Parallel or Distributed Data Parallel')
parser.add_argument('--lr', '--learning-rate', default=2e-5, type=float,
metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--lr-scheduler', default='linear', type=str,
help='Lr scheduler to use')
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
metavar='W', help='weight decay (default: 1e-4)',
dest='weight_decay')
parser.add_argument('-p', '--print-freq', default=50, type=int,
metavar='N', help='print frequency (default: 10)')
parser.add_argument('--seed', default=None, type=int,
help='seed for initializing training. ')
# only used for evaluation
parser.add_argument('--is-test', action='store_true',
help='is in test mode or not')
parser.add_argument('--rerank-n-hop', default=2, type=int,
help='use n-hops node for re-ranking entities, only used during evaluation')
parser.add_argument('--neighbor-weight', default=0.0, type=float,
help='weight for re-ranking entities')
parser.add_argument('--eval-model-path', default='', type=str, metavar='N',
help='path to model, only used for evaluation')
args = parser.parse_args()
assert not args.train_path or os.path.exists(args.train_path)
assert args.pooling in ['cls', 'mean', 'max']
assert args.task.lower() in ['wn18rr', 'fb15k237', 'wiki5m_ind', 'wiki5m_trans', 'mopenkb']
assert args.lr_scheduler in ['linear', 'cosine']
if args.model_dir:
os.makedirs(args.model_dir, exist_ok=True)
else:
assert os.path.exists(args.eval_model_path), 'One of args.model_dir and args.eval_model_path should be valid path'
args.model_dir = os.path.dirname(args.eval_model_path)
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
try:
if args.use_amp:
import torch.cuda.amp
except Exception:
args.use_amp = False
warnings.warn('AMP training is not available, set use_amp=False')
if not torch.cuda.is_available():
args.use_amp = False
args.print_freq = 1
warnings.warn('GPU is not available, set use_amp=False and print_freq=1')