-
Notifications
You must be signed in to change notification settings - Fork 22
/
Copy pathmain.py
83 lines (76 loc) · 2.86 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import pytorch_lightning as pl
from vpr_model import VPRModel
from dataloaders.GSVCitiesDataloader import GSVCitiesDataModule
if __name__ == '__main__':
datamodule = GSVCitiesDataModule(
batch_size=60,
img_per_place=4,
min_img_per_place=4,
shuffle_all=False, # shuffle all images or keep shuffling in-city only
random_sample_from_each_place=True,
image_size=(224, 224),
num_workers=10,
show_data_stats=True,
val_set_names=['pitts30k_val', 'pitts30k_test', 'msls_val'], # pitts30k_val, pitts30k_test, msls_val
)
model = VPRModel(
#---- Encoder
backbone_arch='dinov2_vitb14',
backbone_config={
'num_trainable_blocks': 4,
'return_token': True,
'norm_layer': True,
},
agg_arch='SALAD',
agg_config={
'num_channels': 768,
'num_clusters': 64,
'cluster_dim': 128,
'token_dim': 256,
},
lr = 6e-5,
optimizer='adamw',
weight_decay=9.5e-9, # 0.001 for sgd and 0 for adam,
momentum=0.9,
lr_sched='linear',
lr_sched_args = {
'start_factor': 1,
'end_factor': 0.2,
'total_iters': 4000,
},
#----- Loss functions
# example: ContrastiveLoss, TripletMarginLoss, MultiSimilarityLoss,
# FastAPLoss, CircleLoss, SupConLoss,
loss_name='MultiSimilarityLoss',
miner_name='MultiSimilarityMiner', # example: TripletMarginMiner, MultiSimilarityMiner, PairMarginMiner
miner_margin=0.1,
faiss_gpu=False
)
# model params saving using Pytorch Lightning
# we save the best 3 models accoring to Recall@1 on pittsburg val
checkpoint_cb = pl.callbacks.ModelCheckpoint(
monitor='pitts30k_val/R1',
filename=f'{model.encoder_arch}' + '_({epoch:02d})_R1[{pitts30k_val/R1:.4f}]_R5[{pitts30k_val/R5:.4f}]',
auto_insert_metric_name=False,
save_weights_only=True,
save_top_k=3,
save_last=True,
mode='max'
)
#------------------
# we instanciate a trainer
trainer = pl.Trainer(
accelerator='gpu',
devices=1,
default_root_dir=f'./logs/', # Tensorflow can be used to viz
num_nodes=1,
num_sanity_val_steps=0, # runs a validation step before stating training
precision='16-mixed', # we use half precision to reduce memory usage
max_epochs=4,
check_val_every_n_epoch=1, # run validation every epoch
callbacks=[checkpoint_cb],# we only run the checkpointing callback (you can add more)
reload_dataloaders_every_n_epochs=1, # we reload the dataset to shuffle the order
log_every_n_steps=20,
)
# we call the trainer, we give it the model and the datamodule
trainer.fit(model=model, datamodule=datamodule)