From be082797eb75db56d7048dbc58d6b7dae47b3be0 Mon Sep 17 00:00:00 2001 From: RangiLyu Date: Sat, 17 Jul 2021 17:10:58 +0800 Subject: [PATCH] [Fix]: Fix load weight when change num_classes --- tools/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tools/train.py b/tools/train.py index c2894f292..66c07d73f 100644 --- a/tools/train.py +++ b/tools/train.py @@ -20,7 +20,7 @@ import pytorch_lightning as pl from pytorch_lightning.callbacks import ProgressBar -from nanodet.util import mkdir, Logger, cfg, load_config, convert_old_model +from nanodet.util import mkdir, Logger, cfg, load_config, convert_old_model, load_model_weight from nanodet.data.collate import collate_function from nanodet.data.dataset import build_dataset from nanodet.trainer.task import TrainingTask @@ -75,7 +75,7 @@ def main(args): warnings.warn('Warning! Old .pth checkpoint is deprecated. ' 'Convert the checkpoint with tools/convert_old_checkpoint.py ') ckpt = convert_old_model(ckpt) - task.load_state_dict(ckpt['state_dict'], strict=False) + load_model_weight(task.model, ckpt, logger) logger.log('Loaded model weight from {}'.format(cfg.schedule.load_model)) model_resume_path = os.path.join(cfg.save_dir, 'model_last.ckpt') if 'resume' in cfg.schedule else None