Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework Combination class #26

Merged
merged 47 commits into from
Nov 1, 2023
Merged

Rework Combination class #26

merged 47 commits into from
Nov 1, 2023

Conversation

kaminow
Copy link
Collaborator

@kaminow kaminow commented Aug 28, 2023

The previous iteration of the Combination classes required the computation graph for each pose to be held in GPU memory, which will quickly overflow normal GPUs when using all-atom poses. The new version splits the gradient calculation such that the gradient for each pose is done separately and combined appropriately at the end, meaning that each computation graph can be freed from memory after use. The derivation for the math used in the different Combination subclasses can be found in the README_COMBINATION.md file.

General list of changes for each file:

README_COMBINATION.md
Math for separating out the gradients in the Combination classes

mtenn/combination.py
Each method for combining predictions has a torch.autograd.Function, which takes care of combining and assigning the gradients in the backward pass, and a Combination subclass that is essentially a wrapper around the Function

mtenn/conversion_utils/e3nn.py

  • Update import statements
  • Add ComplexOnlyStrategy

mtenn/conversion_utils/schnet.py

  • Update import statements
  • Add ComplexOnlyStrategy

mtenn/model.py

  • Move all non-Model classes to their own files
  • Update GroupedModel forward pass to work with new Combination setup
  • GroupedModel now returns list of predictions for each pose in addition to the final prediction

mtenn/readout.py
Move all Readout-related code

mtenn/representation.py
Move all Representation-related code

mtenn/Strategy.py

  • Move all Strategy-related code
  • Add ComplexOnlyStrategy class that only predicts on the full input

@kaminow kaminow marked this pull request as ready for review October 13, 2023 19:56
@hmacdope hmacdope self-requested a review October 16, 2023 21:32
Copy link
Contributor

@hmacdope hmacdope left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Few small things, but looking really good. Addressing failing test(s) is main issue, but good that we are putting them in / making progress towards checking correctness.

mtenn/combination.py Outdated Show resolved Hide resolved
mtenn/combination.py Show resolved Hide resolved
mtenn/combination.py Show resolved Hide resolved
@kaminow
Copy link
Collaborator Author

kaminow commented Oct 18, 2023

after playing around with the tests a bit, it seems like it's just a stochastic failure based on how the random data is initialized. two workarounds that I can think of are:

  1. find a random seed that lets all the tests through as is, and trust that if the math gets messed up at some point then the tests will fail
  2. adjust the parameters to the np.allclose call to be more lenient

@hmacdope do you have thoughts as to which would be better/preferable?

@codecov-commenter
Copy link

codecov-commenter commented Oct 18, 2023

Codecov Report

Merging #26 (58c43f2) into main (54c94b0) will increase coverage by 31.92%.
The diff coverage is 84.01%.

Additional details and impacted files

@kaminow kaminow requested a review from hmacdope October 18, 2023 21:07
Copy link
Contributor

@hmacdope hmacdope left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good @kaminow, merge when ready.

Lets hold off on a release until we sort out CI and env build issues.

@hmacdope
Copy link
Contributor

@kaminow I have fixes a missing ase (?) dep in tests and updated some env files.

@kaminow
Copy link
Collaborator Author

kaminow commented Oct 23, 2023

@hmacdope thanks! any thoughts on why things are still failing for Ubuntu 3.11?

@hmacdope
Copy link
Contributor

I will investigate, seems odd.

@kaminow
Copy link
Collaborator Author

kaminow commented Oct 31, 2023

@hmacdope after some investigation, it seems that there's some requirements broken for the 3.11 version of pytorch_geometric. it seems to be requiring cuda for some reason, while the builds for older Python versions don't, so for Python 3.11 an older version of pytorch_geometric is being installed, prior to when the interaction_graph was added to the model

@kaminow
Copy link
Collaborator Author

kaminow commented Oct 31, 2023

for posterity, this is the error I get when I try to run mamba install pytorch_geometric=2.3.1 in a Python 3.11 env:

warning  libmamba Added empty dependency for problem type SOLVER_RULE_UPDATE
Could not solve for environment specs
The following package could not be installed
└─ pytorch_geometric 2.3.1**  is installable and it requires
   └─ pyg-lib 0.2.0  with the potential options
      ├─ pyg-lib 0.2.0 would require
      │  └─ triton with the potential options
      │     ├─ triton [1.1.2|2.0.0] would require
      │     │  └─ python >=3.10,<3.11.0a0 , which can be installed;
      │     ├─ triton 1.1.2 would require
      │     │  └─ python >=3.7,<3.8.0a0 , which can be installed;
      │     ├─ triton [1.1.2|2.0.0] would require
      │     │  └─ python >=3.8,<3.9.0a0 , which can be installed;
      │     ├─ triton [1.1.2|2.0.0] would require
      │     │  └─ python >=3.9,<3.10.0a0 , which can be installed;
      │     └─ triton 2.0.0 would require
      │        └─ pytorch * cuda* with the potential options
      │           ├─ pytorch [1.0.1|1.1.0|...|1.9.1] would require
      │           │  └─ python >=3.7,<3.8.0a0 , which can be installed;
      │           ├─ pytorch [1.10.0|1.10.1|...|1.9.1] would require
      │           │  └─ python >=3.8,<3.9.0a0 , which can be installed;
      │           ├─ pytorch [1.10.0|1.10.1|...|1.9.1] would require
      │           │  └─ python >=3.9,<3.10.0a0 , which can be installed;
      │           ├─ pytorch 1.11.0 would require
      │           │  └─ python >=3.10,<3.11.0a0 , which can be installed;
      │           ├─ pytorch [1.11.0|1.12.0|...|2.0.0] would require
      │           │  └─ __cuda, which is missing on the system;
      │           ├─ pytorch [1.0.1|1.1.0|...|1.9.1] would require
      │           │  └─ python >=3.6,<3.7.0a0 , which can be installed;
      │           ├─ pytorch [1.0.1|1.1.0|1.2.0|1.3.1] would require
      │           │  └─ python >=2.7,<2.8.0a0 , which can be installed;
      │           └─ pytorch 1.0.1 would require
      │              └─ cudatoolkit >=8.0,<8.1.0a0 , which does not exist (perhaps a missing channel);
      ├─ pyg-lib 0.2.0 would require
      │  └─ python >=3.10,<3.11.0a0 , which can be installed;
      ├─ pyg-lib 0.2.0 would require
      │  └─ python >=3.8,<3.9.0a0 , which can be installed;
      └─ pyg-lib 0.2.0 would require
         └─ python >=3.9,<3.10.0a0 , which can be installed.

@hmacdope
Copy link
Contributor

@kaminow let me take a quick look on their feedstock.

@hmacdope
Copy link
Contributor

for posterity, this is the error I get when I try to run mamba install pytorch_geometric=2.3.1 in a Python 3.11 env:

warning  libmamba Added empty dependency for problem type SOLVER_RULE_UPDATE
Could not solve for environment specs
The following package could not be installed
└─ pytorch_geometric 2.3.1**  is installable and it requires
   └─ pyg-lib 0.2.0  with the potential options
      ├─ pyg-lib 0.2.0 would require
      │  └─ triton with the potential options
      │     ├─ triton [1.1.2|2.0.0] would require
      │     │  └─ python >=3.10,<3.11.0a0 , which can be installed;
      │     ├─ triton 1.1.2 would require
      │     │  └─ python >=3.7,<3.8.0a0 , which can be installed;
      │     ├─ triton [1.1.2|2.0.0] would require
      │     │  └─ python >=3.8,<3.9.0a0 , which can be installed;
      │     ├─ triton [1.1.2|2.0.0] would require
      │     │  └─ python >=3.9,<3.10.0a0 , which can be installed;
      │     └─ triton 2.0.0 would require
      │        └─ pytorch * cuda* with the potential options
      │           ├─ pytorch [1.0.1|1.1.0|...|1.9.1] would require
      │           │  └─ python >=3.7,<3.8.0a0 , which can be installed;
      │           ├─ pytorch [1.10.0|1.10.1|...|1.9.1] would require
      │           │  └─ python >=3.8,<3.9.0a0 , which can be installed;
      │           ├─ pytorch [1.10.0|1.10.1|...|1.9.1] would require
      │           │  └─ python >=3.9,<3.10.0a0 , which can be installed;
      │           ├─ pytorch 1.11.0 would require
      │           │  └─ python >=3.10,<3.11.0a0 , which can be installed;
      │           ├─ pytorch [1.11.0|1.12.0|...|2.0.0] would require
      │           │  └─ __cuda, which is missing on the system;
      │           ├─ pytorch [1.0.1|1.1.0|...|1.9.1] would require
      │           │  └─ python >=3.6,<3.7.0a0 , which can be installed;
      │           ├─ pytorch [1.0.1|1.1.0|1.2.0|1.3.1] would require
      │           │  └─ python >=2.7,<2.8.0a0 , which can be installed;
      │           └─ pytorch 1.0.1 would require
      │              └─ cudatoolkit >=8.0,<8.1.0a0 , which does not exist (perhaps a missing channel);
      ├─ pyg-lib 0.2.0 would require
      │  └─ python >=3.10,<3.11.0a0 , which can be installed;
      ├─ pyg-lib 0.2.0 would require
      │  └─ python >=3.8,<3.9.0a0 , which can be installed;
      └─ pyg-lib 0.2.0 would require
         └─ python >=3.9,<3.10.0a0 , which can be installed.

Pinging @mikemhenry as well as I see he is a maintainer on the PYG feedstock

@hmacdope
Copy link
Contributor

We can try a pin also in the meantime.

@hmacdope
Copy link
Contributor

hmacdope commented Nov 1, 2023

I am fairly sure this is due to the exact pin of pyg-lib==0.2.0 in the pyg feedstock which is pulling down old pytorch versions. Tagging @hadim and @rusty1s? Perhaps they have some insight also. I will also confirm when on my linux box. Regardless, I think we are OK to push forward here and leave CI as indicating a failure.

Copy link
Contributor

@hmacdope hmacdope left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@kaminow kaminow merged commit 6f6d8e8 into main Nov 1, 2023
3 of 4 checks passed
@kaminow kaminow deleted the split-comb-calcs branch November 1, 2023 15:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants