Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question regarding VQ-VAE optimization #92

Open
FranklinHu1 opened this issue Aug 29, 2024 · 2 comments
Open

Question regarding VQ-VAE optimization #92

FranklinHu1 opened this issue Aug 29, 2024 · 2 comments

Comments

@FranklinHu1
Copy link

Hello ESM team,

Thank you for releasing the ESM3 code! I have a question regarding the training protocol you have described for the VQ-VAE model.

In the paper, you describe a two-stage training protocol where in the first stage, you train the encoder and codebook with a smaller, more efficient decoder. In the second stage, you freeze the encoder and codebook and train a larger, more expressive decoder. In the codebase for the VQ-VAE, your structure token encoder returns z_q and the min_encoding_indices (https://github.com/evolutionaryscale/esm/blob/main/esm/models/vqvae.py#L301) whereas your structure token decoder only takes structure token indices (https://github.com/evolutionaryscale/esm/blob/main/esm/models/vqvae.py#L380) as part of its decoding method, i.e. only the discrete integer values.

In the code for the codebook, the forward method (https://github.com/evolutionaryscale/esm/blob/main/esm/layers/codebook.py#L57) returns the straight through embeddings (with gradients from the straight through estimator), the encoding indices (which do not have gradients), and the commitment loss.

If you use the structure token decoder as is and only pass the embedding indices from the encoder to the decoder during training, then there is a disconnect and backpropagation will not populate gradients for the encoder parameters. To get gradients for the encoder from the backward pass, you would have to pass z_q instead, which has gradient information.

Based on this, my interpretation of the VQ-VAE training protocol, as described in the paper, is as follows:

  1. In the first stage, the embedding vectors z_q generated from the encoder and codebook are passed to a decoder which operates directly on these embedding vectors. Training the model in this way allows the reconstruction losses over the decoder outputs to backpropagate and optimize the encoder as well.
  2. In the second stage, the encoder and codebook are frozen and the embedding indices are passed to a decoder that now has its own internal embedding layer for these indices. This is the format of the structure token encoder already present in the code.
  3. In the full ESM model, only tokens (discrete integer indices) are used for each track. For structure, these tokens are generated from the encoder and codebook and decoded by the second stage decoder which takes in indices.

Is this a correct interpretation of the ESM3 VQ-VAE training approach? Or is there anything I am missing? Any clarification would be greatly appreciated.

Thank you very much!

@FranklinHu1
Copy link
Author

Sorry, I don't think that's related to my question?

@ebetica
Copy link
Contributor

ebetica commented Oct 16, 2024

I think that's correctly. I believe we used the straight through estimator, e.g. we used quantize(z_q) in the decoder and passed the gradient through to z_q despite it not matching up.

@Jostino
Copy link

Jostino commented Oct 16, 2024

Sorry, I don't think that's related to my question?

sorry my last replies wasnt made by me but from someone who hacked a cookie of my browser to spam on github.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants
@ebetica @Jostino @FranklinHu1 and others