Skip to content
This repository has been archived by the owner on Oct 13, 2022. It is now read-only.

Use a pre-trained bigram P for LF-MMI training #222

Merged
merged 4 commits into from
Jul 2, 2021
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions snowfall/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,11 @@ def load_checkpoint(
src_key = '{}.{}'.format('module', key)
dst_state_dict[key] = src_state_dict.pop(src_key)
assert len(src_state_dict) == 0
model.load_state_dict(dst_state_dict)
model.load_state_dict(dst_state_dict, strict=False)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@danpovey
Adding strict=False should prevent PyTorch from complaining
about extra key P_scores in the checkpoints.

else:
model.load_state_dict(checkpoint['state_dict'])
model.load_state_dict(checkpoint['state_dict'], strict=False)
# Note we used strict=False above so that the current code
# can load models trained with P_scores.

model.num_features = checkpoint['num_features']
model.num_classes = checkpoint['num_classes']
Expand Down