-
Notifications
You must be signed in to change notification settings - Fork 0
/
step4_metatune_vlm.py
275 lines (234 loc) · 9.37 KB
/
step4_metatune_vlm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
"""
current target:
1. using the wandb to sweep model parameters
-----------------------------------------------------
after using sweep:
https://github.com/wandb/client/issues/982
"""
# -------------------
# for cmd running
# -------------------
import sys
import socket
HOST = socket.gethostname()
if HOST.startswith("t"):
sys.path.append('/tank/space/xugy07/MetaVLScratch')
cfg_file = '/tank/space/xugy07/MetaVLScratch/Config/maml_cfg.yaml'
# out_dir = '/tank/space/xugy07/MetaVLScratch/runs'
out_dir = "/egr/research-hlr/xugy07/MetaCompExperiment"
elif HOST.startswith("a"):
sys.path.append('/home/xu/MetaVL')
cfg_file = '/home/xu/MetaVL/Config/maml_cfg.yaml'
out_dir = '/home/xu/MetaVL/runs'
import os
from time import gmtime, strftime
import json
# config
import argparse
from yacs.config import CfgNode
# tools
from transformers import BertTokenizer
import stanza
# utils
from ProjUtils.MetaTrainUtils import add_episode_data_dir_into_cfgnode, add_meta_exp_dir_into_cfgnode, \
str2bool, add_base_model_dir_into_cfgnode
from ProjUtils.ConfigUtils import save_cfg_node
from ProjUtils.SeedUtils import fix_seed
from ProjUtils.WandbUtils import convert_cfgnode_to_dict
# model
from VLModels.VLModelWrapper import VLBERTModel, LXMERTModel
# data
# trainer
# from Trainer.NormalTrainer.MAMLTrainer import MAMLTrainer
from DataPreprocessor.StepN_TaskGenerator.FaissMAMLTrainer import FaissMAMLTrainer
from ProjUtils.Constant import MetaExpList, PureMetaExpList, ProjDir, ExpDir
# logging
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
from DataPreprocessor.StepN_TaskGenerator.mscoco_transform_json_to_tensor import OneEpisodeDataSet
# ------------------ wandb ------------------------#
import wandb
# ----------------------------------------------------#
def meta_train(cfg_node):
# --------------
# global device
# --------------
device = 'cuda'
# -----------------
# shared tools
# -----------------
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
nlp_stanza = stanza.Pipeline('en')
# ------------------------
# Init Base VLModel
# ------------------------
if cfg_node.base_model == 'vlbert':
base_vl_model = VLBERTModel()
elif cfg_node.base_model == 'lxmert':
base_vl_model = LXMERTModel()
else:
raise NotImplementedError('Unknown model type {}'.format(cfg_node.base_model))
# 3. Training Process
faiss_maml_trainer = FaissMAMLTrainer(cfg_node, base_vl_model, bert_tokenizer, nlp_stanza, device)
# 4: different training
if cfg_node.exp_type in ["faiss_dgmaml", "faiss_maml"]:
faiss_maml_trainer.faiss_mamal_training()
def main(args):
# ---------------------------------
# init cfg_node using yaml fuke
# ---------------------------------
cfg_node = CfgNode(new_allowed=True)
cfg_node.merge_from_file(args.config_file)
# --------------------------
# modify cfg_node using args
# --------------------------
# normal part
cfg_node.base_model = args.base_model
cfg_node.dataset = args.dataset
cfg_node.seed = args.seed
cfg_node.output_dir = args.out_dir
cfg_node.is_novel_comps = args.novel_comps
cfg_node.exp_type = args.exp_type
cfg_node.task_type = args.task_type
cfg_node.val.val_item_num = cfg_node.val.batch_size * 100
cfg_node.MetaTrain.fomaml = True if cfg_node.exp_type == "fomaml" else False # str --> bool
# meta part
cfg_node.MetaTrain.inner_update_steps = args.inner_update_steps
cfg_node.MetaTrain.episode_batch_size = args.episode_batch_size
cfg_node.MetaTrain.episode_num = args.episode_num
cfg_node.MetaTrain.target_comp_cpt_num = args.target_cpt_num
cfg_node.MetaTrain.shot_num = args.shot_num
cfg_node.MetaTrain.mask_sup_pair = str2bool(args.mask_sup_pair)
cfg_node.MetaTrain.sup_weight = args.sup_weight
# modify exp type
cfg_node.MetaTrain.train_from_scratch = str2bool(args.train_from_scratch)
if cfg_node.MetaTrain.train_from_scratch:
cfg_node.exp_type += "FromScratch"
# ----------------------
# adding specific dirs
# ----------------------
# base_output_dir: loading the pretrained model
add_base_model_dir_into_cfgnode(cfg_node)
# meta_output_dir: recording the training process
add_meta_exp_dir_into_cfgnode(cfg_node)
# episode data dir
add_episode_data_dir_into_cfgnode(cfg_node)
# print config
logger.info("Running with config:\n{}".format(cfg_node))
# -----------------------------
# save config to output_dir
# -----------------------------
cfg_file = os.path.join(cfg_node.meta_output_dir, 'config.yaml')
logger.info("Saving config into: {}".format(cfg_file))
save_cfg_node(cfg_node, cfg_file)
# --------------------------------------------------------
# init sweep config: not using and using yaml instead
# --------------------------------------------------------
"""
# 1. metric dict
dictMetric = {
'name': 'loss',
'goal': 'minimize'
}
# 2. parameter
dictParameters = {
"MetaTrain.inner_update_steps":{
'values': [1,2,3,4]
},
#"MetaTrain.target_comp_cpt_num":{
# 'values': [1,2,3]
#},
#"MetaTrain.shot_num":{
# 'values': [4,8,16,32]
#}
}
# 3. sweep config
dict_SweepCfg = {
'method': 'grid',
}
# 4. add par and metric to sweep
dict_SweepCfg['parameters'] = dictParameters
dict_SweepCfg['metric'] = dictMetric
"""
# ----------------------------------------------------------------
# wandb_conifg:
# 1. wandb flatten the names using dots in our backend
# 2. dict access by dict_CfgNode['a']['b'] not dict_CfgNode.a.b
# 3. avoid using dots in your config variable names, and use a dash or underscore instead.
# 4. accesses wandb.config keys below the root, use [ ] syntax instead of . syntax
# ----------------------------------------------------------------
dict_CfgNode = convert_cfgnode_to_dict(cfg_node)
# ---------------------------------------
# sweep name using meta_output_dir:
# 1. other wise random name
# 2. '/tank/space/xugy07/MetaVLScratch/runs/mscoco/fomaml/vlbert_OuterLr_0.0001_InnerLr_0.0001_CompCptNum_1_SupNum_8_InnerStep_1_Epoch_8_Pretrain_True'
# example name: mscoco_fomaml_vlbert_OuterLr_0.0001_InnerLr_0.0001_CompCptNum_1_SupNum_8_InnerStep_1_Epoch_8_Pretrain_True
# ---------------------------------------
sweep_name = "_".join(cfg_node.meta_output_dir.split("/")[-3:])
sweep_name = sweep_name + "_" + strftime("%Y-%m-%d %H:%M:%S", gmtime())
with wandb.init(project="MetaCompLearn",
entity="xugy07",
config=dict_CfgNode,
name=sweep_name,
mode='disabled'): # otherwise just random name
# Access all hyperparameter values through wandb.config
dict_CfgNode = dict(wandb.config)
print('-' * 20, "wandb_config", '-' * 20)
print(json.dumps(dict_CfgNode, sort_keys=True, indent=4))
print('-' * 50)
# check whether we get sweep config
cfg_node = CfgNode(init_dict = dict_CfgNode)
print('-' * 20, "cfg_node", '-' * 20)
print(cfg_node)
print('-' * 50)
# check whether they are the
# do we need to re-assign the data
meta_train(cfg_node)
# train(config)
if __name__ == "__main__":
"""
1. cfg
2. dataset -- dataloader
3. trainer:
3.1 gradient
3.2 optimizer
3.3 update
"""
# -----------------------
# argparse from cmd
# -----------------------
parser = argparse.ArgumentParser()
# previous settings
parser.add_argument('--base_model', type=str, default='vlbert')
parser.add_argument('--dataset', type=str, default='MSCOCO')
parser.add_argument('--config_file', type=str, default='{}/Config/maml_cfg.yaml'.format(ProjDir))
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--cfg', nargs='*')
# parser.add_argument('--owutput_dir', type=str, default=out_dir)
parser.add_argument('--out_dir', type=str, default=ExpDir)
parser.add_argument('--novel_comps', action='store_true')
parser.add_argument('--exp_type', choices=["ground", "supervise", "maml", "fomaml", "reptile", "dgmaml",
"faiss_maml", "faiss_dgmaml"],
default="dgmaml")
parser.add_argument('--task_type', choices=["random", "comp", "object", "faiss"],
default="faiss")
# expose new settings for wandb
parser.add_argument('--inner_update_steps', type=int, default=1)
parser.add_argument('--episode_batch_size', type=int, default=8)
parser.add_argument('--episode_num', type=int, default=2000)
parser.add_argument('--target_cpt_num', type=int, default=1)
parser.add_argument('--shot_num', type=int, default=8)
parser.add_argument('--mask_sup_pair', type=str, default="True")
parser.add_argument('--sup_weight', type=float, default=1.0)
# ablaiton
parser.add_argument('--train_from_scratch', type=str, default="False")
args = parser.parse_args()
# -----------------------
# fix seed
# -----------------------
fix_seed(args.seed)
# -----------------------
# main func
# -----------------------
main(args)