From 53ccd15190977976513417fd926b7182e000a8fd Mon Sep 17 00:00:00 2001 From: Boris Bonev Date: Sun, 17 Dec 2023 12:40:29 +0100 Subject: [PATCH 1/8] Moved convolutions and exposed them directly --- torch_harmonics/__init__.py | 3 +-- torch_harmonics/{s2_convolutions.py => convolution.py} | 0 2 files changed, 1 insertion(+), 2 deletions(-) rename torch_harmonics/{s2_convolutions.py => convolution.py} (100%) diff --git a/torch_harmonics/__init__.py b/torch_harmonics/__init__.py index 806ce3a..3366517 100644 --- a/torch_harmonics/__init__.py +++ b/torch_harmonics/__init__.py @@ -32,8 +32,7 @@ __version__ = '0.6.4' from .sht import RealSHT, InverseRealSHT, RealVectorSHT, InverseRealVectorSHT +from .convolution import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2 from . import quadrature -from . import s2_convolutions -from . import disco_convolutions from . import random_fields from . import examples diff --git a/torch_harmonics/s2_convolutions.py b/torch_harmonics/convolution.py similarity index 100% rename from torch_harmonics/s2_convolutions.py rename to torch_harmonics/convolution.py From 4b623fbe16d0ee8aa172dfad931f712139f3ef06 Mon Sep 17 00:00:00 2001 From: Boris Bonev Date: Sun, 17 Dec 2023 15:47:32 +0100 Subject: [PATCH 2/8] Added transposition to the unit test --- tests/test_convolution.py | 220 ++++++++++++++++++ ..._convolutions.py => _disco_convolution.py} | 4 +- torch_harmonics/convolution.py | 2 +- 3 files changed, 223 insertions(+), 3 deletions(-) create mode 100644 tests/test_convolution.py rename torch_harmonics/{disco_convolutions.py => _disco_convolution.py} (99%) diff --git a/tests/test_convolution.py b/tests/test_convolution.py new file mode 100644 index 0000000..b544523 --- /dev/null +++ b/tests/test_convolution.py @@ -0,0 +1,220 @@ +# coding=utf-8 + +# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# + +import unittest +from parameterized import parameterized +from functools import partial +import math +import numpy as np +import torch +from torch.autograd import gradcheck +from torch_harmonics import * + + +def _compute_vals_isotropic(theta: torch.Tensor, phi: torch.Tensor, kernel_size: int, theta_cutoff: float): + """ + helper routine to compute the support but densely + """ + + # compute the support + dtheta = (theta_cutoff - 0.0) / kernel_size + ikernel = torch.arange(kernel_size).reshape(-1, 1, 1) + itheta = ikernel * dtheta + + norm_factor = ( + 2 + * math.pi + * ( + 1 + - math.cos(theta_cutoff - dtheta) + + math.cos(theta_cutoff - dtheta) + + (math.sin(theta_cutoff - dtheta) - math.sin(theta_cutoff)) / dtheta + ) + ) + + vals = torch.where( + ((theta - itheta).abs() <= dtheta) & (theta <= theta_cutoff), + (1 - (theta - itheta).abs() / dtheta) / norm_factor, + 0, + ) + return vals + + +def _precompute_convolution_tensor_dense( + in_shape, out_shape, kernel_shape, grid_in="equiangular", grid_out="equiangular", theta_cutoff=0.01 * math.pi +): + """ + Helper routine to compute the convolution Tensor in a dense fashion + """ + + assert len(in_shape) == 2 + assert len(out_shape) == 2 + + if len(kernel_shape) == 1: + kernel_handle = partial(_compute_vals_isotropic, kernel_size=kernel_shape[0], theta_cutoff=theta_cutoff) + else: + raise ValueError("kernel_shape should be either one- or two-dimensional.") + + nlat_in, nlon_in = in_shape + nlat_out, nlon_out = out_shape + + lats_in, _ = quadrature._precompute_latitudes(nlat_in, grid=grid_in) + lats_in = torch.from_numpy(lats_in).float() + lats_out, _ = quadrature._precompute_latitudes(nlat_out, grid=grid_out) + lats_out = torch.from_numpy(lats_out).float() # array for accumulating non-zero indices + + # compute the phi differences. We need to make the linspace exclusive to not double the last point + lons_in = torch.linspace(0, 2 * math.pi, nlon_in + 1)[:-1] + lons_out = torch.linspace(0, 2 * math.pi, nlon_out + 1)[:-1] + + out = torch.zeros(kernel_shape[0], nlat_out, nlon_out, nlat_in, nlon_in) + + for t in range(nlat_out): + for p in range(nlon_out): + alpha = -lats_out[t] + beta = lons_in - lons_out[p] + gamma = lats_in.reshape(-1, 1) + + # compute latitude of the rotated position + z = -torch.cos(beta) * torch.sin(alpha) * torch.sin(gamma) + torch.cos(alpha) * torch.cos(gamma) + + # compute cartesian coordinates of the rotated position + x = torch.cos(alpha) * torch.cos(beta) * torch.sin(gamma) + torch.cos(gamma) * torch.sin(alpha) + y = torch.sin(beta) * torch.sin(gamma) + + # normalize instead of clipping to ensure correct range + norm = torch.sqrt(x * x + y * y + z * z) + x = x / norm + y = y / norm + z = z / norm + + # compute spherical coordinates + theta = torch.arccos(z) + phi = torch.arctan2(y, x) + + # find the indices where the rotated position falls into the support of the kernel + out[:, t, p, :, :] = kernel_handle(theta, phi) + + return out + + +class TestDiscreteContinuousConvolution(unittest.TestCase): + def setUp(self): + if torch.cuda.is_available(): + self.device = torch.device("cuda") + else: + self.device = torch.device("cpu") + + @parameterized.expand( + [ + # regular convolution + [8, 4, 2, (16, 32), (16, 32), [2], "equiangular", "equiangular", False, 1e-5], + [8, 4, 2, (16, 32), (8, 16), [3], "equiangular", "equiangular", False, 1e-5], + [8, 4, 2, (16, 32), (16, 32), [2], "equiangular", "legendre-gauss", False, 1e-5], + [8, 4, 2, (16, 32), (16, 32), [2], "legendre-gauss", "equiangular", False, 1e-5], + [8, 4, 2, (16, 32), (16, 32), [2], "legendre-gauss", "legendre-gauss", False, 1e-5], + # transpose convolution + [8, 4, 2, (16, 32), (16, 32), [2], "equiangular", "equiangular", True, 1e-5], + [8, 4, 2, (8, 16), (16, 32), [3], "equiangular", "equiangular", True, 1e-5], + [8, 4, 2, (16, 32), (16, 32), [2], "equiangular", "legendre-gauss", True, 1e-5], + [8, 4, 2, (16, 32), (16, 32), [2], "legendre-gauss", "equiangular", True, 1e-5], + [8, 4, 2, (16, 32), (16, 32), [2], "legendre-gauss", "legendre-gauss", True, 1e-5], + ] + ) + def test_disco_convolution( + self, + batch_size, + in_channels, + out_channels, + in_shape, + out_shape, + kernel_shape, + grid_in, + grid_out, + transpose, + tol, + ): + Conv = DiscreteContinuousConvTransposeS2 if transpose else DiscreteContinuousConvS2 + conv = Conv( + in_channels, + out_channels, + in_shape, + out_shape, + kernel_shape, + groups=1, + grid_in=grid_in, + grid_out=grid_out, + bias=False, + ) + + nlat_in, nlon_in = in_shape + nlat_out, nlon_out = out_shape + + theta_cutoff = (kernel_shape[0] + 1) * torch.pi / float(nlat_in - 1) + + if transpose: + psi_dense = _precompute_convolution_tensor_dense( + out_shape, in_shape, kernel_shape, grid_in=grid_out, grid_out=grid_in, theta_cutoff=theta_cutoff + ) + else: + psi_dense = _precompute_convolution_tensor_dense( + in_shape, out_shape, kernel_shape, grid_in=grid_in, grid_out=grid_out, theta_cutoff=theta_cutoff + ) + + self.assertTrue( + torch.allclose(conv.psi.to_dense().cpu(), psi_dense[:, :, 0].reshape(-1, nlat_out, nlat_in * nlon_in)) + ) + + x = torch.randn(batch_size, in_channels, *in_shape) + + + # perform the reference computation + x_ref = x.clone().detach() + x_ref.requires_grad_(True) + if transpose: + y_ref = torch.einsum("oif,biqr->bofqr", conv.weight, x_ref) + y_ref = torch.einsum("fqrtp,bofqr->botp", psi_dense, y_ref * conv.quad_weights) + else: + y_ref = torch.einsum("ftpqr,bcqr->bcftp", psi_dense, x_ref * conv.quad_weights) + y_ref = torch.einsum("oif,biftp->botp", conv.weight, y_ref) + + # use the convolution module + y = conv(x) + + # print + print((y - y_ref).abs().max()) + + # compare result + self.assertTrue(torch.allclose(y, y_ref, rtol=tol, atol=tol)) + + +if __name__ == "__main__": + unittest.main() diff --git a/torch_harmonics/disco_convolutions.py b/torch_harmonics/_disco_convolution.py similarity index 99% rename from torch_harmonics/disco_convolutions.py rename to torch_harmonics/_disco_convolution.py index cb6667b..fd32135 100644 --- a/torch_harmonics/disco_convolutions.py +++ b/torch_harmonics/_disco_convolution.py @@ -377,7 +377,7 @@ def _disco_s2_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nlon_out: in assert psi.shape[-1] == nlat_in * nlon_in assert nlon_in % nlon_out == 0 - + assert nlon_in >= nlat_out pscale = nlon_in // nlon_out # add a dummy dimension for nkernel and move the batch and channel dims to the end @@ -414,7 +414,7 @@ def _disco_s2_transpose_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nl assert psi.shape[-2] == nlat_in assert n_out % nlon_out == 0 nlat_out = n_out // nlon_out - + assert nlon_out >= nlat_in pscale = nlon_out // nlon_in # we do a semi-transposition to faciliate the computation diff --git a/torch_harmonics/convolution.py b/torch_harmonics/convolution.py index 031da33..e87e670 100644 --- a/torch_harmonics/convolution.py +++ b/torch_harmonics/convolution.py @@ -39,7 +39,7 @@ from functools import partial from torch_harmonics.quadrature import _precompute_latitudes -from torch_harmonics.disco_convolutions import ( +from torch_harmonics._disco_convolution import ( _disco_s2_contraction_torch, _disco_s2_transpose_contraction_torch, _disco_s2_contraction_triton, From dadad7fa4631236601465e3b9ca9c2871b5b04d3 Mon Sep 17 00:00:00 2001 From: Boris Bonev Date: Mon, 18 Dec 2023 15:27:33 +0100 Subject: [PATCH 3/8] Minor bugfix in CPU version of DISCO transpose code --- tests/test_convolution.py | 21 +++++++++++++-------- torch_harmonics/_disco_convolution.py | 2 +- torch_harmonics/convolution.py | 2 +- 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/tests/test_convolution.py b/tests/test_convolution.py index b544523..63ed7f0 100644 --- a/tests/test_convolution.py +++ b/tests/test_convolution.py @@ -138,15 +138,17 @@ def setUp(self): # regular convolution [8, 4, 2, (16, 32), (16, 32), [2], "equiangular", "equiangular", False, 1e-5], [8, 4, 2, (16, 32), (8, 16), [3], "equiangular", "equiangular", False, 1e-5], - [8, 4, 2, (16, 32), (16, 32), [2], "equiangular", "legendre-gauss", False, 1e-5], - [8, 4, 2, (16, 32), (16, 32), [2], "legendre-gauss", "equiangular", False, 1e-5], - [8, 4, 2, (16, 32), (16, 32), [2], "legendre-gauss", "legendre-gauss", False, 1e-5], + [8, 4, 2, (18, 36), (6, 12), [4], "equiangular", "equiangular", False, 1e-5], + [8, 4, 2, (16, 32), (8, 16), [3], "equiangular", "legendre-gauss", False, 1e-5], + [8, 4, 2, (16, 32), (8, 16), [3], "legendre-gauss", "equiangular", False, 1e-5], + [8, 4, 2, (16, 32), (8, 16), [3], "legendre-gauss", "legendre-gauss", False, 1e-5], # transpose convolution [8, 4, 2, (16, 32), (16, 32), [2], "equiangular", "equiangular", True, 1e-5], [8, 4, 2, (8, 16), (16, 32), [3], "equiangular", "equiangular", True, 1e-5], - [8, 4, 2, (16, 32), (16, 32), [2], "equiangular", "legendre-gauss", True, 1e-5], - [8, 4, 2, (16, 32), (16, 32), [2], "legendre-gauss", "equiangular", True, 1e-5], - [8, 4, 2, (16, 32), (16, 32), [2], "legendre-gauss", "legendre-gauss", True, 1e-5], + [8, 4, 2, (6, 12), (18, 36), [4], "equiangular", "equiangular", True, 1e-5], + [8, 4, 2, (8, 16), (16, 32), [3], "equiangular", "legendre-gauss", True, 1e-5], + [8, 4, 2, (8, 16), (16, 32), [3], "legendre-gauss", "equiangular", True, 1e-5], + [8, 4, 2, (8, 16), (16, 32), [3], "legendre-gauss", "legendre-gauss", True, 1e-5], ] ) def test_disco_convolution( @@ -209,12 +211,15 @@ def test_disco_convolution( # use the convolution module y = conv(x) - # print - print((y - y_ref).abs().max()) + # # print + # print((y - y_ref).abs().max()) # compare result self.assertTrue(torch.allclose(y, y_ref, rtol=tol, atol=tol)) + # compute gradients and compare results + + if __name__ == "__main__": unittest.main() diff --git a/torch_harmonics/_disco_convolution.py b/torch_harmonics/_disco_convolution.py index fd32135..2f331be 100644 --- a/torch_harmonics/_disco_convolution.py +++ b/torch_harmonics/_disco_convolution.py @@ -429,7 +429,7 @@ def _disco_s2_transpose_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nl # interleave zeros along the longitude dimension to allow for fractional offsets to be considered x_ext = torch.zeros(kernel_size, nlat_in, nlon_out, batch_size * n_chans, device=x.device, dtype=x.dtype) - x_ext[:, :, (pscale-1)::pscale, :] = x.reshape(batch_size * n_chans, kernel_size, nlat_in, nlon_in).permute(1, 2, 3, 0) + x_ext[:, :, ::pscale, :] = x.reshape(batch_size * n_chans, kernel_size, nlat_in, nlon_in).permute(1, 2, 3, 0) # we need to go backwards through the vector, so we flip the axis x_ext = x_ext.contiguous() diff --git a/torch_harmonics/convolution.py b/torch_harmonics/convolution.py index e87e670..1855ecc 100644 --- a/torch_harmonics/convolution.py +++ b/torch_harmonics/convolution.py @@ -310,7 +310,7 @@ def __init__( def forward(self, x: torch.Tensor, use_triton_kernel: bool = True) -> torch.Tensor: # extract shape - B, F, H, W = x.shape + B, C, H, W = x.shape x = x.reshape(B, self.groups, self.groupsize, H, W) # do weight multiplication From 48e26b09182c67803a007e08972ad396696f50bc Mon Sep 17 00:00:00 2001 From: Boris Bonev Date: Mon, 18 Dec 2023 15:39:14 +0100 Subject: [PATCH 4/8] Adding convolution tests to CI --- .github/workflows/tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index ed8f343..095723c 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -23,4 +23,4 @@ jobs: - name: Test with pytest run: | python -m pip install pytest pytest-cov parameterized - python -m pytest --cov-report term --cov-config=.coveragerc --cov=torch_harmonics ./tests/test_sht.py \ No newline at end of file + python -m pytest --cov-report term --cov-config=.coveragerc --cov=torch_harmonics ./tests/test_sht.py ./tests/test_convolution.py \ No newline at end of file From fb7585556affb0dd7a1516c2479bf15b3ac525ce Mon Sep 17 00:00:00 2001 From: Boris Bonev Date: Mon, 18 Dec 2023 16:21:12 +0100 Subject: [PATCH 5/8] Added gradient check --- tests/test_convolution.py | 27 +++++++++++++++------------ tests/test_sht.py | 4 ++-- 2 files changed, 17 insertions(+), 14 deletions(-) diff --git a/tests/test_convolution.py b/tests/test_convolution.py index 63ed7f0..78dd896 100644 --- a/tests/test_convolution.py +++ b/tests/test_convolution.py @@ -133,6 +133,8 @@ def setUp(self): else: self.device = torch.device("cpu") + self.device = torch.device("cpu") + @parameterized.expand( [ # regular convolution @@ -175,7 +177,7 @@ def test_disco_convolution( grid_in=grid_in, grid_out=grid_out, bias=False, - ) + ).to(self.device) nlat_in, nlon_in = in_shape nlat_out, nlon_out = out_shape @@ -185,18 +187,18 @@ def test_disco_convolution( if transpose: psi_dense = _precompute_convolution_tensor_dense( out_shape, in_shape, kernel_shape, grid_in=grid_out, grid_out=grid_in, theta_cutoff=theta_cutoff - ) + ).to(self.device) else: psi_dense = _precompute_convolution_tensor_dense( in_shape, out_shape, kernel_shape, grid_in=grid_in, grid_out=grid_out, theta_cutoff=theta_cutoff - ) + ).to(self.device) self.assertTrue( - torch.allclose(conv.psi.to_dense().cpu(), psi_dense[:, :, 0].reshape(-1, nlat_out, nlat_in * nlon_in)) + torch.allclose(conv.psi.to_dense(), psi_dense[:, :, 0].reshape(-1, nlat_out, nlat_in * nlon_in)) ) - x = torch.randn(batch_size, in_channels, *in_shape) - + # create an input signal + x = torch.randn(batch_size, in_channels, *in_shape, requires_grad=True).to(self.device) # perform the reference computation x_ref = x.clone().detach() @@ -211,15 +213,16 @@ def test_disco_convolution( # use the convolution module y = conv(x) - # # print - # print((y - y_ref).abs().max()) - - # compare result + # compare results self.assertTrue(torch.allclose(y, y_ref, rtol=tol, atol=tol)) # compute gradients and compare results - - + grad_input = torch.randn_like(y) + y_ref.backward(grad_input) + y.backward(grad_input) + + # compare + self.assertTrue(torch.allclose(x.grad, x_ref.grad, rtol=tol, atol=tol)) if __name__ == "__main__": unittest.main() diff --git a/tests/test_sht.py b/tests/test_sht.py index 9d50778..62e13ae 100644 --- a/tests/test_sht.py +++ b/tests/test_sht.py @@ -149,9 +149,9 @@ def test_sht_grad(self, nlat, nlon, batch_size, norm, grid, tol): coeffs[:, :lmax, :mmax] = torch.randn(batch_size, lmax, mmax, device=self.device, dtype=torch.complex128) signal = isht(coeffs) - input = torch.randn_like(signal, requires_grad=True) + grad_input = torch.randn_like(signal, requires_grad=True) err_handle = lambda x : torch.mean(torch.norm( isht(sht(x)) - signal , p='fro', dim=(-1,-2)) / torch.norm(signal, p='fro', dim=(-1,-2)) ) - test_result = gradcheck(err_handle, input, eps=1e-6, atol=tol) + test_result = gradcheck(err_handle, grad_input, eps=1e-6, atol=tol) self.assertTrue(test_result) From e36421928f6cacab7f61a9c85591e5e593f3e8dc Mon Sep 17 00:00:00 2001 From: Boris Bonev Date: Mon, 18 Dec 2023 16:23:06 +0100 Subject: [PATCH 6/8] Checking the weight grad as well --- tests/test_convolution.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/tests/test_convolution.py b/tests/test_convolution.py index 78dd896..ff672cb 100644 --- a/tests/test_convolution.py +++ b/tests/test_convolution.py @@ -197,6 +197,10 @@ def test_disco_convolution( torch.allclose(conv.psi.to_dense(), psi_dense[:, :, 0].reshape(-1, nlat_out, nlat_in * nlon_in)) ) + # create a copy of the weight + w_ref = conv.weight.detach().clone() + w_ref.requires_grad_(True) + # create an input signal x = torch.randn(batch_size, in_channels, *in_shape, requires_grad=True).to(self.device) @@ -204,11 +208,11 @@ def test_disco_convolution( x_ref = x.clone().detach() x_ref.requires_grad_(True) if transpose: - y_ref = torch.einsum("oif,biqr->bofqr", conv.weight, x_ref) + y_ref = torch.einsum("oif,biqr->bofqr", w_ref, x_ref) y_ref = torch.einsum("fqrtp,bofqr->botp", psi_dense, y_ref * conv.quad_weights) else: y_ref = torch.einsum("ftpqr,bcqr->bcftp", psi_dense, x_ref * conv.quad_weights) - y_ref = torch.einsum("oif,biftp->botp", conv.weight, y_ref) + y_ref = torch.einsum("oif,biftp->botp", w_ref, y_ref) # use the convolution module y = conv(x) @@ -220,9 +224,10 @@ def test_disco_convolution( grad_input = torch.randn_like(y) y_ref.backward(grad_input) y.backward(grad_input) - + # compare self.assertTrue(torch.allclose(x.grad, x_ref.grad, rtol=tol, atol=tol)) + self.assertTrue(torch.allclose(conv.weight.grad, w_ref.grad, rtol=tol, atol=tol)) if __name__ == "__main__": unittest.main() From 51451b0105db3c58c731068802d046870ee26a72 Mon Sep 17 00:00:00 2001 From: Boris Bonev Date: Tue, 19 Dec 2023 17:30:27 +0100 Subject: [PATCH 7/8] Added test for anisotropic kernels --- Changelog.md | 15 ++++++-- tests/test_convolution.py | 43 +++++++++++++++++++---- torch_harmonics/convolution.py | 63 +++++++++++++++++++++++++--------- 3 files changed, 95 insertions(+), 26 deletions(-) diff --git a/Changelog.md b/Changelog.md index 79459ed..d5c0e75 100644 --- a/Changelog.md +++ b/Changelog.md @@ -2,10 +2,19 @@ ## Versioning +### v0.6.5 + +* Discrrete-continuous (DISCO) convolutions on the sphere +* Isotropic and anisotropic DISCO convolutions +* Accelerated DISCO convolutions on GPU via Triton implementation +* Unittests for DISCO convolutions + ### v0.6.4 -* reworking distributed to allow for uneven split tensors, effectively removing the necessity of padding the transformed tensors -* distributed SHT tests are now using unittest. Test extended to vector SHT versions. Tests are defined in `torch_harmonics/distributed/distributed_tests.py` -* base pytorch container version bumped up to 23.11 in Dockerfile + +* Reworking distributed to allow for uneven split tensors, effectively removing the necessity of padding the transformed tensors +* Distributed SHT tests are now using unittest. Test extended to vector SHT versions +* Tests are defined in `torch_harmonics/distributed/distributed_tests.py` +* Base pytorch container version bumped up to 23.11 in Dockerfile ### v0.6.3 diff --git a/tests/test_convolution.py b/tests/test_convolution.py index ff672cb..37781cf 100644 --- a/tests/test_convolution.py +++ b/tests/test_convolution.py @@ -39,14 +39,14 @@ from torch_harmonics import * -def _compute_vals_isotropic(theta: torch.Tensor, phi: torch.Tensor, kernel_size: int, theta_cutoff: float): +def _compute_vals_isotropic(theta: torch.Tensor, phi: torch.Tensor, ntheta: int, theta_cutoff: float): """ - helper routine to compute the support but densely + helper routine to compute the values of the isotropic kernel densely """ # compute the support - dtheta = (theta_cutoff - 0.0) / kernel_size - ikernel = torch.arange(kernel_size).reshape(-1, 1, 1) + dtheta = (theta_cutoff - 0.0) / ntheta + ikernel = torch.arange(ntheta).reshape(-1, 1, 1) itheta = ikernel * dtheta norm_factor = ( @@ -67,6 +67,28 @@ def _compute_vals_isotropic(theta: torch.Tensor, phi: torch.Tensor, kernel_size: ) return vals +def _compute_vals_anisotropic(theta: torch.Tensor, phi: torch.Tensor, ntheta: int, nphi: int, theta_cutoff: float): + """ + helper routine to compute the values of the anisotropic kernel densely + """ + + # compute the support + dtheta = (theta_cutoff - 0.0) / ntheta + dphi = 2.0 * math.pi / nphi + kernel_size = (ntheta-1)*nphi + 1 + ikernel = torch.arange(kernel_size).reshape(-1, 1, 1) + itheta = ((ikernel - 1) // nphi + 1) * dtheta + iphi = ((ikernel - 1) % nphi) * dphi + + norm_factor = 2 * math.pi * (1 - math.cos(theta_cutoff - dtheta) + math.cos(theta_cutoff - dtheta) + (math.sin(theta_cutoff - dtheta) - math.sin(theta_cutoff)) / dtheta) + + # find the indices where the rotated position falls into the support of the kernel + cond_theta = ((theta - itheta).abs() <= dtheta) & (theta <= theta_cutoff) + cond_phi = ((phi - iphi).abs() <= dphi) | ((2*math.pi - (phi - iphi).abs()) <= dphi) + theta_vals = torch.where(cond_theta, (1 - (theta - itheta).abs() / dtheta) / norm_factor, 0.0) + phi_vals = torch.where(cond_phi, (1 - torch.minimum((phi - iphi).abs(), (2*math.pi - (phi - iphi).abs()) ) / dphi ), 0.0) + vals = torch.where(ikernel > 0, theta_vals * phi_vals, theta_vals) + return vals def _precompute_convolution_tensor_dense( in_shape, out_shape, kernel_shape, grid_in="equiangular", grid_out="equiangular", theta_cutoff=0.01 * math.pi @@ -79,7 +101,11 @@ def _precompute_convolution_tensor_dense( assert len(out_shape) == 2 if len(kernel_shape) == 1: - kernel_handle = partial(_compute_vals_isotropic, kernel_size=kernel_shape[0], theta_cutoff=theta_cutoff) + kernel_handle = partial(_compute_vals_isotropic, ntheta=kernel_shape[0], theta_cutoff=theta_cutoff) + kernel_size = kernel_shape[0] + elif len(kernel_shape) == 2: + kernel_handle = partial(_compute_vals_anisotropic, ntheta=kernel_shape[0], nphi=kernel_shape[1], theta_cutoff=theta_cutoff) + kernel_size = (kernel_shape[0]-1)*kernel_shape[1] + 1 else: raise ValueError("kernel_shape should be either one- or two-dimensional.") @@ -95,7 +121,7 @@ def _precompute_convolution_tensor_dense( lons_in = torch.linspace(0, 2 * math.pi, nlon_in + 1)[:-1] lons_out = torch.linspace(0, 2 * math.pi, nlon_out + 1)[:-1] - out = torch.zeros(kernel_shape[0], nlat_out, nlon_out, nlat_in, nlon_in) + out = torch.zeros(kernel_size, nlat_out, nlon_out, nlat_in, nlon_in) for t in range(nlat_out): for p in range(nlon_out): @@ -118,7 +144,7 @@ def _precompute_convolution_tensor_dense( # compute spherical coordinates theta = torch.arccos(z) - phi = torch.arctan2(y, x) + phi = torch.arctan2(y, x) + torch.pi # find the indices where the rotated position falls into the support of the kernel out[:, t, p, :, :] = kernel_handle(theta, phi) @@ -140,6 +166,7 @@ def setUp(self): # regular convolution [8, 4, 2, (16, 32), (16, 32), [2], "equiangular", "equiangular", False, 1e-5], [8, 4, 2, (16, 32), (8, 16), [3], "equiangular", "equiangular", False, 1e-5], + [8, 4, 2, (16, 32), (8, 16), [2, 3], "equiangular", "equiangular", False, 1e-5], [8, 4, 2, (18, 36), (6, 12), [4], "equiangular", "equiangular", False, 1e-5], [8, 4, 2, (16, 32), (8, 16), [3], "equiangular", "legendre-gauss", False, 1e-5], [8, 4, 2, (16, 32), (8, 16), [3], "legendre-gauss", "equiangular", False, 1e-5], @@ -147,6 +174,7 @@ def setUp(self): # transpose convolution [8, 4, 2, (16, 32), (16, 32), [2], "equiangular", "equiangular", True, 1e-5], [8, 4, 2, (8, 16), (16, 32), [3], "equiangular", "equiangular", True, 1e-5], + [8, 4, 2, (8, 16), (16, 32), [2, 3], "equiangular", "equiangular", True, 1e-5], [8, 4, 2, (6, 12), (18, 36), [4], "equiangular", "equiangular", True, 1e-5], [8, 4, 2, (8, 16), (16, 32), [3], "equiangular", "legendre-gauss", True, 1e-5], [8, 4, 2, (8, 16), (16, 32), [3], "legendre-gauss", "equiangular", True, 1e-5], @@ -202,6 +230,7 @@ def test_disco_convolution( w_ref.requires_grad_(True) # create an input signal + torch.manual_seed(333) x = torch.randn(batch_size, in_channels, *in_shape, requires_grad=True).to(self.device) # perform the reference computation diff --git a/torch_harmonics/convolution.py b/torch_harmonics/convolution.py index 1855ecc..c3fd8bf 100644 --- a/torch_harmonics/convolution.py +++ b/torch_harmonics/convolution.py @@ -47,14 +47,14 @@ ) -def _compute_support_vals_isotropic(theta: torch.Tensor, phi: torch.Tensor, kernel_size: int, theta_cutoff: float): +def _compute_support_vals_isotropic(theta: torch.Tensor, phi: torch.Tensor, ntheta: int, theta_cutoff: float): """ Computes the index set that falls into the isotropic kernel's support and returns both indices and values. """ # compute the support - dtheta = (theta_cutoff - 0.0) / kernel_size - ikernel = torch.arange(kernel_size).reshape(-1, 1, 1) + dtheta = (theta_cutoff - 0.0) / ntheta + ikernel = torch.arange(ntheta).reshape(-1, 1, 1) itheta = ikernel * dtheta norm_factor = 2 * math.pi * (1 - math.cos(theta_cutoff - dtheta) + math.cos(theta_cutoff - dtheta) + (math.sin(theta_cutoff - dtheta) - math.sin(theta_cutoff)) / dtheta) @@ -64,6 +64,29 @@ def _compute_support_vals_isotropic(theta: torch.Tensor, phi: torch.Tensor, kern vals = (1 - (theta[iidx[:, 1], iidx[:, 2]] - itheta[iidx[:, 0], 0, 0]).abs() / dtheta) / norm_factor return iidx, vals +def _compute_support_vals_anisotropic(theta: torch.Tensor, phi: torch.Tensor, ntheta: int, nphi: int, theta_cutoff: float): + """ + Computes the index set that falls into the anisotropic kernel's support and returns both indices and values. + """ + + # compute the support + dtheta = (theta_cutoff - 0.0) / ntheta + dphi = 2.0 * math.pi / nphi + kernel_size = (ntheta-1)*nphi + 1 + ikernel = torch.arange(kernel_size).reshape(-1, 1, 1) + itheta = ((ikernel - 1) // nphi + 1) * dtheta + iphi = ((ikernel - 1) % nphi) * dphi + + norm_factor = 2 * math.pi * (1 - math.cos(theta_cutoff - dtheta) + math.cos(theta_cutoff - dtheta) + (math.sin(theta_cutoff - dtheta) - math.sin(theta_cutoff)) / dtheta) + + # find the indices where the rotated position falls into the support of the kernel + cond_theta = ((theta - itheta).abs() <= dtheta) & (theta <= theta_cutoff) + cond_phi = (ikernel == 0) | ((phi - iphi).abs() <= dphi) | ((2*math.pi - (phi - iphi).abs()) <= dphi) + iidx = torch.argwhere(cond_theta & cond_phi) + vals = (1 - (theta[iidx[:, 1], iidx[:, 2]] - itheta[iidx[:, 0], 0, 0]).abs() / dtheta) / norm_factor + vals *= torch.where(iidx[:, 0] > 0, (1 - torch.minimum((phi[iidx[:, 1], iidx[:, 2]] - iphi[iidx[:, 0], 0, 0]).abs(), (2*math.pi - (phi[iidx[:, 1], iidx[:, 2]] - iphi[iidx[:, 0], 0, 0]).abs()) ) / dphi ), 1.0) + return iidx, vals + def _precompute_convolution_tensor( in_shape, out_shape, kernel_shape, grid_in="equiangular", grid_out="equiangular", theta_cutoff=0.01 * math.pi @@ -88,7 +111,9 @@ def _precompute_convolution_tensor( assert len(out_shape) == 2 if len(kernel_shape) == 1: - kernel_handle = partial(_compute_support_vals_isotropic, kernel_size=kernel_shape[0], theta_cutoff=theta_cutoff) + kernel_handle = partial(_compute_support_vals_isotropic, ntheta=kernel_shape[0], theta_cutoff=theta_cutoff) + elif len(kernel_shape) == 2: + kernel_handle = partial(_compute_support_vals_anisotropic, ntheta=kernel_shape[0], nphi=kernel_shape[1], theta_cutoff=theta_cutoff) else: raise ValueError("kernel_shape should be either one- or two-dimensional.") @@ -128,9 +153,9 @@ def _precompute_convolution_tensor( y = y / norm z = z / norm - # compute spherical coordinates + # compute spherical coordinates, where phi needs to fall into the [0, 2pi) range theta = torch.arccos(z) - phi = torch.arctan2(y, x) + phi = torch.arctan2(y, x) + torch.pi # find the indices where the rotated position falls into the support of the kernel iidx, vals = kernel_handle(theta, phi) @@ -146,8 +171,7 @@ def _precompute_convolution_tensor( # TODO: -# - parameter initialization -# - add anisotropy +# - derive conv and conv transpose from single module class DiscreteContinuousConvS2(nn.Module): """ Discrete-continuous convolutions (DISCO) on the 2-Sphere as described in [1]. @@ -175,9 +199,13 @@ def __init__( if isinstance(kernel_shape, int): kernel_shape = [kernel_shape] - self.kernel_size = 1 - for kdim in kernel_shape: - self.kernel_size *= kdim + if len(kernel_shape) == 1: + self.kernel_size = kernel_shape[0] + elif len(kernel_shape) == 2: + self.kernel_size = (kernel_shape[0]-1)*kernel_shape[1] + 1 + else: + raise ValueError("kernel_shape should be either one- or two-dimensional.") + # compute theta cutoff based on the bandlimit of the input field if theta_cutoff is None: @@ -209,7 +237,7 @@ def __init__( raise ValueError("Error, the number of output channels has to be an integer multiple of the group size") self.groupsize = in_channels // self.groups scale = math.sqrt(1.0 / self.groupsize) - self.weight = nn.Parameter(scale * torch.randn(out_channels, self.groupsize, kernel_shape[0])) + self.weight = nn.Parameter(scale * torch.randn(out_channels, self.groupsize, self.kernel_size)) if bias: self.bias = nn.Parameter(torch.zeros(out_channels)) @@ -266,9 +294,12 @@ def __init__( if isinstance(kernel_shape, int): kernel_shape = [kernel_shape] - self.kernel_size = 1 - for kdim in kernel_shape: - self.kernel_size *= kdim + if len(kernel_shape) == 1: + self.kernel_size = kernel_shape[0] + elif len(kernel_shape) == 2: + self.kernel_size = (kernel_shape[0]-1)*kernel_shape[1] + 1 + else: + raise ValueError("kernel_shape should be either one- or two-dimensional.") # bandlimit if theta_cutoff is None: @@ -301,7 +332,7 @@ def __init__( raise ValueError("Error, the number of output channels has to be an integer multiple of the group size") self.groupsize = in_channels // self.groups scale = math.sqrt(1.0 / self.groupsize) - self.weight = nn.Parameter(scale * torch.randn(out_channels, self.groupsize, kernel_shape[0])) + self.weight = nn.Parameter(scale * torch.randn(out_channels, self.groupsize, self.kernel_size)) if bias: self.bias = nn.Parameter(torch.zeros(out_channels)) From 92b6f38f77a7296d236e9299befe3eeeaaff1a61 Mon Sep 17 00:00:00 2001 From: Boris Bonev Date: Fri, 22 Dec 2023 16:23:00 +0100 Subject: [PATCH 8/8] Changed the code to only implicitly use sparse tensors in the modules, in order to enable compatibility with DDP --- tests/test_convolution.py | 4 +++- torch_harmonics/convolution.py | 32 ++++++++++++++++++++------------ 2 files changed, 23 insertions(+), 13 deletions(-) diff --git a/tests/test_convolution.py b/tests/test_convolution.py index 37781cf..114bb8c 100644 --- a/tests/test_convolution.py +++ b/tests/test_convolution.py @@ -221,8 +221,10 @@ def test_disco_convolution( in_shape, out_shape, kernel_shape, grid_in=grid_in, grid_out=grid_out, theta_cutoff=theta_cutoff ).to(self.device) + psi = torch.sparse_coo_tensor(conv.psi_idx, conv.psi_vals, size=(conv.kernel_size, conv.nlat_out, conv.nlat_in * conv.nlon_in)).to_dense() + self.assertTrue( - torch.allclose(conv.psi.to_dense(), psi_dense[:, :, 0].reshape(-1, nlat_out, nlat_in * nlon_in)) + torch.allclose(psi, psi_dense[:, :, 0].reshape(-1, nlat_out, nlat_in * nlon_in)) ) # create a copy of the weight diff --git a/torch_harmonics/convolution.py b/torch_harmonics/convolution.py index c3fd8bf..363389e 100644 --- a/torch_harmonics/convolution.py +++ b/torch_harmonics/convolution.py @@ -222,10 +222,12 @@ def __init__( idx, vals = _precompute_convolution_tensor( in_shape, out_shape, kernel_shape, grid_in=grid_in, grid_out=grid_out, theta_cutoff=theta_cutoff ) - psi = torch.sparse_coo_tensor( - idx, vals, size=(self.kernel_size, self.nlat_out, self.nlat_in * self.nlon_in) - ).coalesce() - self.register_buffer("psi", psi, persistent=False) + # psi = torch.sparse_coo_tensor( + # idx, vals, size=(self.kernel_size, self.nlat_out, self.nlat_in * self.nlon_in) + # ).coalesce() + self.register_buffer("psi_idx", idx, persistent=False) + self.register_buffer("psi_vals", vals, persistent=False) + # self.register_buffer("psi", psi, persistent=False) # groups self.groups = groups @@ -248,10 +250,12 @@ def forward(self, x: torch.Tensor, use_triton_kernel: bool = True) -> torch.Tens # pre-multiply x with the quadrature weights x = self.quad_weights * x + psi = torch.sparse_coo_tensor(self.psi_idx, self.psi_vals, size=(self.kernel_size, self.nlat_out, self.nlat_in * self.nlon_in)).coalesce() + if x.is_cuda and use_triton_kernel: - x = _disco_s2_contraction_triton(x, self.psi, self.nlon_out) + x = _disco_s2_contraction_triton(x, psi, self.nlon_out) else: - x = _disco_s2_contraction_torch(x, self.psi, self.nlon_out) + x = _disco_s2_contraction_torch(x, psi, self.nlon_out) # extract shape B, C, K, H, W = x.shape @@ -317,10 +321,12 @@ def __init__( idx, vals = _precompute_convolution_tensor( out_shape, in_shape, kernel_shape, grid_in=grid_out, grid_out=grid_in, theta_cutoff=theta_cutoff ) - psi = torch.sparse_coo_tensor( - idx, vals, size=(self.kernel_size, self.nlat_in, self.nlat_out * self.nlon_out) - ).coalesce() - self.register_buffer("psi", psi, persistent=False) + # psi = torch.sparse_coo_tensor( + # idx, vals, size=(self.kernel_size, self.nlat_in, self.nlat_out * self.nlon_out) + # ).coalesce() + self.register_buffer("psi_idx", idx, persistent=False) + self.register_buffer("psi_vals", vals, persistent=False) + # self.register_buffer("psi", psi, persistent=False) # groups self.groups = groups @@ -351,10 +357,12 @@ def forward(self, x: torch.Tensor, use_triton_kernel: bool = True) -> torch.Tens # pre-multiply x with the quadrature weights x = self.quad_weights * x + psi = torch.sparse_coo_tensor(self.psi_idx, self.psi_vals, size=(self.kernel_size, self.nlat_in, self.nlat_out * self.nlon_out)).coalesce() + if x.is_cuda and use_triton_kernel: - out = _disco_s2_transpose_contraction_triton(x, self.psi, self.nlon_out) + out = _disco_s2_transpose_contraction_triton(x, psi, self.nlon_out) else: - out = _disco_s2_transpose_contraction_torch(x, self.psi, self.nlon_out) + out = _disco_s2_transpose_contraction_torch(x, psi, self.nlon_out) if self.bias is not None: out = out + self.bias.reshape(1, -1, 1, 1)