-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Rework Combination class #26
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Few small things, but looking really good. Addressing failing test(s) is main issue, but good that we are putting them in / making progress towards checking correctness.
after playing around with the tests a bit, it seems like it's just a stochastic failure based on how the random data is initialized. two workarounds that I can think of are:
@hmacdope do you have thoughts as to which would be better/preferable? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good @kaminow, merge when ready.
Lets hold off on a release until we sort out CI and env build issues.
@kaminow I have fixes a missing ase (?) dep in tests and updated some env files. |
@hmacdope thanks! any thoughts on why things are still failing for Ubuntu 3.11? |
I will investigate, seems odd. |
@hmacdope after some investigation, it seems that there's some requirements broken for the 3.11 version of |
for posterity, this is the error I get when I try to run
|
@kaminow let me take a quick look on their feedstock. |
Pinging @mikemhenry as well as I see he is a maintainer on the PYG feedstock |
We can try a pin also in the meantime. |
I am fairly sure this is due to the exact pin of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
The previous iteration of the
Combination
classes required the computation graph for each pose to be held in GPU memory, which will quickly overflow normal GPUs when using all-atom poses. The new version splits the gradient calculation such that the gradient for each pose is done separately and combined appropriately at the end, meaning that each computation graph can be freed from memory after use. The derivation for the math used in the differentCombination
subclasses can be found in theREADME_COMBINATION.md
file.General list of changes for each file:
README_COMBINATION.md
Math for separating out the gradients in the
Combination
classesmtenn/combination.py
Each method for combining predictions has a
torch.autograd.Function
, which takes care of combining and assigning the gradients in the backward pass, and aCombination
subclass that is essentially a wrapper around theFunction
mtenn/conversion_utils/e3nn.py
import
statementsComplexOnlyStrategy
mtenn/conversion_utils/schnet.py
import
statementsComplexOnlyStrategy
mtenn/model.py
Model
classes to their own filesGroupedModel
forward pass to work with newCombination
setupGroupedModel
now returns list of predictions for each pose in addition to the final predictionmtenn/readout.py
Move all
Readout
-related codemtenn/representation.py
Move all
Representation
-related codemtenn/Strategy.py
Strategy
-related codeComplexOnlyStrategy
class that only predicts on the full input