diff --git a/README.md b/README.md index c8ff05d..a10c886 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ Install scientific python stack + nosetests $ pip install numpy scipy matplotlib nose ``` -Install [chainer](http://chainer.org/) with CUDNN and HDF5 or install [tensorflow](https://www.tensorflow.org/) +Install [chainer](http://chainer.org/) with CUDNN and HDF5, install [tensorflow](https://www.tensorflow.org/) or install [PyTorch](http://pytorch.org/). Clone the latest GrouPy from github and run setup.py @@ -39,7 +39,7 @@ $ nosetests -v ### TensorFlow -``` +```python import numpy as np import tensorflow as tf from groupy.gconv.tensorflow_gconv.splitgconv2d import gconv2d, gconv2d_util @@ -71,7 +71,7 @@ print y.shape # (10, 9, 9, 512) ### Chainer -``` +```python from chainer import Variable import cupy as cp from groupy.gconv.chainer_gconv import P4ConvZ2, P4ConvP4 @@ -88,6 +88,26 @@ y = C2(C1(x)) print y.data.shape # (10, 64, 4, 9, 9) ``` +### Pytorch + +```python +import torch +from torch.autograd import Variable +from groupy.gconv.pytorch_gconv import P4ConvZ2, P4ConvP4 + +# Construct G-Conv layers +C1 = P4ConvZ2(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1) +C2 = P4ConvP4(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1) + +# Create 10 images with 3 channels and 9x9 pixels: +x = Variable(torch.randn(10, 3, 9, 9)) + +# fprop +y = C2(C1(x)) +print y.data.shape # (10, 64, 4, 9, 9) +``` + + ## Functionality diff --git a/groupy/gconv/pytorch_gconv/__init__.py b/groupy/gconv/pytorch_gconv/__init__.py new file mode 100644 index 0000000..e1d57af --- /dev/null +++ b/groupy/gconv/pytorch_gconv/__init__.py @@ -0,0 +1 @@ +from groupy.gconv.pytorch_gconv.splitgconv2d import P4ConvZ2, P4ConvP4, P4MConvZ2, P4MConvP4M \ No newline at end of file diff --git a/groupy/gconv/pytorch_gconv/check_gconv2d.py b/groupy/gconv/pytorch_gconv/check_gconv2d.py new file mode 100644 index 0000000..7f7be95 --- /dev/null +++ b/groupy/gconv/pytorch_gconv/check_gconv2d.py @@ -0,0 +1,114 @@ +import numpy as np +import torch +from torch.autograd import Variable +from groupy.gconv.pytorch_gconv.splitgconv2d import P4ConvZ2, P4ConvP4, P4MConvZ2, P4MConvP4M + + +def test_p4_net_equivariance(): + from groupy.gfunc import Z2FuncArray, P4FuncArray + import groupy.garray.C4_array as c4a + + im = np.random.randn(1, 1, 11, 11).astype('float32') + check_equivariance( + im=im, + layers=[ + P4ConvZ2(in_channels=1, out_channels=2, kernel_size=3), + P4ConvP4(in_channels=2, out_channels=3, kernel_size=3) + ], + input_array=Z2FuncArray, + output_array=P4FuncArray, + point_group=c4a, + ) + + +def test_p4m_net_equivariance(): + from groupy.gfunc import Z2FuncArray, P4MFuncArray + import groupy.garray.D4_array as d4a + + im = np.random.randn(1, 1, 11, 11).astype('float32') + check_equivariance( + im=im, + layers=[ + P4MConvZ2(in_channels=1, out_channels=2, kernel_size=3), + P4MConvP4M(in_channels=2, out_channels=3, kernel_size=3) + ], + input_array=Z2FuncArray, + output_array=P4MFuncArray, + point_group=d4a, + ) + + +def test_g_z2_conv_equivariance(): + from groupy.gfunc import Z2FuncArray, P4FuncArray, P4MFuncArray + import groupy.garray.C4_array as c4a + import groupy.garray.D4_array as d4a + + im = np.random.randn(1, 1, 11, 11).astype('float32') + check_equivariance( + im=im, + layers=[P4ConvZ2(1, 2, 3)], + input_array=Z2FuncArray, + output_array=P4FuncArray, + point_group=c4a, + ) + + check_equivariance( + im=im, + layers=[P4MConvZ2(1, 2, 3)], + input_array=Z2FuncArray, + output_array=P4MFuncArray, + point_group=d4a, + ) + + +def test_p4_p4_conv_equivariance(): + from groupy.gfunc import P4FuncArray + import groupy.garray.C4_array as c4a + + im = np.random.randn(1, 1, 4, 11, 11).astype('float32') + check_equivariance( + im=im, + layers=[P4ConvP4(1, 2, 3)], + input_array=P4FuncArray, + output_array=P4FuncArray, + point_group=c4a, + ) + + +def test_p4m_p4m_conv_equivariance(): + from groupy.gfunc import P4MFuncArray + import groupy.garray.D4_array as d4a + + im = np.random.randn(1, 1, 8, 11, 11).astype('float32') + check_equivariance( + im=im, + layers=[P4MConvP4M(1, 2, 3)], + input_array=P4MFuncArray, + output_array=P4MFuncArray, + point_group=d4a, + ) + + +def check_equivariance(im, layers, input_array, output_array, point_group): + + # Transform the image + f = input_array(im) + g = point_group.rand() + gf = g * f + im1 = gf.v + # Apply layers to both images + im = Variable(torch.Tensor(im)) + im1 = Variable(torch.Tensor(im1)) + + fmap = im + fmap1 = im1 + for layer in layers: + fmap = layer(fmap) + fmap1 = layer(fmap1) + + # Transform the computed feature maps + fmap1_garray = output_array(fmap1.data.numpy()) + r_fmap1_data = (g.inv() * fmap1_garray).v + + fmap_data = fmap.data.numpy() + assert np.allclose(fmap_data, r_fmap1_data, rtol=1e-5, atol=1e-3) diff --git a/groupy/gconv/pytorch_gconv/check_transform_filter.py b/groupy/gconv/pytorch_gconv/check_transform_filter.py new file mode 100644 index 0000000..a7fbb18 --- /dev/null +++ b/groupy/gconv/pytorch_gconv/check_transform_filter.py @@ -0,0 +1,104 @@ +import numpy as np +import tensorflow as tf +import torch + +from groupy.gconv.tensorflow_gconv.transform_filter import transform_filter_2d_nchw, transform_filter_2d_nhwc +from groupy.gconv.make_gconv_indices import make_c4_z2_indices, make_c4_p4_indices,\ + make_d4_z2_indices, make_d4_p4m_indices, flatten_indices +from groupy.gconv.pytorch_gconv.splitgconv2d import trans_filter as pytorch_trans_filter_ + +# Comparing tensorflow and pytorch filter transformation + + +def check_c4_z2(): + inds = make_c4_z2_indices(ksize=3) + w = np.random.randn(6, 7, 1, 3, 3) + + rt = tf_trans_filter(w, inds) + rp = pytorch_trans_filter(w, inds) + diff = np.abs(rt - rp).sum() + print ('>>>>> DIFFERENCE:', diff) + assert diff == 0 + + +def check_c4_p4(): + inds = make_c4_p4_indices(ksize=3) + w = np.random.randn(6, 7, 4, 3, 3) + + rt = tf_trans_filter(w, inds) + rp = pytorch_trans_filter(w, inds) + + diff = np.abs(rt - rp).sum() + print ('>>>>> DIFFERENCE:', diff) + assert diff == 0 + + +def check_d4_z2(): + inds = make_d4_z2_indices(ksize=3) + w = np.random.randn(6, 7, 1, 3, 3) + + rt = tf_trans_filter(w, inds) + rp = pytorch_trans_filter(w, inds) + + diff = np.abs(rt - rp).sum() + print ('>>>>> DIFFERENCE:', diff) + assert diff == 0 + + +def check_d4_p4m(): + inds = make_d4_p4m_indices(ksize=3) + w = np.random.randn(6, 7, 8, 3, 3) + + rt = tf_trans_filter(w, inds) + rp = pytorch_trans_filter(w, inds) + + diff = np.abs(rt - rp).sum() + print ('>>>>> DIFFERENCE:', diff) + assert diff == 0 + + +def tf_trans_filter(w, inds): + + flat_inds = flatten_indices(inds) + no, ni, nti, n, _ = w.shape + shape_info = (no, inds.shape[0], ni, nti, n) + + w = w.transpose((3, 4, 2, 1, 0)).reshape((n, n, nti * ni, no)) + + wt = tf.constant(w) + rwt = transform_filter_2d_nhwc(wt, flat_inds, shape_info) + + sess = tf.Session() + rwt = sess.run(rwt) + sess.close() + + nto = inds.shape[0] + rwt = rwt.transpose(3, 2, 0, 1).reshape(no, nto, ni, nti, n, n) + return rwt + + +def tf_trans_filter2(w, inds): + + flat_inds = flatten_indices(inds) + no, ni, nti, n, _ = w.shape + shape_info = (no, inds.shape[0], ni, nti, n) + + w = w.reshape(no, ni * nti, n, n) + + wt = tf.constant(w) + rwt = transform_filter_2d_nchw(wt, flat_inds, shape_info) + + sess = tf.Session() + rwt = sess.run(rwt) + sess.close() + + nto = inds.shape[0] + rwt = rwt.reshape(no, nto, ni, nti, n, n) + return rwt + + +def pytorch_trans_filter(w, inds): + w = torch.DoubleTensor(w) + rp = pytorch_trans_filter_(w, inds) + rp = rp.numpy() + return rp diff --git a/groupy/gconv/pytorch_gconv/pooling.py b/groupy/gconv/pytorch_gconv/pooling.py new file mode 100644 index 0000000..dec52b2 --- /dev/null +++ b/groupy/gconv/pytorch_gconv/pooling.py @@ -0,0 +1,9 @@ +import torch.nn.functional as F + + +def plane_group_spatial_max_pooling(x, ksize, stride=None, pad=0): + xs = x.size() + x = x.view(xs[0], xs[1] * xs[2], xs[3], xs[4]) + x = F.max_pool2d(input=x, kernel_size=ksize, stride=stride, padding=pad) + x = x.view(xs[0], xs[1], xs[2], x.size()[2], x.size()[3]) + return x diff --git a/groupy/gconv/pytorch_gconv/splitgconv2d.py b/groupy/gconv/pytorch_gconv/splitgconv2d.py new file mode 100644 index 0000000..96311dc --- /dev/null +++ b/groupy/gconv/pytorch_gconv/splitgconv2d.py @@ -0,0 +1,109 @@ +import torch.nn as nn +from torch.nn import Parameter +import torch.nn.functional as F +import torch +import math +from torch.nn.modules.utils import _pair +from groupy.gconv.make_gconv_indices import * + +make_indices_functions = {(1, 4): make_c4_z2_indices, + (4, 4): make_c4_p4_indices, + (1, 8): make_d4_z2_indices, + (8, 8): make_d4_p4m_indices} + + +def trans_filter(w, inds): + inds_reshape = inds.reshape((-1, inds.shape[-1])).astype(np.int64) + w_indexed = w[:, :, inds_reshape[:, 0].tolist(), inds_reshape[:, 1].tolist(), inds_reshape[:, 2].tolist()] + w_indexed = w_indexed.view(w_indexed.size()[0], w_indexed.size()[1], + inds.shape[0], inds.shape[1], inds.shape[2], inds.shape[3]) + w_transformed = w_indexed.permute(0, 2, 1, 3, 4, 5) + return w_transformed.contiguous() + + +class SplitGConv2D(nn.Module): + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, bias=True, input_stabilizer_size=1, output_stabilizer_size=4): + super(SplitGConv2D, self).__init__() + assert (input_stabilizer_size, output_stabilizer_size) in make_indices_functions.keys() + self.ksize = kernel_size + + kernel_size = _pair(kernel_size) + stride = _pair(stride) + padding = _pair(padding) + + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.input_stabilizer_size = input_stabilizer_size + self.output_stabilizer_size = output_stabilizer_size + + self.weight = Parameter(torch.Tensor( + out_channels, in_channels, self.input_stabilizer_size, *kernel_size)) + if bias: + self.bias = Parameter(torch.Tensor(out_channels)) + else: + self.register_parameter('bias', None) + self.reset_parameters() + + self.inds = self.make_transformation_indices() + + def reset_parameters(self): + n = self.in_channels + for k in self.kernel_size: + n *= k + stdv = 1. / math.sqrt(n) + self.weight.data.uniform_(-stdv, stdv) + if self.bias is not None: + self.bias.data.uniform_(-stdv, stdv) + + def make_transformation_indices(self): + return make_indices_functions[(self.input_stabilizer_size, self.output_stabilizer_size)](self.ksize) + + def forward(self, input): + tw = trans_filter(self.weight, self.inds) + tw_shape = (self.out_channels * self.output_stabilizer_size, + self.in_channels * self.input_stabilizer_size, + self.ksize, self.ksize) + tw = tw.view(tw_shape) + + input_shape = input.size() + input = input.view(input_shape[0], self.in_channels*self.input_stabilizer_size, input_shape[-2], input_shape[-1]) + + y = F.conv2d(input, weight=tw, bias=None, stride=self.stride, + padding=self.padding) + batch_size, _, ny_out, nx_out = y.size() + y = y.view(batch_size, self.out_channels, self.output_stabilizer_size, ny_out, nx_out) + + if self.bias is not None: + bias = self.bias.view(1, self.out_channels, 1, 1, 1) + y = y + bias + + return y + + +class P4ConvZ2(SplitGConv2D): + + def __init__(self, *args, **kwargs): + super(P4ConvZ2, self).__init__(input_stabilizer_size=1, output_stabilizer_size=4, *args, **kwargs) + + +class P4ConvP4(SplitGConv2D): + + def __init__(self, *args, **kwargs): + super(P4ConvP4, self).__init__(input_stabilizer_size=4, output_stabilizer_size=4, *args, **kwargs) + + +class P4MConvZ2(SplitGConv2D): + + def __init__(self, *args, **kwargs): + super(P4MConvZ2, self).__init__(input_stabilizer_size=1, output_stabilizer_size=8, *args, **kwargs) + + +class P4MConvP4M(SplitGConv2D): + + def __init__(self, *args, **kwargs): + super(P4MConvP4M, self).__init__(input_stabilizer_size=8, output_stabilizer_size=8, *args, **kwargs) \ No newline at end of file diff --git a/setup.py b/setup.py index eda2a5b..bb84b68 100644 --- a/setup.py +++ b/setup.py @@ -4,9 +4,9 @@ setup( name='GrouPy', - version='0.1.1', + version='0.1.2', description='Group equivariant convolutional neural networks', author='Taco S. Cohen', author_email='taco.cohen@gmail.com', - packages=['groupy', 'groupy.garray', 'groupy.gconv', 'groupy.gconv.chainer_gconv', 'groupy.gconv.theano_gconv', 'groupy.gconv.tensorflow_gconv', 'groupy.gfunc', 'groupy.gfunc.plot'], + packages=['groupy', 'groupy.garray', 'groupy.gconv', 'groupy.gconv.chainer_gconv', 'groupy.gconv.theano_gconv', 'groupy.gconv.tensorflow_gconv', 'groupy.gconv.pytorch_gconv', 'groupy.gfunc', 'groupy.gfunc.plot'], )