You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Thank you for releasing the ESM3 code! I have a question regarding the training protocol you have described for the VQ-VAE model.
In the paper, you describe a two-stage training protocol where in the first stage, you train the encoder and codebook with a smaller, more efficient decoder. In the second stage, you freeze the encoder and codebook and train a larger, more expressive decoder. In the codebase for the VQ-VAE, your structure token encoder returns z_q and the min_encoding_indices (https://github.com/evolutionaryscale/esm/blob/main/esm/models/vqvae.py#L301) whereas your structure token decoder only takes structure token indices (https://github.com/evolutionaryscale/esm/blob/main/esm/models/vqvae.py#L380) as part of its decoding method, i.e. only the discrete integer values.
If you use the structure token decoder as is and only pass the embedding indices from the encoder to the decoder during training, then there is a disconnect and backpropagation will not populate gradients for the encoder parameters. To get gradients for the encoder from the backward pass, you would have to pass z_q instead, which has gradient information.
Based on this, my interpretation of the VQ-VAE training protocol, as described in the paper, is as follows:
In the first stage, the embedding vectors z_q generated from the encoder and codebook are passed to a decoder which operates directly on these embedding vectors. Training the model in this way allows the reconstruction losses over the decoder outputs to backpropagate and optimize the encoder as well.
In the second stage, the encoder and codebook are frozen and the embedding indices are passed to a decoder that now has its own internal embedding layer for these indices. This is the format of the structure token encoder already present in the code.
In the full ESM model, only tokens (discrete integer indices) are used for each track. For structure, these tokens are generated from the encoder and codebook and decoded by the second stage decoder which takes in indices.
Is this a correct interpretation of the ESM3 VQ-VAE training approach? Or is there anything I am missing? Any clarification would be greatly appreciated.
Thank you very much!
The text was updated successfully, but these errors were encountered:
I think that's correctly. I believe we used the straight through estimator, e.g. we used quantize(z_q) in the decoder and passed the gradient through to z_q despite it not matching up.
Hello ESM team,
Thank you for releasing the ESM3 code! I have a question regarding the training protocol you have described for the VQ-VAE model.
In the paper, you describe a two-stage training protocol where in the first stage, you train the encoder and codebook with a smaller, more efficient decoder. In the second stage, you freeze the encoder and codebook and train a larger, more expressive decoder. In the codebase for the VQ-VAE, your structure token encoder returns z_q and the min_encoding_indices (https://github.com/evolutionaryscale/esm/blob/main/esm/models/vqvae.py#L301) whereas your structure token decoder only takes structure token indices (https://github.com/evolutionaryscale/esm/blob/main/esm/models/vqvae.py#L380) as part of its decoding method, i.e. only the discrete integer values.
In the code for the codebook, the forward method (https://github.com/evolutionaryscale/esm/blob/main/esm/layers/codebook.py#L57) returns the straight through embeddings (with gradients from the straight through estimator), the encoding indices (which do not have gradients), and the commitment loss.
If you use the structure token decoder as is and only pass the embedding indices from the encoder to the decoder during training, then there is a disconnect and backpropagation will not populate gradients for the encoder parameters. To get gradients for the encoder from the backward pass, you would have to pass z_q instead, which has gradient information.
Based on this, my interpretation of the VQ-VAE training protocol, as described in the paper, is as follows:
Is this a correct interpretation of the ESM3 VQ-VAE training approach? Or is there anything I am missing? Any clarification would be greatly appreciated.
Thank you very much!
The text was updated successfully, but these errors were encountered: