Skip to content

Commit

Permalink
[Fix]: Fix load weight when change num_classes
Browse files Browse the repository at this point in the history
  • Loading branch information
RangiLyu authored Jul 17, 2021
1 parent 4ecfb1c commit be08279
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ProgressBar

from nanodet.util import mkdir, Logger, cfg, load_config, convert_old_model
from nanodet.util import mkdir, Logger, cfg, load_config, convert_old_model, load_model_weight
from nanodet.data.collate import collate_function
from nanodet.data.dataset import build_dataset
from nanodet.trainer.task import TrainingTask
Expand Down Expand Up @@ -75,7 +75,7 @@ def main(args):
warnings.warn('Warning! Old .pth checkpoint is deprecated. '
'Convert the checkpoint with tools/convert_old_checkpoint.py ')
ckpt = convert_old_model(ckpt)
task.load_state_dict(ckpt['state_dict'], strict=False)
load_model_weight(task.model, ckpt, logger)
logger.log('Loaded model weight from {}'.format(cfg.schedule.load_model))

model_resume_path = os.path.join(cfg.save_dir, 'model_last.ckpt') if 'resume' in cfg.schedule else None
Expand Down

0 comments on commit be08279

Please sign in to comment.