Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can hydra laod mamba's weight directly? #6

Open
EricLina opened this issue Jul 22, 2024 · 8 comments
Open

Can hydra laod mamba's weight directly? #6

EricLina opened this issue Jul 22, 2024 · 8 comments

Comments

@EricLina
Copy link

Hello!
Thanks for you amazing work!

If I have a pretrained mamba model, can I load mamba's weight to hydra directly?

@sukjunhwang
Copy link
Collaborator

Yes, I believe copying the Mamba weights for both forward and backward passes would be a very good initial weight to start from.

@EricLina
Copy link
Author

I think your architecture is similar to mamba2, but as tridao said mamba1 couldn't be loaded into mamba2, can you give me some guidance about how to copy mamba's weights into your model?

@sukjunhwang
Copy link
Collaborator

I believe weights of Mamba-2 are also public on HuggingFace

@EricLina
Copy link
Author

EricLina commented Jul 26, 2024

Thanks for your reply!
I mean how to copy mamba1(with selective scan)'s weight to your architecture(or Mamba2).
Actually, I want to transfer vision mamba's weight to your model, which was trained on mamba1.

@sukjunhwang
Copy link
Collaborator

Then you can slightly modify our code from using SSD to S6, with some corrections to the SSM parameters

@EricLina
Copy link
Author

I tried to load mamba's weight to hydra, but found there are some differences:

1. Parameters in mamba but not in Hydra:

  • Parameter: x_proj.weight,
    Mamba shape: torch.Size([36, 128])

  • Parameter: dt_proj.bias,
    Mamba shape: torch.Size([128])

  • Parameter: dt_proj.weight,
    Mamba shape: torch.Size([128, 4])

2. Parameters in Hydra but not in mamba:

  • Parameter: dt_bias,
    Hydra shape: torch.Size([2])

  • Parameter: fc_D.weight,
    Hydra shape: torch.Size([2, 128])

  • Parameter: norm.weight,
    Hydra shape: torch.Size([128])

3. Parameters with the same name but different shapes:

  • Parameter: A_log,
    Mamba shape: torch.Size([128, 16]), Hydra shape: torch.Size([2])
  • Parameter: D,
    Mamba shape: torch.Size([128]), Hydra shape: torch.Size([2])
  • Parameter: in_proj.weight,
    Mamba shape: torch.Size([256, 64]), Hydra shape: torch.Size([516, 64])
  • Parameter: conv1d.weight,
    Mamba shape: torch.Size([128, 1, 4]), Hydra shape: torch.Size([384, 1, 7])
  • Parameter: conv1d.bias,
    Mamba shape: torch.Size([128]), Hydra shape: torch.Size([384])
    """

Below is my code:


def cmp_mamba1_with_hydra():
    """
    
    """
    d_model = 64
    mamba = Mamba(d_model)
    hydra = Hydra(d_model)


    mamba_params = mamba.state_dict()
    hydra_params = hydra.state_dict()

    param_diff = set(mamba_params.keys()) - set(hydra_params.keys())
    print("\nParameters in mamba but not in Hydra:")
    for name in param_diff:
        print(f"Parameter: {name}, \nMamba shape: {mamba_params[name].shape}")

    param_diff = set(hydra_params.keys()) - set(mamba_params.keys())
    print("\nParameters in Hydra but not in mamba:")
    for name in param_diff:
        print(f"Parameter: {name}, \nHydra shape: {hydra_params[name].shape}")

    param_diff_shape = {name: (mamba_params[name].shape, hydra_params[name].shape) 
                        for name in mamba_params.keys() 
                        if name in hydra_params and mamba_params[name].shape != hydra_params[name].shape}
    
    print("\nParameters with the same name but different shapes:")
    for name, shapes in param_diff_shape.items():
        print(f"Parameter: {name},\n Mamba shape: {shapes[0]}, Hydra shape: {shapes[1]}")

@Hprairie
Copy link

You will not be able to load Mamba1's parameters from my understanding, as hydra is a mamba2 variant. The reasoning behind this is because the parameterization of the A matrix is completely different between the two models, i.e. Mamba1 will be a (D, N) parameter tensor and Mamba2 will be a (D) parameter tensor, where D is the dimension of a layer, and N is the hidden state size.

@sukjunhwang
Copy link
Collaborator

It is not able to directly reuse Mamba1 weights to Mamba2. Unless importing Mamba2 weights, you will need to retrain from scratch. As @Hprairie mentions, they have differences in parameters as well as layouts of the model (e.g., convolution). Please refer to Mamba2 paper for further specific differences between Mamba1 and 2.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants