Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rewrite the mace equivariant head to avoid a model explosion #52

Merged
merged 2 commits into from
Jun 5, 2024

Conversation

sblackburn86
Copy link
Collaborator

original model:
time embedding + mace output ->
o3.Linear -> o3.BatchNorm -> o3.TensorSquare -> non linearity
The TensorSquare had a lot of weights and was probably overkill for our purpose. Its goal was to mix the time information across the different channels of MACE, but it was also computing a lot of useless stuff.

I replaced with the following:
FullyConnectedTensorProduct(time embedding, mace output) ->
o3.Linear -> o3.BatchNorm -> non linearity
the FCTP mixes time with all the components of the mace output and we can control the dimensionality of its output. No need for a TensorSquare anymore
We probably could simplify even more by allowing only the 0e channel as the output of the FCTP since the next layers should not mix different ells

@sblackburn86 sblackburn86 requested a review from rousseab June 3, 2024 19:39
Copy link
Collaborator

@rousseab rousseab left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@sblackburn86 sblackburn86 merged commit eab8b54 into main Jun 5, 2024
1 check passed
@sblackburn86 sblackburn86 deleted the mace_equivariant_head_with_tensorproduct branch June 5, 2024 12:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants