Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can't load pretrained models anymore #239

Open
nsbruce opened this issue Jul 12, 2024 · 16 comments
Open

Can't load pretrained models anymore #239

nsbruce opened this issue Jul 12, 2024 · 16 comments
Labels
bug Something isn't working
Milestone

Comments

@nsbruce
Copy link

nsbruce commented Jul 12, 2024

Describe the bug
Since the torch version was updated, the provided pretrained weights no longer match the efficientnet model and can't be loaded in.

To Reproduce

from torchsig.models.iq_models.efficientnet.efficientnet import efficientnet_b4

convnet = efficentnet_b4(pretrained=True, path='/path/to/efficientnet_b4_online.pt')

Output

Traceback (most recent call last):                                                                                             
  File "<frozen runpy>", line 198, in _run_module_as_main                                                                      
  File "<frozen runpy>", line 88, in _run_code                                                                                 
  File "/train.py", line 135, in <module>                                            
    my_model = train_mymodel()
               ^^^^^^^^^^^^^^                                                                                              
  File "/.venv/lib/python3.11/site-packages/click/core.py", line 1157, in __call__          
    return self.main(*args, **kwargs)                                                                                          
           ^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                          
  File "/.venv/lib/python3.11/site-packages/click/core.py", line 1078, in main              
    rv = self.invoke(ctx)                                                                                                      
         ^^^^^^^^^^^^^^^^                                                                                                      
  File "/.venv/lib/python3.11/site-packages/click/core.py", line 1434, in invoke            
    return ctx.invoke(self.callback, **ctx.params)                                                                             
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                             
  File "/.venv/lib/python3.11/site-packages/click/core.py", line 783, in invoke             
    return __callback(*args, **kwargs)                                                                                         
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                         
  File "train.py", line 127, in train_mymodel                                       
    model = MyModel()
    ^^^^^^^^^^^^^^^^
  File "/mymodel.py", line 28, in __init__                                            
    self.encoder = Encoder()                                                                                         
    ^^^^^^^^^^^^^^^^^^^^                                                                                       
  File "/encoder.py", line 8, in __init__                                            
    self.convnet = efficientnet_b4(pretrained=pretrained, path=path)                                                           
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                           
  File "/.venv/lib/python3.11/site-packages/torchsig/models/iq_models/efficientnet/efficient
net.py", line 285, in efficientnet_b4                                                                                          
    mdl.load_state_dict(torch.load(path))                                                                                      
  File "/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 2189, in load_
state_dict                                                                                                                     
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(                                                  
RuntimeError: Error(s) in loading state_dict for EfficientNet:

Missing key(s) in state_dict: "bn1.weight", "bn1.bias", "bn1.running_mean", "bn1.running_var", "blocks.0.0.bn1.weight",
 "blocks.0.0.bn1.bias", "blocks.0.0.bn1.running_mean", "blocks.0.0.bn1.running_var", "blocks.0.0.bn2.weight", "blocks.0.0.bn2.b
ias", "blocks.0.0.bn2.running_mean", "blocks.0.0.bn2.running_var", "blocks.0.1.bn1.weight", ..., "bn2.running_mean", "bn2.running_var".

Unexpected key(s) in state_dict: "bn1.bn.weight", "bn1.bn.bias", "bn1.bn.running_mean", "bn1.bn.running_var", "bn1.bn.n
um_batches_tracked", "blocks.0.0.bn1.bn.weight", "blocks.0.0.bn1.bn.bias", "blocks.0.0.bn1.bn.running_mean", "blocks.0.0.bn1.bn
.running_var", "blocks.0.0.bn1.bn.num_batches_tracked", "blocks.0.0.bn2.bn.weight", "blocks.0.0.bn2.bn.bias", "blocks.0.0.bn2.b
n.running_mean", "blocks.0.0.bn2.bn.running_var", "blocks.0.0.bn2.bn.num_batches_tracked", "blocks.0.1.bn1.bn.weight",..., "bn2.bn.running_var", "bn2.bn.num_batches_tracked".

Note I shortened the output by a lot since there are many mismatched keys between these dictionaries.

Expected behavior
model loads in correctly

Some possible solutions:

  • Downgrade timm / torch again,
  • The developers provide the scripts they used to generate the pretrained weights, so that users can generate the weights again,
  • The developers provide updated pretrained weights,
  • Somebody figures out how to map these two dicts properly.
@nsbruce
Copy link
Author

nsbruce commented Jul 12, 2024

It looks like the version bump came from @pvallance - do you have any input on what the best approach here is?

@nsbruce
Copy link
Author

nsbruce commented Jul 12, 2024

huggingface/pytorch-image-models#2150
huggingface/pytorch-image-models#2215
https://huggingface.co/timm/swinv2_tiny_window16_256.ms_in1k/discussions/1

These references got me one step closer (past the above error). I adjusted the timm.create_model() calls in efficientnet.py. I'm now getting the following (related) error:

Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/train.py", line 135, in <module>
    my_model = train_mymodel()
                   ^^^^^^^^^^^^^^
  File "/.venv/lib/python3.11/site-packages/click/core.py", line 1157, in __call__
    return self.main(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/.venv/lib/python3.11/site-packages/click/core.py", line 1078, in main
    rv = self.invoke(ctx)
         ^^^^^^^^^^^^^^^^
  File "/.venv/lib/python3.11/site-packages/click/core.py", line 1434, in invoke
    return ctx.invoke(self.callback, **ctx.params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/.venv/lib/python3.11/site-packages/click/core.py", line 783, in invoke
    return __callback(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/train.py", line 127, in train_sigclr
    model = MyModel()
  File "/mymodel.py", line 28, in __init__
    self.encoder = Encoder()
                   ^^^^^^^^^^^^^^^^^^^
  File "/encoder.py", line 8, in __init__
    self.convnet = efficientnet_b4(pretrained=pretrained, path=path)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/.venv/lib/python3.11/site-packages/torchsig/models/iq_models/efficientnet/efficientnet.py", line 276, in efficientnet_b4
    timm.create_model(
  File "/.venv/lib/python3.11/site-packages/timm/models/_factory.py", line 117, in create_model
    model = create_fn(
            ^^^^^^^^^^
  File "/.venv/lib/python3.11/site-packages/timm/models/efficientnet.py", line 1913, in efficientnet_b4
    model = _gen_efficientnet(
            ^^^^^^^^^^^^^^^^^^
  File "/.venv/lib/python3.11/site-packages/timm/models/efficientnet.py", line 651, in _gen_efficientnet
    model = _create_effnet(variant, pretrained, **model_kwargs)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/.venv/lib/python3.11/site-packages/timm/models/efficientnet.py", line 368, in _create_effnet
    model = build_model_with_cfg(
            ^^^^^^^^^^^^^^^^^^^^^
  File "/.venv/lib/python3.11/site-packages/timm/models/_builder.py", line 418, in build_model_with_cfg
    load_pretrained(
  File "/.venv/lib/python3.11/site-packages/timm/models/_builder.py", line 215, in load_pretrained
    state_dict[weight_name] = adapt_input_conv(in_chans, state_dict[weight_name])
                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/.venv/lib/python3.11/site-packages/timm/models/_manipulate.py", line 259, in adapt_input_conv
    O, I, J, K = conv_weight.shape
    ^^^^^^^^^^
ValueError: not enough values to unpack (expected 4, got 3)

@pvallance
Copy link
Collaborator

pvallance commented Jul 13, 2024 via email

@pvallance
Copy link
Collaborator

pvallance commented Jul 13, 2024 via email

@pvallance
Copy link
Collaborator

pvallance commented Jul 13, 2024 via email

@nsbruce
Copy link
Author

nsbruce commented Jul 13, 2024

Hey thanks for the response @pvallance. This is based off of the narrow band example notebooks. In https://github.com/TorchDSP/torchsig/blob/main/examples/02_example_sig53_classifier.ipynb the model being used on the narrowband dataset is

model = efficientnet_b4(
    pretrained=True,
    path="efficientnet_b4.pt",
)

Which was working for me as an encoder while using torchsig v.0.5.0.

I'm using it like this:

from torchsig.models.iq_models.efficientnet.efficientnet import efficientnet_b4
import torch.nn as nn
import torch

class Encoder(nn.Module):
    def __init__(self, pretrained=True,path="/path/to/sig53/efficientnet_b4_online.pt",device='cuda'):
        super().__init__()
        self.convnet = efficientnet_b4(pretrained=pretrained, path=path)

        self.convnet = self.convnet.to(device)
    def forward(self, x):
        return self.convnet(x)

    def predict(self, x):
        with torch.no_grad():
            out = self.forward(x)
        return out

Which is causing that error when I try to build instantiate an Encoder().

@pvallance
Copy link
Collaborator

pvallance commented Jul 19, 2024 via email

@nsbruce
Copy link
Author

nsbruce commented Aug 9, 2024

Hi @pvallance any update on this?

@pvallance
Copy link
Collaborator

pvallance commented Aug 9, 2024 via email

@pvallance
Copy link
Collaborator

pvallance commented Aug 16, 2024 via email

@acchione
Copy link

This sounds very interesting. Any updates here? I was looking to pull the latest but I didn't want to dig through this repo if a big refactor is on the way.

@nsbruce
Copy link
Author

nsbruce commented Sep 12, 2024

Hi @pvallance, lots of interesting updates were released this week! Thank you. The newer notebooks no longer have an example of loading the pre-trained weights.

When I try

from torchsig.models import EfficientNet1d
import torch

mdl = EfficientNet1d(2,53,"b4")
mdl.load_state_dict(torch.load('efficientnet_b4_online.pt'))

I get the following error:

RuntimeError: Error(s) in loading state_dict for EfficientNet:
        Unexpected key(s) in state_dict: "conv_stem.bias", "blocks.0.0.conv_dw.bias", "blocks.0.0.conv_pw.bias", "blocks.0.1.conv_dw.bias", "blocks.0.1.conv_pw.bias", "blocks.1.0.conv_pw.bias", "blocks.1.0.conv_dw.bias", "blocks.1.0.conv_pwl.bias", "blocks.1.1.conv_pw.bias", "blocks.1.1.conv_dw.bias", "blocks.1.1.conv_pwl.bias", "blocks.1.2.conv_pw.bias", "blocks.1.2.conv_dw.bias", "blocks.1.2.conv_pwl.bias", "blocks.1.3.conv_pw.bias", "blocks.1.3.conv_dw.bias", "blocks.1.3.conv_pwl.bias", "blocks.2.0.conv_pw.bias", "blocks.2.0.conv_dw.bias", "blocks.2.0.conv_pwl.bias", "blocks.2.1.conv_pw.bias", "blocks.2.1.conv_dw.bias", "blocks.2.1.conv_pwl.bias", "blocks.2.2.conv_pw.bias", "blocks.2.2.conv_dw.bias", "blocks.2.2.conv_pwl.bias", "blocks.2.3.conv_pw.bias", "blocks.2.3.conv_dw.bias", "blocks.2.3.conv_pwl.bias", "blocks.3.0.conv_pw.bias", "blocks.3.0.conv_dw.bias", "blocks.3.0.conv_pwl.bias", "blocks.3.1.conv_pw.bias", "blocks.3.1.conv_dw.bias", "blocks.3.1.conv_pwl.bias", "blocks.3.2.conv_pw.bias", "blocks.3.2.conv_dw.bias", "blocks.3.2.conv_pwl.bias", "blocks.3.3.conv_pw.bias", "blocks.3.3.conv_dw.bias", "blocks.3.3.conv_pwl.bias", "blocks.3.4.conv_pw.bias", "blocks.3.4.conv_dw.bias", "blocks.3.4.conv_pwl.bias", "blocks.3.5.conv_pw.bias", "blocks.3.5.conv_dw.bias", "blocks.3.5.conv_pwl.bias", "blocks.4.0.conv_pw.bias", "blocks.4.0.conv_dw.bias", "blocks.4.0.conv_pwl.bias", "blocks.4.1.conv_pw.bias", "blocks.4.1.conv_dw.bias", "blocks.4.1.conv_pwl.bias", "blocks.4.2.conv_pw.bias", "blocks.4.2.conv_dw.bias", "blocks.4.2.conv_pwl.bias", "blocks.4.3.conv_pw.bias", "blocks.4.3.conv_dw.bias", "blocks.4.3.conv_pwl.bias", "blocks.4.4.conv_pw.bias", "blocks.4.4.conv_dw.bias", "blocks.4.4.conv_pwl.bias", "blocks.4.5.conv_pw.bias", "blocks.4.5.conv_dw.bias", "blocks.4.5.conv_pwl.bias", "blocks.5.0.conv_pw.bias", "blocks.5.0.conv_dw.bias", "blocks.5.0.conv_pwl.bias", "blocks.5.1.conv_pw.bias", "blocks.5.1.conv_dw.bias", "blocks.5.1.conv_pwl.bias", "blocks.5.2.conv_pw.bias", "blocks.5.2.conv_dw.bias", "blocks.5.2.conv_pwl.bias", "blocks.5.3.conv_pw.bias", "blocks.5.3.conv_dw.bias", "blocks.5.3.conv_pwl.bias", "blocks.5.4.conv_pw.bias", "blocks.5.4.conv_dw.bias", "blocks.5.4.conv_pwl.bias", "blocks.5.5.conv_pw.bias", "blocks.5.5.conv_dw.bias", "blocks.5.5.conv_pwl.bias", "blocks.5.6.conv_pw.bias", "blocks.5.6.conv_dw.bias", "blocks.5.6.conv_pwl.bias", "blocks.5.7.conv_pw.bias", "blocks.5.7.conv_dw.bias", "blocks.5.7.conv_pwl.bias", "blocks.6.0.conv_pw.bias", "blocks.6.0.conv_dw.bias", "blocks.6.0.conv_pwl.bias", "blocks.6.1.conv_pw.bias", "blocks.6.1.conv_dw.bias", "blocks.6.1.conv_pwl.bias", "conv_head.bias".

Which is a much smaller number of unexpected keys than before (output was truncated in my fist comment on this issue), and no missing keys, which I think is good progress. Do you have any advice on how to load the pretrained weights?

@ereoh
Copy link
Collaborator

ereoh commented Sep 12, 2024

Hello @nsbruce,

Taking over for @pvallance, please look at examples/02_example_sig53_classifier.ipynb for more guidance.

@nsbruce
Copy link
Author

nsbruce commented Sep 12, 2024

Hi @ereoh, I did look through that. No pre-trained weights are loaded in that notebook.

In the old version the weights were loaded like so:

model = efficientnet_b4(
    pretrained=True,
    path="efficientnet_b4.pt",
)

In the new version the pre-trained weights are not used.

@ereoh
Copy link
Collaborator

ereoh commented Sep 12, 2024

Gotcha. Passing it on to developers -- we really appreciate your patience.

@ereoh ereoh added this to the v0.6.1 milestone Nov 6, 2024
@TorchDSP TorchDSP added this to TorchSig Nov 6, 2024
@ereoh
Copy link
Collaborator

ereoh commented Nov 6, 2024

Hello, just updating you that for this next release (v0.6.1) we will be updating our pretrained models with the correct weights for both Narrowband and Wideband. Thanks for your patience.

@ereoh ereoh added the bug Something isn't working label Nov 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
No open projects
Status: No status
Development

No branches or pull requests

4 participants