diff --git a/groupy/gconv/pytorch_gconv/__init__.py b/groupy/gconv/pytorch_gconv/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/groupy/gconv/pytorch_gconv/p4_conv.py b/groupy/gconv/pytorch_gconv/p4_conv.py new file mode 100644 index 0000000..63f0876 --- /dev/null +++ b/groupy/gconv/pytorch_gconv/p4_conv.py @@ -0,0 +1,31 @@ +from groupy.gconv.pytorch_gconv.splitgconv2d import SplitGConv2D +from groupy.gconv.make_gconv_indices import make_c4_z2_indices, \ + make_c4_p4_indices, flatten_indices + + +class P4ConvZ2(SplitGConv2D): + + @property + def input_stabilizer_size(self): + return 1 + + @property + def output_stabilizer_size(self): + return 4 + + def make_transformation_indices(self, ksize): + return flatten_indices(make_c4_z2_indices(ksize=ksize)) + + +class P4ConvP4(SplitGConv2D): + + @property + def input_stabilizer_size(self): + return 4 + + @property + def output_stabilizer_size(self): + return 4 + + def make_transformation_indices(self, ksize): + return flatten_indices(make_c4_p4_indices(ksize=ksize)) diff --git a/groupy/gconv/pytorch_gconv/p4m_conv.py b/groupy/gconv/pytorch_gconv/p4m_conv.py new file mode 100644 index 0000000..064372a --- /dev/null +++ b/groupy/gconv/pytorch_gconv/p4m_conv.py @@ -0,0 +1,31 @@ +from groupy.gconv.pytorch_gconv.splitgconv2d import SplitGConv2D +from groupy.gconv.make_gconv_indices import make_d4_z2_indices, \ + make_d4_p4m_indices, flatten_indices + + +class P4MConvZ2(SplitGConv2D): + + @property + def input_stabilizer_size(self): + return 1 + + @property + def output_stabilizer_size(self): + return 8 + + def make_transformation_indices(self, ksize): + return flatten_indices(make_d4_z2_indices(ksize=ksize)) + + +class P4MConvP4M(SplitGConv2D): + + @property + def input_stabilizer_size(self): + return 8 + + @property + def output_stabilizer_size(self): + return 8 + + def make_transformation_indices(self, ksize): + return flatten_indices(make_d4_p4m_indices(ksize=ksize)) diff --git a/groupy/gconv/pytorch_gconv/splitgconv2d.py b/groupy/gconv/pytorch_gconv/splitgconv2d.py new file mode 100644 index 0000000..be245cd --- /dev/null +++ b/groupy/gconv/pytorch_gconv/splitgconv2d.py @@ -0,0 +1,155 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.nn.init as nninit +from torch.autograd import Variable + + +def _pair(x): + if hasattr(x, '__getitem__'): + return x + else: + return (x, x) + + +class SplitGConv2D(nn.Module): + """ + Group convolution base class for split plane groups. + + A plane group (aka wallpaper group) is a group of distance-preserving + transformations that includes two independent discrete translations. + + A group is called split (or symmorphic) if every element in this group can + be written as the composition of an element from the "stabilizer of the + origin" and a translation. The stabilizer of the origin consists of those + transformations in the group that leave the origin fixed. For example, the + stabilizer in the rotation-translation group p4 is the set of rotations + around the origin, which is (isomorphic to) the group C4. + + Most plane groups are split, but some include glide-reflection generators; + such groups are not split. For split groups G, the G-conv can be split + into a "filter transform" and "translational convolution" part. + + Different subclasses of this class implement the filter transform for + various groups, while this class implements the common functionality. + + This PyTorch implementation mimicks the original Chainer implementation. + """ + + def __init__(self, + in_channels, + out_channels, + ksize=3, + flat_channels=False, + stride=1, + pad=0, + bias=True, + *args, **kwargs): + """ + :param in_channels: + :param out_channels: + :param ksize: + :param flat_channels + :param stride: + :param pad: + :param bias: + :return: + """ + + super(SplitGConv2D, self).__init__(*args, **kwargs) + + if not isinstance(ksize, int): + raise TypeError('ksize must be an integer (only square filters ' + 'are supported).') + + self.in_channels = in_channels + self.out_channels = out_channels + self.ksize = ksize + self.stride = _pair(stride) + self.pad = _pair(pad) + self.flat_channels = flat_channels + self.use_bias = bias + + self.weight = nn.Parameter(torch.Tensor(self.out_channels, + self.in_channels, + self.input_stabilizer_size, + self.ksize, + self.ksize)) + nninit.xavier_normal(self.weight) + + if self.use_bias: + self.bias = nn.Parameter( + torch.zeros(self.out_channels)) + + # Shorthands + ni, no = in_channels, out_channels + nti, nto = self.input_stabilizer_size, self.output_stabilizer_size + n = self.ksize + + self.expand_shape = (no, nto, ni, nti * n * n) + self.weight_shape = (no * nto, ni * nti, n, n) + self.weight_flat_shape = (no, 1, ni, nti * n * n) + + transform_indices = self._create_indices(self.expand_shape) + self.register_buffer('transform_indices', transform_indices) + + def _create_indices(self, expand_shape): + no, nto, ni, r = expand_shape + transform_indices = self.make_transformation_indices(ksize=self.ksize) + transform_indices = transform_indices.astype(np.int64) + transform_indices = transform_indices.reshape(1, nto, 1, r) + transform_indices = torch.from_numpy(transform_indices) + transform_indices = transform_indices.expand(*expand_shape) + return transform_indices + + @property + def input_stabilizer_size(): + raise NotImplementedError() + + @property + def output_stabilizer_size(): + raise NotImplementedError() + + def make_transformation_indices(self, ksize): + raise NotImplementedError() + + def forward(self, x): + # Transform the filters + w_flat_ = self.weight.view(self.weight_flat_shape) + w_flat = w_flat_.expand(*self.expand_shape) + w = torch.gather(w_flat, 3, Variable(self.transform_indices)) \ + .view(self.weight_shape) + + # If flat_channels is False, we need to flatten the input feature maps + # to have a single 1d feature dimension. + if not self.flat_channels: + batch_size = x.size(0) + in_ny, in_nx = x.size()[-2:] + x = x.view(batch_size, + self.in_channels * self.input_stabilizer_size, + in_ny, + in_nx) + + # Perform the 2D convolution + y = F.conv2d(x, w, stride=self.stride, padding=self.pad) + + # Unfold the output feature maps + # We do this even if flat_channels is True, because we need to add the + # same bias to each G-feature map + batch_size, _, ny_out, nx_out = y.size() + y = y.view(batch_size, self.out_channels, self.output_stabilizer_size, + ny_out, nx_out) + + # Add a bias to each G-feature map + if self.use_bias: + b = self.bias.view(1, self.out_channels, 1, 1, 1) + b = b.expand_as(y) + y = y + b + + # Flatten feature channels if needed + if self.flat_channels: + n, nc, ng, nx, ny = y.size() + y = y.view(n, nc * ng, nx, ny) + + return y diff --git a/groupy/gconv/pytorch_gconv/test_gconv.py b/groupy/gconv/pytorch_gconv/test_gconv.py new file mode 100644 index 0000000..25228f6 --- /dev/null +++ b/groupy/gconv/pytorch_gconv/test_gconv.py @@ -0,0 +1,121 @@ +import numpy as np +import torch +from torch.autograd import Variable + + +def test_p4_net_equivariance(): + from groupy.gfunc import Z2FuncArray, P4FuncArray + import groupy.garray.C4_array as c4a + from groupy.gconv.pytorch_gconv.p4_conv import P4ConvZ2, P4ConvP4 + + im = np.random.randn(1, 1, 11, 11).astype('float32') + check_equivariance( + im=im, + layers=[ + P4ConvZ2(in_channels=1, out_channels=2, ksize=3), + P4ConvP4(in_channels=2, out_channels=3, ksize=3) + ], + input_array=Z2FuncArray, + output_array=P4FuncArray, + point_group=c4a, + ) + + +def test_p4m_net_equivariance(): + from groupy.gfunc import Z2FuncArray, P4MFuncArray + import groupy.garray.D4_array as d4a + from groupy.gconv.pytorch_gconv.p4m_conv import P4MConvZ2, P4MConvP4M + + im = np.random.randn(1, 1, 11, 11).astype('float32') + check_equivariance( + im=im, + layers=[ + P4MConvZ2(in_channels=1, out_channels=2, ksize=3), + P4MConvP4M(in_channels=2, out_channels=3, ksize=3) + ], + input_array=Z2FuncArray, + output_array=P4MFuncArray, + point_group=d4a, + ) + + +def test_g_z2_conv_equivariance(): + from groupy.gfunc import Z2FuncArray, P4FuncArray, P4MFuncArray + import groupy.garray.C4_array as c4a + import groupy.garray.D4_array as d4a + from groupy.gconv.pytorch_gconv.p4_conv import P4ConvZ2 + from groupy.gconv.pytorch_gconv.p4m_conv import P4MConvZ2 + + im = np.random.randn(1, 1, 11, 11).astype('float32') + check_equivariance( + im=im, + layers=[P4ConvZ2(1, 2, 3)], + input_array=Z2FuncArray, + output_array=P4FuncArray, + point_group=c4a, + ) + + check_equivariance( + im=im, + layers=[P4MConvZ2(1, 2, 3)], + input_array=Z2FuncArray, + output_array=P4MFuncArray, + point_group=d4a, + ) + + +def test_p4_p4_conv_equivariance(): + from groupy.gfunc import P4FuncArray + import groupy.garray.C4_array as c4a + from groupy.gconv.pytorch_gconv.p4_conv import P4ConvP4 + + im = np.random.randn(1, 1, 4, 11, 11).astype('float32') + check_equivariance( + im=im, + layers=[P4ConvP4(1, 2, 3)], + input_array=P4FuncArray, + output_array=P4FuncArray, + point_group=c4a, + ) + + +def test_p4m_p4m_conv_equivariance(): + from groupy.gfunc import P4MFuncArray + import groupy.garray.D4_array as d4a + from groupy.gconv.pytorch_gconv.p4m_conv import P4MConvP4M + + im = np.random.randn(1, 1, 8, 11, 11).astype('float32') + check_equivariance( + im=im, + layers=[P4MConvP4M(1, 2, 3)], + input_array=P4MFuncArray, + output_array=P4MFuncArray, + point_group=d4a, + ) + + +def check_equivariance(im, layers, input_array, output_array, point_group): + + # Transform the image + f = input_array(im) + g = point_group.rand() + gf = g * f + im1 = gf.v + + # Apply layers to both images + im = Variable(torch.from_numpy(im)) + im1 = Variable(torch.from_numpy(im1)) + + fmap = im + fmap1 = im1 + for layer in layers: + print(layer) + fmap = layer(fmap) + fmap1 = layer(fmap1) + + # Transform the computed feature maps + fmap1_garray = output_array(fmap1.data.numpy()) + r_fmap1_data = (g.inv() * fmap1_garray).v + + fmap_data = fmap.data.numpy() + assert np.allclose(fmap_data, r_fmap1_data, rtol=1e-5, atol=1e-3)