Skip to content

Commit

Permalink
moving to an abstract base class for disco
Browse files Browse the repository at this point in the history
  • Loading branch information
bonevbs committed Jan 6, 2024
1 parent 1d387f8 commit 5ec67a6
Showing 1 changed file with 89 additions and 61 deletions.
150 changes: 89 additions & 61 deletions torch_harmonics/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

import abc
from typing import List, Tuple, Union, Optional

import math
Expand Down Expand Up @@ -88,7 +89,7 @@ def _compute_support_vals_anisotropic(theta: torch.Tensor, phi: torch.Tensor, nt
return iidx, vals


def _precompute_convolution_tensor(
def _precompute_convolution_tensor_s2(
in_shape, out_shape, kernel_shape, grid_in="equiangular", grid_out="equiangular", theta_cutoff=0.01 * math.pi
):
"""
Expand Down Expand Up @@ -170,9 +171,90 @@ def _precompute_convolution_tensor(
return out_idx, out_vals


def _precompute_convolution_tensor_2d(
in_grid, out_grid, kernel_shape, radius_cutoff=0.01
):
"""
Precomputes the translated filters at positions $T^{-1}_j \omega_i = T^{-1}_j T_i \nu$. Similar to the S2 routine,
only that it assumes a non-periodic subset of the euclidean plane
"""

# check that input arrays are valid point clouds in 2D
assert len(in_grid) == 2
assert len(out_grid) == 2
assert in_grid.shape[0] == 2
assert out_grid.shape[0] == 2

n_in = in_grid.shape[-1]
n_out = out_grid.shape[-1]

if len(kernel_shape) == 1:
kernel_handle = partial(_compute_support_vals_isotropic, ntheta=kernel_shape[0], theta_cutoff=radius_cutoff)
elif len(kernel_shape) == 2:
kernel_handle = partial(_compute_support_vals_anisotropic, ntheta=kernel_shape[0], nphi=kernel_shape[1], theta_cutoff=radius_cutoff)
else:
raise ValueError("kernel_shape should be either one- or two-dimensional.")

in_grid = in_grid.reshape(2, 1, n_in)
out_grid = out_grid.reshape(2, n_out, 1)

diffs = in_grid - out_grid
r = torch.sqrt(diffs[0]**2 + diffs[1]**2)
phi = torch.arctan2(diffs[1], diffs[0]) + torch.pi

idx, vals = kernel_handle(r, diffs)
idx.permute(1, 0)

return idx, vals

class DiscreteContinuousConv(nn.Module, abc.ABC):
"""
Abstract base class for DISCO convolutions
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_shape: Union[int, List[int]],
groups: Optional[int] = 1,
bias: Optional[bool] = True,
):
super().__init__()

if isinstance(kernel_shape, int):
kernel_shape = [kernel_shape]
if len(kernel_shape) == 1:
self.kernel_size = kernel_shape[0]
elif len(kernel_shape) == 2:
self.kernel_size = (kernel_shape[0]-1)*kernel_shape[1] + 1
else:
raise ValueError("kernel_shape should be either one- or two-dimensional.")

# groups
self.groups = groups

# weight tensor
if in_channels % self.groups != 0:
raise ValueError("Error, the number of input channels has to be an integer multiple of the group size")
if out_channels % self.groups != 0:
raise ValueError("Error, the number of output channels has to be an integer multiple of the group size")
self.groupsize = in_channels // self.groups
scale = math.sqrt(1.0 / self.groupsize)
self.weight = nn.Parameter(scale * torch.randn(out_channels, self.groupsize, self.kernel_size))

if bias:
self.bias = nn.Parameter(torch.zeros(out_channels))
else:
self.bias = None

@abc.abstractmethod
def forward(self, x: torch.Tensor):
raise NotImplementedError


# TODO:
# - derive conv and conv transpose from single module
class DiscreteContinuousConvS2(nn.Module):
class DiscreteContinuousConvS2(DiscreteContinuousConv):
"""
Discrete-continuous convolutions (DISCO) on the 2-Sphere as described in [1].
Expand All @@ -192,21 +274,11 @@ def __init__(
bias: Optional[bool] = True,
theta_cutoff: Optional[float] = None,
):
super().__init__()
super().__init__(in_channels, out_channels, kernel_shape, groups, bias)

self.nlat_in, self.nlon_in = in_shape
self.nlat_out, self.nlon_out = out_shape

if isinstance(kernel_shape, int):
kernel_shape = [kernel_shape]
if len(kernel_shape) == 1:
self.kernel_size = kernel_shape[0]
elif len(kernel_shape) == 2:
self.kernel_size = (kernel_shape[0]-1)*kernel_shape[1] + 1
else:
raise ValueError("kernel_shape should be either one- or two-dimensional.")


# compute theta cutoff based on the bandlimit of the input field
if theta_cutoff is None:
theta_cutoff = (kernel_shape[0]+1) * torch.pi / float(self.nlat_in - 1)
Expand All @@ -219,7 +291,7 @@ def __init__(
quad_weights = 2.0 * torch.pi * torch.from_numpy(wgl).float().reshape(-1, 1) / self.nlon_in
self.register_buffer("quad_weights", quad_weights, persistent=False)

idx, vals = _precompute_convolution_tensor(
idx, vals = _precompute_convolution_tensor_s2(
in_shape, out_shape, kernel_shape, grid_in=grid_in, grid_out=grid_out, theta_cutoff=theta_cutoff
)
# psi = torch.sparse_coo_tensor(
Expand All @@ -229,23 +301,6 @@ def __init__(
self.register_buffer("psi_vals", vals, persistent=False)
# self.register_buffer("psi", psi, persistent=False)

# groups
self.groups = groups

# weight tensor
if in_channels % self.groups != 0:
raise ValueError("Error, the number of input channels has to be an integer multiple of the group size")
if out_channels % self.groups != 0:
raise ValueError("Error, the number of output channels has to be an integer multiple of the group size")
self.groupsize = in_channels // self.groups
scale = math.sqrt(1.0 / self.groupsize)
self.weight = nn.Parameter(scale * torch.randn(out_channels, self.groupsize, self.kernel_size))

if bias:
self.bias = nn.Parameter(torch.zeros(out_channels))
else:
self.bias = None

def get_psi(self):
psi = torch.sparse_coo_tensor(self.psi_idx, self.psi_vals, size=(self.kernel_size, self.nlat_out, self.nlat_in * self.nlon_in)).coalesce()
return psi
Expand Down Expand Up @@ -274,8 +329,7 @@ def forward(self, x: torch.Tensor, use_triton_kernel: bool = True) -> torch.Tens

return out


class DiscreteContinuousConvTransposeS2(nn.Module):
class DiscreteContinuousConvTransposeS2(DiscreteContinuousConv):
"""
Discrete-continuous transpose convolutions (DISCO) on the 2-Sphere as described in [1].
Expand All @@ -295,20 +349,11 @@ def __init__(
bias: Optional[bool] = True,
theta_cutoff: Optional[float] = None,
):
super().__init__()
super().__init__(in_channels, out_channels, kernel_shape, groups, bias)

self.nlat_in, self.nlon_in = in_shape
self.nlat_out, self.nlon_out = out_shape

if isinstance(kernel_shape, int):
kernel_shape = [kernel_shape]
if len(kernel_shape) == 1:
self.kernel_size = kernel_shape[0]
elif len(kernel_shape) == 2:
self.kernel_size = (kernel_shape[0]-1)*kernel_shape[1] + 1
else:
raise ValueError("kernel_shape should be either one- or two-dimensional.")

# bandlimit
if theta_cutoff is None:
theta_cutoff = (kernel_shape[0]+1) * torch.pi / float(self.nlat_in - 1)
Expand All @@ -322,7 +367,7 @@ def __init__(
self.register_buffer("quad_weights", quad_weights, persistent=False)

# switch in_shape and out_shape since we want transpose conv
idx, vals = _precompute_convolution_tensor(
idx, vals = _precompute_convolution_tensor_s2(
out_shape, in_shape, kernel_shape, grid_in=grid_out, grid_out=grid_in, theta_cutoff=theta_cutoff
)
# psi = torch.sparse_coo_tensor(
Expand All @@ -332,23 +377,6 @@ def __init__(
self.register_buffer("psi_vals", vals, persistent=False)
# self.register_buffer("psi", psi, persistent=False)

# groups
self.groups = groups

# weight tensor
if in_channels % self.groups != 0:
raise ValueError("Error, the number of input channels has to be an integer multiple of the group size")
if out_channels % self.groups != 0:
raise ValueError("Error, the number of output channels has to be an integer multiple of the group size")
self.groupsize = in_channels // self.groups
scale = math.sqrt(1.0 / self.groupsize)
self.weight = nn.Parameter(scale * torch.randn(out_channels, self.groupsize, self.kernel_size))

if bias:
self.bias = nn.Parameter(torch.zeros(out_channels))
else:
self.bias = None

def get_psi(self):
psi = torch.sparse_coo_tensor(self.psi_idx, self.psi_vals, size=(self.kernel_size, self.nlat_in, self.nlat_out * self.nlon_out)).coalesce()
return psi
Expand Down

0 comments on commit 5ec67a6

Please sign in to comment.