From 942aa4ead5c02ddcfb154dbd82742da3ea00c6d0 Mon Sep 17 00:00:00 2001 From: Boris Bonev Date: Wed, 20 Dec 2023 11:02:04 +0100 Subject: [PATCH] Bbonev/disco refactor (#27) * Moved convolutions and exposed them directly * Added transposition to the unit test * Minor bugfix in CPU version of DISCO transpose code * Adding convolution tests to CI * Added gradient check * Checking the weight grad as well * Added test for anisotropic kernels --- .github/workflows/tests.yml | 2 +- Changelog.md | 15 +- tests/test_convolution.py | 262 ++++++++++++++++++ tests/test_sht.py | 4 +- torch_harmonics/__init__.py | 3 +- ..._convolutions.py => _disco_convolution.py} | 6 +- .../{s2_convolutions.py => convolution.py} | 67 +++-- 7 files changed, 330 insertions(+), 29 deletions(-) create mode 100644 tests/test_convolution.py rename torch_harmonics/{disco_convolutions.py => _disco_convolution.py} (98%) rename torch_harmonics/{s2_convolutions.py => convolution.py} (83%) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index ed8f343..095723c 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -23,4 +23,4 @@ jobs: - name: Test with pytest run: | python -m pip install pytest pytest-cov parameterized - python -m pytest --cov-report term --cov-config=.coveragerc --cov=torch_harmonics ./tests/test_sht.py \ No newline at end of file + python -m pytest --cov-report term --cov-config=.coveragerc --cov=torch_harmonics ./tests/test_sht.py ./tests/test_convolution.py \ No newline at end of file diff --git a/Changelog.md b/Changelog.md index 79459ed..d5c0e75 100644 --- a/Changelog.md +++ b/Changelog.md @@ -2,10 +2,19 @@ ## Versioning +### v0.6.5 + +* Discrrete-continuous (DISCO) convolutions on the sphere +* Isotropic and anisotropic DISCO convolutions +* Accelerated DISCO convolutions on GPU via Triton implementation +* Unittests for DISCO convolutions + ### v0.6.4 -* reworking distributed to allow for uneven split tensors, effectively removing the necessity of padding the transformed tensors -* distributed SHT tests are now using unittest. Test extended to vector SHT versions. Tests are defined in `torch_harmonics/distributed/distributed_tests.py` -* base pytorch container version bumped up to 23.11 in Dockerfile + +* Reworking distributed to allow for uneven split tensors, effectively removing the necessity of padding the transformed tensors +* Distributed SHT tests are now using unittest. Test extended to vector SHT versions +* Tests are defined in `torch_harmonics/distributed/distributed_tests.py` +* Base pytorch container version bumped up to 23.11 in Dockerfile ### v0.6.3 diff --git a/tests/test_convolution.py b/tests/test_convolution.py new file mode 100644 index 0000000..37781cf --- /dev/null +++ b/tests/test_convolution.py @@ -0,0 +1,262 @@ +# coding=utf-8 + +# SPDX-FileCopyrightText: Copyright (c) 2022 The torch-harmonics Authors. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# + +import unittest +from parameterized import parameterized +from functools import partial +import math +import numpy as np +import torch +from torch.autograd import gradcheck +from torch_harmonics import * + + +def _compute_vals_isotropic(theta: torch.Tensor, phi: torch.Tensor, ntheta: int, theta_cutoff: float): + """ + helper routine to compute the values of the isotropic kernel densely + """ + + # compute the support + dtheta = (theta_cutoff - 0.0) / ntheta + ikernel = torch.arange(ntheta).reshape(-1, 1, 1) + itheta = ikernel * dtheta + + norm_factor = ( + 2 + * math.pi + * ( + 1 + - math.cos(theta_cutoff - dtheta) + + math.cos(theta_cutoff - dtheta) + + (math.sin(theta_cutoff - dtheta) - math.sin(theta_cutoff)) / dtheta + ) + ) + + vals = torch.where( + ((theta - itheta).abs() <= dtheta) & (theta <= theta_cutoff), + (1 - (theta - itheta).abs() / dtheta) / norm_factor, + 0, + ) + return vals + +def _compute_vals_anisotropic(theta: torch.Tensor, phi: torch.Tensor, ntheta: int, nphi: int, theta_cutoff: float): + """ + helper routine to compute the values of the anisotropic kernel densely + """ + + # compute the support + dtheta = (theta_cutoff - 0.0) / ntheta + dphi = 2.0 * math.pi / nphi + kernel_size = (ntheta-1)*nphi + 1 + ikernel = torch.arange(kernel_size).reshape(-1, 1, 1) + itheta = ((ikernel - 1) // nphi + 1) * dtheta + iphi = ((ikernel - 1) % nphi) * dphi + + norm_factor = 2 * math.pi * (1 - math.cos(theta_cutoff - dtheta) + math.cos(theta_cutoff - dtheta) + (math.sin(theta_cutoff - dtheta) - math.sin(theta_cutoff)) / dtheta) + + # find the indices where the rotated position falls into the support of the kernel + cond_theta = ((theta - itheta).abs() <= dtheta) & (theta <= theta_cutoff) + cond_phi = ((phi - iphi).abs() <= dphi) | ((2*math.pi - (phi - iphi).abs()) <= dphi) + theta_vals = torch.where(cond_theta, (1 - (theta - itheta).abs() / dtheta) / norm_factor, 0.0) + phi_vals = torch.where(cond_phi, (1 - torch.minimum((phi - iphi).abs(), (2*math.pi - (phi - iphi).abs()) ) / dphi ), 0.0) + vals = torch.where(ikernel > 0, theta_vals * phi_vals, theta_vals) + return vals + +def _precompute_convolution_tensor_dense( + in_shape, out_shape, kernel_shape, grid_in="equiangular", grid_out="equiangular", theta_cutoff=0.01 * math.pi +): + """ + Helper routine to compute the convolution Tensor in a dense fashion + """ + + assert len(in_shape) == 2 + assert len(out_shape) == 2 + + if len(kernel_shape) == 1: + kernel_handle = partial(_compute_vals_isotropic, ntheta=kernel_shape[0], theta_cutoff=theta_cutoff) + kernel_size = kernel_shape[0] + elif len(kernel_shape) == 2: + kernel_handle = partial(_compute_vals_anisotropic, ntheta=kernel_shape[0], nphi=kernel_shape[1], theta_cutoff=theta_cutoff) + kernel_size = (kernel_shape[0]-1)*kernel_shape[1] + 1 + else: + raise ValueError("kernel_shape should be either one- or two-dimensional.") + + nlat_in, nlon_in = in_shape + nlat_out, nlon_out = out_shape + + lats_in, _ = quadrature._precompute_latitudes(nlat_in, grid=grid_in) + lats_in = torch.from_numpy(lats_in).float() + lats_out, _ = quadrature._precompute_latitudes(nlat_out, grid=grid_out) + lats_out = torch.from_numpy(lats_out).float() # array for accumulating non-zero indices + + # compute the phi differences. We need to make the linspace exclusive to not double the last point + lons_in = torch.linspace(0, 2 * math.pi, nlon_in + 1)[:-1] + lons_out = torch.linspace(0, 2 * math.pi, nlon_out + 1)[:-1] + + out = torch.zeros(kernel_size, nlat_out, nlon_out, nlat_in, nlon_in) + + for t in range(nlat_out): + for p in range(nlon_out): + alpha = -lats_out[t] + beta = lons_in - lons_out[p] + gamma = lats_in.reshape(-1, 1) + + # compute latitude of the rotated position + z = -torch.cos(beta) * torch.sin(alpha) * torch.sin(gamma) + torch.cos(alpha) * torch.cos(gamma) + + # compute cartesian coordinates of the rotated position + x = torch.cos(alpha) * torch.cos(beta) * torch.sin(gamma) + torch.cos(gamma) * torch.sin(alpha) + y = torch.sin(beta) * torch.sin(gamma) + + # normalize instead of clipping to ensure correct range + norm = torch.sqrt(x * x + y * y + z * z) + x = x / norm + y = y / norm + z = z / norm + + # compute spherical coordinates + theta = torch.arccos(z) + phi = torch.arctan2(y, x) + torch.pi + + # find the indices where the rotated position falls into the support of the kernel + out[:, t, p, :, :] = kernel_handle(theta, phi) + + return out + + +class TestDiscreteContinuousConvolution(unittest.TestCase): + def setUp(self): + if torch.cuda.is_available(): + self.device = torch.device("cuda") + else: + self.device = torch.device("cpu") + + self.device = torch.device("cpu") + + @parameterized.expand( + [ + # regular convolution + [8, 4, 2, (16, 32), (16, 32), [2], "equiangular", "equiangular", False, 1e-5], + [8, 4, 2, (16, 32), (8, 16), [3], "equiangular", "equiangular", False, 1e-5], + [8, 4, 2, (16, 32), (8, 16), [2, 3], "equiangular", "equiangular", False, 1e-5], + [8, 4, 2, (18, 36), (6, 12), [4], "equiangular", "equiangular", False, 1e-5], + [8, 4, 2, (16, 32), (8, 16), [3], "equiangular", "legendre-gauss", False, 1e-5], + [8, 4, 2, (16, 32), (8, 16), [3], "legendre-gauss", "equiangular", False, 1e-5], + [8, 4, 2, (16, 32), (8, 16), [3], "legendre-gauss", "legendre-gauss", False, 1e-5], + # transpose convolution + [8, 4, 2, (16, 32), (16, 32), [2], "equiangular", "equiangular", True, 1e-5], + [8, 4, 2, (8, 16), (16, 32), [3], "equiangular", "equiangular", True, 1e-5], + [8, 4, 2, (8, 16), (16, 32), [2, 3], "equiangular", "equiangular", True, 1e-5], + [8, 4, 2, (6, 12), (18, 36), [4], "equiangular", "equiangular", True, 1e-5], + [8, 4, 2, (8, 16), (16, 32), [3], "equiangular", "legendre-gauss", True, 1e-5], + [8, 4, 2, (8, 16), (16, 32), [3], "legendre-gauss", "equiangular", True, 1e-5], + [8, 4, 2, (8, 16), (16, 32), [3], "legendre-gauss", "legendre-gauss", True, 1e-5], + ] + ) + def test_disco_convolution( + self, + batch_size, + in_channels, + out_channels, + in_shape, + out_shape, + kernel_shape, + grid_in, + grid_out, + transpose, + tol, + ): + Conv = DiscreteContinuousConvTransposeS2 if transpose else DiscreteContinuousConvS2 + conv = Conv( + in_channels, + out_channels, + in_shape, + out_shape, + kernel_shape, + groups=1, + grid_in=grid_in, + grid_out=grid_out, + bias=False, + ).to(self.device) + + nlat_in, nlon_in = in_shape + nlat_out, nlon_out = out_shape + + theta_cutoff = (kernel_shape[0] + 1) * torch.pi / float(nlat_in - 1) + + if transpose: + psi_dense = _precompute_convolution_tensor_dense( + out_shape, in_shape, kernel_shape, grid_in=grid_out, grid_out=grid_in, theta_cutoff=theta_cutoff + ).to(self.device) + else: + psi_dense = _precompute_convolution_tensor_dense( + in_shape, out_shape, kernel_shape, grid_in=grid_in, grid_out=grid_out, theta_cutoff=theta_cutoff + ).to(self.device) + + self.assertTrue( + torch.allclose(conv.psi.to_dense(), psi_dense[:, :, 0].reshape(-1, nlat_out, nlat_in * nlon_in)) + ) + + # create a copy of the weight + w_ref = conv.weight.detach().clone() + w_ref.requires_grad_(True) + + # create an input signal + torch.manual_seed(333) + x = torch.randn(batch_size, in_channels, *in_shape, requires_grad=True).to(self.device) + + # perform the reference computation + x_ref = x.clone().detach() + x_ref.requires_grad_(True) + if transpose: + y_ref = torch.einsum("oif,biqr->bofqr", w_ref, x_ref) + y_ref = torch.einsum("fqrtp,bofqr->botp", psi_dense, y_ref * conv.quad_weights) + else: + y_ref = torch.einsum("ftpqr,bcqr->bcftp", psi_dense, x_ref * conv.quad_weights) + y_ref = torch.einsum("oif,biftp->botp", w_ref, y_ref) + + # use the convolution module + y = conv(x) + + # compare results + self.assertTrue(torch.allclose(y, y_ref, rtol=tol, atol=tol)) + + # compute gradients and compare results + grad_input = torch.randn_like(y) + y_ref.backward(grad_input) + y.backward(grad_input) + + # compare + self.assertTrue(torch.allclose(x.grad, x_ref.grad, rtol=tol, atol=tol)) + self.assertTrue(torch.allclose(conv.weight.grad, w_ref.grad, rtol=tol, atol=tol)) + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_sht.py b/tests/test_sht.py index 9d50778..62e13ae 100644 --- a/tests/test_sht.py +++ b/tests/test_sht.py @@ -149,9 +149,9 @@ def test_sht_grad(self, nlat, nlon, batch_size, norm, grid, tol): coeffs[:, :lmax, :mmax] = torch.randn(batch_size, lmax, mmax, device=self.device, dtype=torch.complex128) signal = isht(coeffs) - input = torch.randn_like(signal, requires_grad=True) + grad_input = torch.randn_like(signal, requires_grad=True) err_handle = lambda x : torch.mean(torch.norm( isht(sht(x)) - signal , p='fro', dim=(-1,-2)) / torch.norm(signal, p='fro', dim=(-1,-2)) ) - test_result = gradcheck(err_handle, input, eps=1e-6, atol=tol) + test_result = gradcheck(err_handle, grad_input, eps=1e-6, atol=tol) self.assertTrue(test_result) diff --git a/torch_harmonics/__init__.py b/torch_harmonics/__init__.py index 806ce3a..3366517 100644 --- a/torch_harmonics/__init__.py +++ b/torch_harmonics/__init__.py @@ -32,8 +32,7 @@ __version__ = '0.6.4' from .sht import RealSHT, InverseRealSHT, RealVectorSHT, InverseRealVectorSHT +from .convolution import DiscreteContinuousConvS2, DiscreteContinuousConvTransposeS2 from . import quadrature -from . import s2_convolutions -from . import disco_convolutions from . import random_fields from . import examples diff --git a/torch_harmonics/disco_convolutions.py b/torch_harmonics/_disco_convolution.py similarity index 98% rename from torch_harmonics/disco_convolutions.py rename to torch_harmonics/_disco_convolution.py index cb6667b..2f331be 100644 --- a/torch_harmonics/disco_convolutions.py +++ b/torch_harmonics/_disco_convolution.py @@ -377,7 +377,7 @@ def _disco_s2_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nlon_out: in assert psi.shape[-1] == nlat_in * nlon_in assert nlon_in % nlon_out == 0 - + assert nlon_in >= nlat_out pscale = nlon_in // nlon_out # add a dummy dimension for nkernel and move the batch and channel dims to the end @@ -414,7 +414,7 @@ def _disco_s2_transpose_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nl assert psi.shape[-2] == nlat_in assert n_out % nlon_out == 0 nlat_out = n_out // nlon_out - + assert nlon_out >= nlat_in pscale = nlon_out // nlon_in # we do a semi-transposition to faciliate the computation @@ -429,7 +429,7 @@ def _disco_s2_transpose_contraction_torch(x: torch.Tensor, psi: torch.Tensor, nl # interleave zeros along the longitude dimension to allow for fractional offsets to be considered x_ext = torch.zeros(kernel_size, nlat_in, nlon_out, batch_size * n_chans, device=x.device, dtype=x.dtype) - x_ext[:, :, (pscale-1)::pscale, :] = x.reshape(batch_size * n_chans, kernel_size, nlat_in, nlon_in).permute(1, 2, 3, 0) + x_ext[:, :, ::pscale, :] = x.reshape(batch_size * n_chans, kernel_size, nlat_in, nlon_in).permute(1, 2, 3, 0) # we need to go backwards through the vector, so we flip the axis x_ext = x_ext.contiguous() diff --git a/torch_harmonics/s2_convolutions.py b/torch_harmonics/convolution.py similarity index 83% rename from torch_harmonics/s2_convolutions.py rename to torch_harmonics/convolution.py index 031da33..c3fd8bf 100644 --- a/torch_harmonics/s2_convolutions.py +++ b/torch_harmonics/convolution.py @@ -39,7 +39,7 @@ from functools import partial from torch_harmonics.quadrature import _precompute_latitudes -from torch_harmonics.disco_convolutions import ( +from torch_harmonics._disco_convolution import ( _disco_s2_contraction_torch, _disco_s2_transpose_contraction_torch, _disco_s2_contraction_triton, @@ -47,14 +47,14 @@ ) -def _compute_support_vals_isotropic(theta: torch.Tensor, phi: torch.Tensor, kernel_size: int, theta_cutoff: float): +def _compute_support_vals_isotropic(theta: torch.Tensor, phi: torch.Tensor, ntheta: int, theta_cutoff: float): """ Computes the index set that falls into the isotropic kernel's support and returns both indices and values. """ # compute the support - dtheta = (theta_cutoff - 0.0) / kernel_size - ikernel = torch.arange(kernel_size).reshape(-1, 1, 1) + dtheta = (theta_cutoff - 0.0) / ntheta + ikernel = torch.arange(ntheta).reshape(-1, 1, 1) itheta = ikernel * dtheta norm_factor = 2 * math.pi * (1 - math.cos(theta_cutoff - dtheta) + math.cos(theta_cutoff - dtheta) + (math.sin(theta_cutoff - dtheta) - math.sin(theta_cutoff)) / dtheta) @@ -64,6 +64,29 @@ def _compute_support_vals_isotropic(theta: torch.Tensor, phi: torch.Tensor, kern vals = (1 - (theta[iidx[:, 1], iidx[:, 2]] - itheta[iidx[:, 0], 0, 0]).abs() / dtheta) / norm_factor return iidx, vals +def _compute_support_vals_anisotropic(theta: torch.Tensor, phi: torch.Tensor, ntheta: int, nphi: int, theta_cutoff: float): + """ + Computes the index set that falls into the anisotropic kernel's support and returns both indices and values. + """ + + # compute the support + dtheta = (theta_cutoff - 0.0) / ntheta + dphi = 2.0 * math.pi / nphi + kernel_size = (ntheta-1)*nphi + 1 + ikernel = torch.arange(kernel_size).reshape(-1, 1, 1) + itheta = ((ikernel - 1) // nphi + 1) * dtheta + iphi = ((ikernel - 1) % nphi) * dphi + + norm_factor = 2 * math.pi * (1 - math.cos(theta_cutoff - dtheta) + math.cos(theta_cutoff - dtheta) + (math.sin(theta_cutoff - dtheta) - math.sin(theta_cutoff)) / dtheta) + + # find the indices where the rotated position falls into the support of the kernel + cond_theta = ((theta - itheta).abs() <= dtheta) & (theta <= theta_cutoff) + cond_phi = (ikernel == 0) | ((phi - iphi).abs() <= dphi) | ((2*math.pi - (phi - iphi).abs()) <= dphi) + iidx = torch.argwhere(cond_theta & cond_phi) + vals = (1 - (theta[iidx[:, 1], iidx[:, 2]] - itheta[iidx[:, 0], 0, 0]).abs() / dtheta) / norm_factor + vals *= torch.where(iidx[:, 0] > 0, (1 - torch.minimum((phi[iidx[:, 1], iidx[:, 2]] - iphi[iidx[:, 0], 0, 0]).abs(), (2*math.pi - (phi[iidx[:, 1], iidx[:, 2]] - iphi[iidx[:, 0], 0, 0]).abs()) ) / dphi ), 1.0) + return iidx, vals + def _precompute_convolution_tensor( in_shape, out_shape, kernel_shape, grid_in="equiangular", grid_out="equiangular", theta_cutoff=0.01 * math.pi @@ -88,7 +111,9 @@ def _precompute_convolution_tensor( assert len(out_shape) == 2 if len(kernel_shape) == 1: - kernel_handle = partial(_compute_support_vals_isotropic, kernel_size=kernel_shape[0], theta_cutoff=theta_cutoff) + kernel_handle = partial(_compute_support_vals_isotropic, ntheta=kernel_shape[0], theta_cutoff=theta_cutoff) + elif len(kernel_shape) == 2: + kernel_handle = partial(_compute_support_vals_anisotropic, ntheta=kernel_shape[0], nphi=kernel_shape[1], theta_cutoff=theta_cutoff) else: raise ValueError("kernel_shape should be either one- or two-dimensional.") @@ -128,9 +153,9 @@ def _precompute_convolution_tensor( y = y / norm z = z / norm - # compute spherical coordinates + # compute spherical coordinates, where phi needs to fall into the [0, 2pi) range theta = torch.arccos(z) - phi = torch.arctan2(y, x) + phi = torch.arctan2(y, x) + torch.pi # find the indices where the rotated position falls into the support of the kernel iidx, vals = kernel_handle(theta, phi) @@ -146,8 +171,7 @@ def _precompute_convolution_tensor( # TODO: -# - parameter initialization -# - add anisotropy +# - derive conv and conv transpose from single module class DiscreteContinuousConvS2(nn.Module): """ Discrete-continuous convolutions (DISCO) on the 2-Sphere as described in [1]. @@ -175,9 +199,13 @@ def __init__( if isinstance(kernel_shape, int): kernel_shape = [kernel_shape] - self.kernel_size = 1 - for kdim in kernel_shape: - self.kernel_size *= kdim + if len(kernel_shape) == 1: + self.kernel_size = kernel_shape[0] + elif len(kernel_shape) == 2: + self.kernel_size = (kernel_shape[0]-1)*kernel_shape[1] + 1 + else: + raise ValueError("kernel_shape should be either one- or two-dimensional.") + # compute theta cutoff based on the bandlimit of the input field if theta_cutoff is None: @@ -209,7 +237,7 @@ def __init__( raise ValueError("Error, the number of output channels has to be an integer multiple of the group size") self.groupsize = in_channels // self.groups scale = math.sqrt(1.0 / self.groupsize) - self.weight = nn.Parameter(scale * torch.randn(out_channels, self.groupsize, kernel_shape[0])) + self.weight = nn.Parameter(scale * torch.randn(out_channels, self.groupsize, self.kernel_size)) if bias: self.bias = nn.Parameter(torch.zeros(out_channels)) @@ -266,9 +294,12 @@ def __init__( if isinstance(kernel_shape, int): kernel_shape = [kernel_shape] - self.kernel_size = 1 - for kdim in kernel_shape: - self.kernel_size *= kdim + if len(kernel_shape) == 1: + self.kernel_size = kernel_shape[0] + elif len(kernel_shape) == 2: + self.kernel_size = (kernel_shape[0]-1)*kernel_shape[1] + 1 + else: + raise ValueError("kernel_shape should be either one- or two-dimensional.") # bandlimit if theta_cutoff is None: @@ -301,7 +332,7 @@ def __init__( raise ValueError("Error, the number of output channels has to be an integer multiple of the group size") self.groupsize = in_channels // self.groups scale = math.sqrt(1.0 / self.groupsize) - self.weight = nn.Parameter(scale * torch.randn(out_channels, self.groupsize, kernel_shape[0])) + self.weight = nn.Parameter(scale * torch.randn(out_channels, self.groupsize, self.kernel_size)) if bias: self.bias = nn.Parameter(torch.zeros(out_channels)) @@ -310,7 +341,7 @@ def __init__( def forward(self, x: torch.Tensor, use_triton_kernel: bool = True) -> torch.Tensor: # extract shape - B, F, H, W = x.shape + B, C, H, W = x.shape x = x.reshape(B, self.groups, self.groupsize, H, W) # do weight multiplication