Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyTorch implementation #7

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 23 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ Install scientific python stack + nosetests
$ pip install numpy scipy matplotlib nose
```

Install [chainer](http://chainer.org/) with CUDNN and HDF5 or install [tensorflow](https://www.tensorflow.org/)
Install [chainer](http://chainer.org/) with CUDNN and HDF5, install [tensorflow](https://www.tensorflow.org/) or install [PyTorch](http://pytorch.org/).

Clone the latest GrouPy from github and run setup.py

Expand All @@ -39,7 +39,7 @@ $ nosetests -v

### TensorFlow

```
```python
import numpy as np
import tensorflow as tf
from groupy.gconv.tensorflow_gconv.splitgconv2d import gconv2d, gconv2d_util
Expand Down Expand Up @@ -71,7 +71,7 @@ print y.shape # (10, 9, 9, 512)

### Chainer

```
```python
from chainer import Variable
import cupy as cp
from groupy.gconv.chainer_gconv import P4ConvZ2, P4ConvP4
Expand All @@ -88,6 +88,26 @@ y = C2(C1(x))
print y.data.shape # (10, 64, 4, 9, 9)
```

### Pytorch

```python
import torch
from torch.autograd import Variable
from groupy.gconv.pytorch_gconv import P4ConvZ2, P4ConvP4

# Construct G-Conv layers
C1 = P4ConvZ2(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1)
C2 = P4ConvP4(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1)

# Create 10 images with 3 channels and 9x9 pixels:
x = Variable(torch.randn(10, 3, 9, 9))

# fprop
y = C2(C1(x))
print y.data.shape # (10, 64, 4, 9, 9)
```



## Functionality

Expand Down
1 change: 1 addition & 0 deletions groupy/gconv/pytorch_gconv/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from groupy.gconv.pytorch_gconv.splitgconv2d import P4ConvZ2, P4ConvP4, P4MConvZ2, P4MConvP4M
114 changes: 114 additions & 0 deletions groupy/gconv/pytorch_gconv/check_gconv2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import numpy as np
import torch
from torch.autograd import Variable
from groupy.gconv.pytorch_gconv.splitgconv2d import P4ConvZ2, P4ConvP4, P4MConvZ2, P4MConvP4M


def test_p4_net_equivariance():
from groupy.gfunc import Z2FuncArray, P4FuncArray
import groupy.garray.C4_array as c4a

im = np.random.randn(1, 1, 11, 11).astype('float32')
check_equivariance(
im=im,
layers=[
P4ConvZ2(in_channels=1, out_channels=2, kernel_size=3),
P4ConvP4(in_channels=2, out_channels=3, kernel_size=3)
],
input_array=Z2FuncArray,
output_array=P4FuncArray,
point_group=c4a,
)


def test_p4m_net_equivariance():
from groupy.gfunc import Z2FuncArray, P4MFuncArray
import groupy.garray.D4_array as d4a

im = np.random.randn(1, 1, 11, 11).astype('float32')
check_equivariance(
im=im,
layers=[
P4MConvZ2(in_channels=1, out_channels=2, kernel_size=3),
P4MConvP4M(in_channels=2, out_channels=3, kernel_size=3)
],
input_array=Z2FuncArray,
output_array=P4MFuncArray,
point_group=d4a,
)


def test_g_z2_conv_equivariance():
from groupy.gfunc import Z2FuncArray, P4FuncArray, P4MFuncArray
import groupy.garray.C4_array as c4a
import groupy.garray.D4_array as d4a

im = np.random.randn(1, 1, 11, 11).astype('float32')
check_equivariance(
im=im,
layers=[P4ConvZ2(1, 2, 3)],
input_array=Z2FuncArray,
output_array=P4FuncArray,
point_group=c4a,
)

check_equivariance(
im=im,
layers=[P4MConvZ2(1, 2, 3)],
input_array=Z2FuncArray,
output_array=P4MFuncArray,
point_group=d4a,
)


def test_p4_p4_conv_equivariance():
from groupy.gfunc import P4FuncArray
import groupy.garray.C4_array as c4a

im = np.random.randn(1, 1, 4, 11, 11).astype('float32')
check_equivariance(
im=im,
layers=[P4ConvP4(1, 2, 3)],
input_array=P4FuncArray,
output_array=P4FuncArray,
point_group=c4a,
)


def test_p4m_p4m_conv_equivariance():
from groupy.gfunc import P4MFuncArray
import groupy.garray.D4_array as d4a

im = np.random.randn(1, 1, 8, 11, 11).astype('float32')
check_equivariance(
im=im,
layers=[P4MConvP4M(1, 2, 3)],
input_array=P4MFuncArray,
output_array=P4MFuncArray,
point_group=d4a,
)


def check_equivariance(im, layers, input_array, output_array, point_group):

# Transform the image
f = input_array(im)
g = point_group.rand()
gf = g * f
im1 = gf.v
# Apply layers to both images
im = Variable(torch.Tensor(im))
im1 = Variable(torch.Tensor(im1))

fmap = im
fmap1 = im1
for layer in layers:
fmap = layer(fmap)
fmap1 = layer(fmap1)

# Transform the computed feature maps
fmap1_garray = output_array(fmap1.data.numpy())
r_fmap1_data = (g.inv() * fmap1_garray).v

fmap_data = fmap.data.numpy()
assert np.allclose(fmap_data, r_fmap1_data, rtol=1e-5, atol=1e-3)
104 changes: 104 additions & 0 deletions groupy/gconv/pytorch_gconv/check_transform_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import numpy as np
import tensorflow as tf
import torch

from groupy.gconv.tensorflow_gconv.transform_filter import transform_filter_2d_nchw, transform_filter_2d_nhwc
from groupy.gconv.make_gconv_indices import make_c4_z2_indices, make_c4_p4_indices,\
make_d4_z2_indices, make_d4_p4m_indices, flatten_indices
from groupy.gconv.pytorch_gconv.splitgconv2d import trans_filter as pytorch_trans_filter_

# Comparing tensorflow and pytorch filter transformation


def check_c4_z2():
inds = make_c4_z2_indices(ksize=3)
w = np.random.randn(6, 7, 1, 3, 3)

rt = tf_trans_filter(w, inds)
rp = pytorch_trans_filter(w, inds)
diff = np.abs(rt - rp).sum()
print ('>>>>> DIFFERENCE:', diff)
assert diff == 0


def check_c4_p4():
inds = make_c4_p4_indices(ksize=3)
w = np.random.randn(6, 7, 4, 3, 3)

rt = tf_trans_filter(w, inds)
rp = pytorch_trans_filter(w, inds)

diff = np.abs(rt - rp).sum()
print ('>>>>> DIFFERENCE:', diff)
assert diff == 0


def check_d4_z2():
inds = make_d4_z2_indices(ksize=3)
w = np.random.randn(6, 7, 1, 3, 3)

rt = tf_trans_filter(w, inds)
rp = pytorch_trans_filter(w, inds)

diff = np.abs(rt - rp).sum()
print ('>>>>> DIFFERENCE:', diff)
assert diff == 0


def check_d4_p4m():
inds = make_d4_p4m_indices(ksize=3)
w = np.random.randn(6, 7, 8, 3, 3)

rt = tf_trans_filter(w, inds)
rp = pytorch_trans_filter(w, inds)

diff = np.abs(rt - rp).sum()
print ('>>>>> DIFFERENCE:', diff)
assert diff == 0


def tf_trans_filter(w, inds):

flat_inds = flatten_indices(inds)
no, ni, nti, n, _ = w.shape
shape_info = (no, inds.shape[0], ni, nti, n)

w = w.transpose((3, 4, 2, 1, 0)).reshape((n, n, nti * ni, no))

wt = tf.constant(w)
rwt = transform_filter_2d_nhwc(wt, flat_inds, shape_info)

sess = tf.Session()
rwt = sess.run(rwt)
sess.close()

nto = inds.shape[0]
rwt = rwt.transpose(3, 2, 0, 1).reshape(no, nto, ni, nti, n, n)
return rwt


def tf_trans_filter2(w, inds):

flat_inds = flatten_indices(inds)
no, ni, nti, n, _ = w.shape
shape_info = (no, inds.shape[0], ni, nti, n)

w = w.reshape(no, ni * nti, n, n)

wt = tf.constant(w)
rwt = transform_filter_2d_nchw(wt, flat_inds, shape_info)

sess = tf.Session()
rwt = sess.run(rwt)
sess.close()

nto = inds.shape[0]
rwt = rwt.reshape(no, nto, ni, nti, n, n)
return rwt


def pytorch_trans_filter(w, inds):
w = torch.DoubleTensor(w)
rp = pytorch_trans_filter_(w, inds)
rp = rp.numpy()
return rp
9 changes: 9 additions & 0 deletions groupy/gconv/pytorch_gconv/pooling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import torch.nn.functional as F


def plane_group_spatial_max_pooling(x, ksize, stride=None, pad=0):
xs = x.size()
x = x.view(xs[0], xs[1] * xs[2], xs[3], xs[4])
x = F.max_pool2d(input=x, kernel_size=ksize, stride=stride, padding=pad)
x = x.view(xs[0], xs[1], xs[2], x.size()[2], x.size()[3])
return x
109 changes: 109 additions & 0 deletions groupy/gconv/pytorch_gconv/splitgconv2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import torch.nn as nn
from torch.nn import Parameter
import torch.nn.functional as F
import torch
import math
from torch.nn.modules.utils import _pair
from groupy.gconv.make_gconv_indices import *

make_indices_functions = {(1, 4): make_c4_z2_indices,
(4, 4): make_c4_p4_indices,
(1, 8): make_d4_z2_indices,
(8, 8): make_d4_p4m_indices}


def trans_filter(w, inds):
inds_reshape = inds.reshape((-1, inds.shape[-1])).astype(np.int64)
w_indexed = w[:, :, inds_reshape[:, 0].tolist(), inds_reshape[:, 1].tolist(), inds_reshape[:, 2].tolist()]
w_indexed = w_indexed.view(w_indexed.size()[0], w_indexed.size()[1],
inds.shape[0], inds.shape[1], inds.shape[2], inds.shape[3])
w_transformed = w_indexed.permute(0, 2, 1, 3, 4, 5)
return w_transformed.contiguous()


class SplitGConv2D(nn.Module):

def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, bias=True, input_stabilizer_size=1, output_stabilizer_size=4):
super(SplitGConv2D, self).__init__()
assert (input_stabilizer_size, output_stabilizer_size) in make_indices_functions.keys()
self.ksize = kernel_size

kernel_size = _pair(kernel_size)
stride = _pair(stride)
padding = _pair(padding)

self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.input_stabilizer_size = input_stabilizer_size
self.output_stabilizer_size = output_stabilizer_size

self.weight = Parameter(torch.Tensor(
out_channels, in_channels, self.input_stabilizer_size, *kernel_size))
if bias:
self.bias = Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()

self.inds = self.make_transformation_indices()

def reset_parameters(self):
n = self.in_channels
for k in self.kernel_size:
n *= k
stdv = 1. / math.sqrt(n)
self.weight.data.uniform_(-stdv, stdv)
if self.bias is not None:
self.bias.data.uniform_(-stdv, stdv)

def make_transformation_indices(self):
return make_indices_functions[(self.input_stabilizer_size, self.output_stabilizer_size)](self.ksize)

def forward(self, input):
tw = trans_filter(self.weight, self.inds)
tw_shape = (self.out_channels * self.output_stabilizer_size,
self.in_channels * self.input_stabilizer_size,
self.ksize, self.ksize)
tw = tw.view(tw_shape)

input_shape = input.size()
input = input.view(input_shape[0], self.in_channels*self.input_stabilizer_size, input_shape[-2], input_shape[-1])

y = F.conv2d(input, weight=tw, bias=None, stride=self.stride,
padding=self.padding)
batch_size, _, ny_out, nx_out = y.size()
y = y.view(batch_size, self.out_channels, self.output_stabilizer_size, ny_out, nx_out)

if self.bias is not None:
bias = self.bias.view(1, self.out_channels, 1, 1, 1)
y = y + bias

return y


class P4ConvZ2(SplitGConv2D):

def __init__(self, *args, **kwargs):
super(P4ConvZ2, self).__init__(input_stabilizer_size=1, output_stabilizer_size=4, *args, **kwargs)


class P4ConvP4(SplitGConv2D):

def __init__(self, *args, **kwargs):
super(P4ConvP4, self).__init__(input_stabilizer_size=4, output_stabilizer_size=4, *args, **kwargs)


class P4MConvZ2(SplitGConv2D):

def __init__(self, *args, **kwargs):
super(P4MConvZ2, self).__init__(input_stabilizer_size=1, output_stabilizer_size=8, *args, **kwargs)


class P4MConvP4M(SplitGConv2D):

def __init__(self, *args, **kwargs):
super(P4MConvP4M, self).__init__(input_stabilizer_size=8, output_stabilizer_size=8, *args, **kwargs)
Loading