Skip to content

Commit

Permalink
Merge remote-tracking branch 'refs/remotes/origin/master'
Browse files Browse the repository at this point in the history
  • Loading branch information
mpelchat04 committed Aug 20, 2020
2 parents 56fd2fd + c12b8a8 commit 74eca6f
Showing 1 changed file with 21 additions and 2 deletions.
23 changes: 21 additions & 2 deletions airborne_lidar/airborne_lidar_seg.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,27 @@ def get_model(nb_classes, args):

# Using pretrained model.
if args['training']['finetune']:
state = torch.load(args['training']['finetune'])
net.load_state_dict(state['state_dict'], strict=False)
# state = torch.load(args['training']['finetune'])
# old_in_channels = state['state_dict'].cv0.weight.shape(0)
# old_nb_classes = state['state_dict'].fcout.weight.shape(0)
# net = Net(old_in_channels, output_channels=old_nb_classes, args=args)
# net.load_state_dict(state['state_dict'], strict=False)

# replace in_channels and out nb classes shapes.

pretrained_dict = torch.load(args['training']['finetune'])['state_dict']
model_dict = net.state_dict()

# 1. filter out unnecessary keys
# {k: v for k, v in pretrained_state.iteritems() if k in model_state and v.size() == model_state[k].size()}
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and v.size() == model_dict[k].size()}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
net.load_state_dict(model_dict)

# else:
# net = Net(input_channels, output_channels=nb_classes, args=args)

return net, features

Expand Down

0 comments on commit 74eca6f

Please sign in to comment.