Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove weird sphericart.torch mechanics for double backward #139

Open
frostedoyster opened this issue Aug 20, 2024 · 3 comments
Open

Remove weird sphericart.torch mechanics for double backward #139

frostedoyster opened this issue Aug 20, 2024 · 3 comments

Comments

@frostedoyster
Copy link
Collaborator

frostedoyster commented Aug 20, 2024

In torch, there is no way to know if a second derivative call might be executed by the user (unlike for the first derivative, where requires_grad can be checked).
As a result, in the current API, we require the user to specify if second derivatives will be used at class initialization. This is pretty useless for two reasons:

  • if the model is torchscripted, the class can't be re-initialized so one is stuck with first derivatives only
  • the model can fail silently if second derivatives are called and they were not requested at initialization

The only way I see to make this feature usable is to calculate the second derivatives on the fly when their calculation is needed. This will recompute the values and first derivatives of the spherical harmonics, but the current approach which avoids the recomputation is unsustainable in practice. We should also find a way to mark the second derivative function as non-differentiable to avoid, once again, silent failures if people try to differentiate 3 or more times. Something similar to @once_differentiable (https://discuss.pytorch.org/t/what-does-the-function-wrapper-once-differentiable-do/31513), but for C++ torch.

@ceriottm
Copy link
Contributor

Maybe having a separate class for second derivatives?

@frostedoyster
Copy link
Collaborator Author

This is partially fixed thanks to warnings.
Once we can reasonably expect torch 2.4 as a minimum requirement, we will be able to fix it once and for all with
https://pytorch.org/tutorials/advanced/cpp_custom_ops.html

@Luthaf
Copy link
Contributor

Luthaf commented Oct 12, 2024

I'm not sure I see how the link you shared would fix the issue? By allowing to define backward Python-side?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants