Skip to content

Commit

Permalink
Fix silent fail if network loads wrong weights (#1521)
Browse files Browse the repository at this point in the history
* Add warning if the imported network does not have the right keys

Signed-off-by: Matthias Hadlich <[email protected]>

* Update warning

Signed-off-by: Matthias Hadlich <[email protected]>

* Set load_strict to false for deepedit

Signed-off-by: Matthias Hadlich <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Matthias Hadlich <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: tangy5 <[email protected]>
Co-authored-by: SACHIDANAND ALLE <[email protected]>
  • Loading branch information
4 people authored Aug 26, 2023
1 parent 43e9e2a commit 7208641
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 1 deletion.
12 changes: 11 additions & 1 deletion monailabel/tasks/infer/basic_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(
output_label_key: str = "pred",
output_json_key: str = "result",
config: Union[None, Dict[str, Any]] = None,
load_strict: bool = False,
load_strict: bool = True,
roi_size=None,
preload=False,
train_mode=False,
Expand Down Expand Up @@ -453,6 +453,16 @@ def _get_network(self, device, data):
if path:
checkpoint = torch.load(path, map_location=torch.device(device))
model_state_dict = checkpoint.get(self.model_state_dict, checkpoint)

if set(self.network.state_dict().keys()) != set(checkpoint.keys()):
logger.warning(
f"Checkpoint keys don't match network.state_dict()! Items that exist in only one dict"
f" but not in the other: {set(self.network.state_dict().keys()) ^ set(checkpoint.keys())}"
)
logger.warning(
"The run will now continue unless load_strict is set to True. "
"If loading fails or the network behaves abnormally, please check the loaded weights"
)
network.load_state_dict(model_state_dict, strict=self.load_strict)
else:
network = torch.jit.load(path, map_location=torch.device(device))
Expand Down
1 change: 1 addition & 0 deletions sample-apps/radiology/lib/infers/deepedit.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def __init__(
self.spatial_size = spatial_size
self.target_spacing = target_spacing
self.number_intensity_ch = number_intensity_ch
self.load_strict = False

def pre_transforms(self, data=None):
t = [
Expand Down

0 comments on commit 7208641

Please sign in to comment.