Skip to content

Commit

Permalink
Raise warning for torch double backward (#141)
Browse files Browse the repository at this point in the history
* Add `__call__` for pytorch
* Change all APIs 
* Test with desiderata
* Remove solid/normalized option from examples
* Add warning and tests
* Run CI on ARM64 macOS to fix CI bug
 
---------

Co-authored-by: Michele Ceriotti <[email protected]>
Co-authored-by: Guillaume Fraux <[email protected]>
  • Loading branch information
3 people authored Sep 2, 2024
1 parent b8c9d10 commit 2bee94d
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 8 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
matrix:
include:
- os: ubuntu-20.04
- os: macos-13
- os: macos-14
steps:
- uses: actions/checkout@v3

Expand Down
24 changes: 18 additions & 6 deletions sphericart-torch/python/sphericart/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,24 @@ class SphericalHarmonics(torch.nn.Module):
By default, only single backpropagation with respect to `xyz` is
enabled (this includes mixed second derivatives where ``xyz`` appears
as only one of the differentiation steps). To activate support
for double backpropagation with respect to ``xyz``, please set
``backward_second_derivatives=True`` at class creation. Warning: if
``backward_second_derivatives`` is not set to ``True`` and double
differentiation with respect to ``xyz`` is requested, the results may
be incorrect and no warnings will be displayed. This is necessary to
provide optimal performance for both use cases.
for double backpropagation with respect to `xyz`, please set
`backward_second_derivatives=True` at class creation. Warning: if
`backward_second_derivatives` is not set to `True` and double
differentiation with respect to `xyz` is requested, the results may
be incorrect, but a warning will be displayed. This is necessary to
provide optimal performance for both use cases. In particular, the
following will happen:
- when using ``torch.autograd.grad`` as the second backpropagation
step, a warning will be displayed and torch will raise an error.
- when using ``torch.autograd.grad`` with ``allow_unused=True`` as
the second backpropagation step, the results will be incorrect
and only a warning will be displayed.
- when using ``backward`` as the second backpropagation step, the
results will be incorrect and only a warning will be displayed.
- when using ``torch.autograd.functional.hessian``, the results will
be incorrect and only a warning will be displayed.
This class supports TorchScript.
Expand Down
120 changes: 120 additions & 0 deletions sphericart-torch/python/tests/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,123 @@ def f(xyz):

proportionality_factor = analytical_hessian[2, 2] / hessian[2, 2]
assert torch.allclose(analytical_hessian, hessian * proportionality_factor)


def test_second_derivative_error(xyz):
# This is a problematic test, as it tests some potential silent failures

# Initialize a calculator
calculator = sphericart.torch.SphericalHarmonics(
l_max=8, backward_second_derivatives=False
)

# Fill a single xyz point with arbitrary numbers
xyz = torch.tensor(
[
[0.67, 0.53, -0.22],
],
requires_grad=True,
)

# Compute the spherical harmonics and run backward 2 times.
# The second one is supposed to raise an error.
sph = calculator.compute(xyz)
l0 = torch.sum(sph)
d1 = torch.autograd.grad(
outputs=l0,
inputs=xyz,
retain_graph=True,
create_graph=True,
)[0]
l1 = torch.sum(d1)

# case 1: autograd.grad raises an error
# being the first time the second derivatives are requested, and since
# `backward_second_derivatives=False`, a warning is also displayed
with pytest.warns(
UserWarning,
match="Second derivatives of the spherical harmonics with respect "
"to the Cartesian coordinates were not requested at class creation.",
):
with pytest.raises(
RuntimeError,
match="One of the differentiated Tensors appears to not have "
"been used in the graph. Set allow_unused=True if this is the "
"desired behavior.",
):
torch.autograd.grad(
outputs=l1,
inputs=xyz,
retain_graph=True,
create_graph=True,
)

# case 2: autograd.grad with allow_unused=True fails silently
torch.autograd.grad(
outputs=l1,
inputs=xyz,
retain_graph=True,
create_graph=True,
allow_unused=True,
)

# case 3: backward fails silently
l1.backward()

# case 4: autograd.functional.hessian fails silently
def f(xyz): # dummy function
sph = calculator.compute(xyz)[0] # Discard sample dimension
return torch.sum(sph)

torch.autograd.functional.hessian(f, xyz)[
0, :, 0, :
] # Discard the two sample dimensions


def test_third_derivative_error(xyz):
# Initialize a calculator
calculator = sphericart.torch.SphericalHarmonics(
l_max=8, backward_second_derivatives=True
)

# Fill a single xyz point with arbitrary numbers
xyz = torch.tensor(
[
[0.67, 0.53, -0.22],
],
requires_grad=True,
)

# Compute the spherical harmonics and run backward 3 times.
# The third one must raise.
sph = calculator.compute(xyz)
l0 = torch.sum(sph)
d1 = torch.autograd.grad(
outputs=l0,
inputs=xyz,
retain_graph=True,
create_graph=True,
)[0]
l1 = torch.sum(d1)
d2 = torch.autograd.grad(
outputs=l1,
inputs=xyz,
retain_graph=True,
create_graph=True,
)[0]
s2 = torch.sum(d2)
with pytest.raises(
RuntimeError,
match="element 0 of tensors does not require grad and does not have a grad_fn",
):
torch.autograd.grad(
outputs=s2,
inputs=xyz,
retain_graph=False,
create_graph=False,
)
with pytest.raises(
RuntimeError,
match="element 0 of tensors does not require grad and does not have a grad_fn",
):
s2.backward()
14 changes: 13 additions & 1 deletion sphericart-torch/src/autograd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,19 @@ std::vector<torch::Tensor> SphericartAutogradBackward::backward(
auto gradgrad_wrt_xyz = torch::Tensor();

bool double_backward = ddsph.defined(); // If the double backward was not requested in
// advance, this tensor will be uninitialized
// advance, this tensor will be uninitialized

if (!double_backward) {
TORCH_WARN_ONCE(
"Second derivatives of the spherical harmonics with respect to the Cartesian "
"coordinates were not requested at class creation. The second derivative of "
"the spherical harmonics with respect to the Cartesian coordinates will be "
"treated as zero, potentially causing incorrect results. Make sure you either "
"do not need (i.e., are not using) these second derivatives, or that you set "
"`backward_second_derivatives=True` when creating the SphericalHarmonics or "
"SolidHarmonics class."
);
}

if (grad_out.requires_grad()) {
// gradgrad_wrt_grad_out, unlike gradgrad_wrt_xyz, is needed for mixed
Expand Down

0 comments on commit 2bee94d

Please sign in to comment.