Skip to content

Commit

Permalink
[X86] [NNVM] [TOPI] [WIP] Implement NCHWc Winograd convolutions
Browse files Browse the repository at this point in the history
This is the implementation alluded to in
https://discuss.tvm.ai/t/improved-direct-winograd-nchwc-cpu-implementation-with-resnet-50-results/

It is a pretty standard Winograd implementation, modified for NCHWc
layout. It achieves reasonable speedups (up to 2x vs current
implementation) on a number of ResNet 3x3 layers on SKL and AVX.

TODO: Parallelization
TODO: Benchmarking suite results on full ResNet suite.
TODO: Demonstration in `tune_nnvm_x86.py`
  • Loading branch information
ajtulloch committed Nov 20, 2018
1 parent 7f420f8 commit 7c33f51
Show file tree
Hide file tree
Showing 11 changed files with 1,021 additions and 207 deletions.
14 changes: 11 additions & 3 deletions nnvm/include/nnvm/top/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -174,10 +174,18 @@ struct Conv2DParam : public dmlc::Parameter<Conv2DParam> {

struct WinogradWeightTransformParam : public dmlc::Parameter<WinogradWeightTransformParam> {
int tile_size;

std::string kernel_layout;
DMLC_DECLARE_PARAMETER(WinogradWeightTransformParam) {
DMLC_DECLARE_FIELD(tile_size)
.describe("Tile size of winograd. E.g. 2 for F(2x2, 3x3) and 4 for F(4x4, 3x3)");
DMLC_DECLARE_FIELD(tile_size).describe("Tile size of winograd. E.g. 2 "
"for F(2x2, 3x3) and 4 for F(4x4, "
"3x3)");
DMLC_DECLARE_FIELD(kernel_layout)
.set_default("OIHW")
.describe(
"Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc."
"'O', 'I', 'H', 'W' stands for num_filter, input_channel, "
"height, and width"
"dimensions respectively.");
}

static const constexpr int kWeight = 0;
Expand Down
48 changes: 48 additions & 0 deletions nnvm/python/nnvm/top/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,54 @@ def schedule_contrib_conv2d_NCHWc(attrs, outs, target):

reg.register_pattern("_contrib_conv2d_NCHWc", OpPattern.OUT_ELEMWISE_FUSABLE)

@reg.register_compute("_contrib_conv2d_NCHWc_winograd_weight_transform")
def compute_contrib_conv2d_NCHWc_winograd_weight_transform(attrs, inputs, _):
return topi.nn.conv2d_NCHWc_winograd_weight_transform(
inputs[0], attrs.get_int('tile_size'), attrs.get_string("kernel_layout"))

@reg.register_schedule("_contrib_conv2d_NCHWc_winograd_weight_transform")
def schedule_contrib_conv2d_NCHWc_winograd_weight_transform(attrs, outs, target):
with tvm.target.create(target):
return topi.generic.schedule_conv2d_NCHWc_winograd_weight_transform(outs)

reg.register_pattern(
"_contrib_conv2d_NCHWc_winograd_weight_transform",
OpPattern.OUT_ELEMWISE_FUSABLE
)

@reg.register_compute("_contrib_conv2d_NCHWc_winograd_without_weight_transform")
def compute_contrib_conv2d_NCHWc_winograd_without_weight_transform(attrs, inputs, _):
"""Compute definition of conv2d NCHWc"""
padding = attrs.get_int_tuple("padding")
strides = attrs.get_int_tuple("strides")
dilation = attrs.get_int_tuple("dilation")
groups = attrs.get_int("groups")
layout = attrs.get_string("layout")
out_layout = attrs.get_string("out_layout")
out_dtype = attrs.get_string("out_dtype")
tile_size = attrs.get_int("tile_size")
out_dtype = inputs[0].dtype if out_dtype == "same" else out_dtype
assert dilation == (1, 1), "Do not support dilate now"
assert groups == 1, "Do not supoort arbitrary group number"

# pylint: disable=assignment-from-no-return
out = topi.nn.conv2d_NCHWc_winograd_without_weight_transform(
inputs[0], inputs[1], strides, padding, dilation, layout, out_layout,
out_dtype, tile_size)

if attrs.get_bool("use_bias"):
bias = inputs[2]
bias = topi.expand_dims(bias, axis=1, num_newaxis=2)
out = topi.add(out, bias)
return out

@reg.register_schedule("_contrib_conv2d_NCHWc_winograd_without_weight_transform")
def schedule_contrib_conv2d_NCHWc_winograd_without_weight_transform(attrs, outs, target):
with tvm.target.create(target):
return topi.generic.schedule_conv2d_NCHWc_winograd_without_weight_transform(outs)

reg.register_pattern("_contrib_conv2d_NCHWc_winograd_without_weight_transform",
OpPattern.OUT_ELEMWISE_FUSABLE)

@reg.register_compute("_contrib_conv2d_winograd_weight_transform")
def compute_contrib_conv2d_winograd_weight_transform(attrs, inputs, _):
Expand Down
78 changes: 77 additions & 1 deletion nnvm/src/top/nn/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,83 @@ NNVM_REGISTER_OP(_contrib_conv2d_winograd_without_weight_transform)

DMLC_REGISTER_PARAMETER(WinogradConv2DParam);

NNVM_REGISTER_OP(_contrib_conv2d_NCHWc_winograd_weight_transform)
.describe(
R"code(Weight transformation of winograd fast convolution algorithm.
Separate this into another nnvm symbol in order to enable Precompute Pass to compute the
weight transformation in advance.
- **weight**: (Packed weight matrix)
)code" NNVM_ADD_FILELINE)
.add_argument("weight", "6D Tensor", "Packed weight tensor.")
.add_arguments(WinogradWeightTransformParam::__FIELDS__())
.set_attr_parser(ParamParser<WinogradWeightTransformParam>)
.set_attr<FGetAttrDict>("FGetAttrDict",
ParamGetAttrDict<WinogradWeightTransformParam>)
.set_attr<FInferShape>(
"FInferShape",
[](const nnvm::NodeAttrs &attrs, std::vector<TShape> *in_shape,
std::vector<TShape> *out_shape) {
const auto &param =
nnvm::get<WinogradWeightTransformParam>(attrs.parsed);
const TShape &wshape = (*in_shape)[0];

CHECK_EQ(wshape.ndim(), 6)
<< "Packed Weight should be a 6 dimensional tensor";

// Input kernel layout is essentially COO, CII, KH, KW, CIII, COOO
// Transformed kernel layout is COO, CII, CIII, KH, KW, COOO
TShape oshape({wshape[0], wshape[1], wshape[4],
param.tile_size + wshape[2] - 1,
param.tile_size + wshape[3] - 1, wshape[5]});
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, oshape);
return true;
})
.set_attr<FCorrectLayout>("FCorrectLayout",
[](const NodeAttrs &attrs,
std::vector<Layout> *ilayouts,
const std::vector<Layout> *last_ilayouts,
std::vector<Layout> *olayouts) {
const auto &param =
nnvm::get<WinogradWeightTransformParam>(
attrs.parsed);
Layout kernel_layout(param.kernel_layout);
NNVM_ASSIGN_LAYOUT(*ilayouts, 0, kernel_layout);
NNVM_ASSIGN_LAYOUT(*olayouts, 0, kernel_layout);
return true;
})
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_num_outputs(1)
.set_num_inputs(1)
.set_support_level(5);

NNVM_REGISTER_OP(_contrib_conv2d_NCHWc_winograd_without_weight_transform)
.describe(R"code(Compute conv2d with winograd algorithm.
- **data**: Input is 5 array of shape (batch_size, in_channel_outer, height, width, in_channel_inner)
- **weight**: Any shape
We do not check shape for this input tensor.
- **bias**: (channels,)
- **out**: Output is 4D array of shape (batch_size, channels, out_height, out_width)
)code" NNVM_ADD_FILELINE)
.add_argument("data", "5D Tensor", "Input data.")
.add_argument("weight", "6D Tensor", "Transformed weight tensor.")
.add_argument("bias", "1D Tensor", "Bias parameter.")
.add_arguments(WinogradConv2DParam::__FIELDS__())
.set_attr_parser(ParamParser<WinogradConv2DParam>)
.set_attr<FGetAttrDict>("FGetAttrDict",
ParamGetAttrDict<WinogradConv2DParam>)
.set_attr<FListInputNames>("FListInputNames",
UseBiasListInputNames<WinogradConv2DParam>)
.set_attr<FInferShape>("FInferShape", WinogradConv2DInferShape)
.set_attr<FInferType>("FInferType", Conv2DInferType<WinogradConv2DParam>)
.set_attr<FCorrectLayout>("FCorrectLayout",
Conv2DCorrectLayout<WinogradConv2DParam>)
.set_num_outputs(1)
.set_num_inputs(UseBiasNumInputs<WinogradConv2DParam>)
.set_support_level(5);

NNVM_REGISTER_OP(_conv2d_grad)
.describe(R"code(2D convolution grad.
Expand Down Expand Up @@ -441,7 +518,6 @@ NNVM_REGISTER_OP(_conv2d_grad)
.set_attr<FInferType>("FInferType", ElemwiseType<3, -1>)
.set_attr<TIsBackward>("TIsBackward", true);


DMLC_REGISTER_PARAMETER(Conv2DTransposeParam);

inline bool Conv2DTransposeInferShape(const nnvm::NodeAttrs& attrs,
Expand Down
57 changes: 4 additions & 53 deletions topi/python/topi/arm_cpu/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@

import warnings

import numpy as np

import tvm
from tvm import autotvm

from ..generic import schedule_conv2d_nchw, schedule_conv2d_winograd_without_weight_transform
from ..util import traverse_inline, get_const_tuple, const_matrix
from ..util import traverse_inline, get_const_tuple
from ..nn import dilate, pad, conv2d, conv2d_alter_layout, conv2d_winograd_without_weight_transform
from ..nn.winograd_util import winograd_transform_matrices
from ..nn.util import get_const_int, get_pad_tuple

@autotvm.register_topi_compute(conv2d, 'arm_cpu', ['direct'])
Expand Down Expand Up @@ -304,53 +303,8 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt
assert layout == 'NCHW'
assert KH == 3 and KW == 3 and HPAD == 1 and WPAD == 1 and HSTR == 1 and WSTR == 1
data_pad = pad(data, (0, 0, HPAD, WPAD), name="data_pad")

if tile_size == 4:
G_data = np.array([
[1 / 4.0, 0, 0],
[-1 / 6.0, -1 / 6.0, -1 / 6.0],
[-1 / 6.0, 1 / 6.0, -1 / 6.0],
[1 / 24.0, 1 / 12.0, 1 / 6.0],
[1 / 24.0, -1 / 12.0, 1 / 6.0],
[0, 0, 1]], dtype=np.float32)

B_data = np.array([
[4, 0, 0, 0, 0, 0],
[0, -4, 4, -2, 2, 4],
[-5, -4, -4, -1, -1, 0],
[0, 1, -1, 2, -2, -5],
[1, 1, 1, 1, 1, 0],
[0, 0, 0, 0, 0, 1]], out_dtype)

A_data = np.array([
[1, 0, 0, 0],
[1, 1, 1, 1],
[1, -1, 1, -1],
[1, 2, 4, 8],
[1, -2, 4, -8],
[0, 0, 0, 1]], out_dtype)
elif tile_size == 2:
G_data = np.array([
[1, 0, 0],
[1.0/2, 1.0/2, 1.0/2],
[1.0/2, -1.0/2, 1.0/2],
[0, 0, 1]], np.float32)

B_data = np.array([
[1, 0, 0, 0],
[0, 1, -1, 1],
[-1, 1, 1, 0],
[0, 0, 0, -1]], out_dtype)

A_data = np.array([
[1, 0],
[1, 1],
[1, -1],
[0, -1]], out_dtype)
else:
raise ValueError("Unsupported tile size for winograd: " + str(tile_size))

m = A_data.shape[1]
A, B, G = winograd_transform_matrices(tile_size, out_dtype)
m = tile_size
r = 3
alpha = m + r - 1
K = CO
Expand All @@ -377,15 +331,13 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt
if pre_computed:
U = kernel
else:
G = const_matrix(G_data, 'G')
r_kh = tvm.reduce_axis((0, KH), 'r_kh')
r_kw = tvm.reduce_axis((0, KW), 'r_kw')
U = tvm.compute((alpha, alpha, K // VK, C, VK), lambda eps, nu, k, c, kk:
tvm.sum(kernel[k * VK + kk][c][r_kh][r_kw].astype(out_dtype) *
G[eps][r_kh] * G[nu][r_kw], axis=[r_kh, r_kw]), name='U')

# transform image
B = const_matrix(B_data, 'B')
r_eps = tvm.reduce_axis((0, alpha), 'r_eps')
r_nu = tvm.reduce_axis((0, alpha), 'r_nu')
V = tvm.compute((alpha, alpha, P // VP, C, VP), lambda eps, nu, b, c, bb:
Expand All @@ -399,7 +351,6 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt
V[eps][nu][b // VP][c][b % VP], axis=c), name='M')

# inverse transform
A = const_matrix(A_data, 'A')
r_eps = tvm.reduce_axis((0, alpha), 'r_eps')
r_nu = tvm.reduce_axis((0, alpha), 'r_nu')
Y = tvm.compute((K, P, m, m), lambda k, b, vh, vw:
Expand Down
59 changes: 4 additions & 55 deletions topi/python/topi/cuda/conv2d_winograd.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
# pylint: disable=invalid-name,unused-variable,unused-argument
"""Winograd template for cuda backend"""

import numpy as np

import tvm
from tvm import autotvm

from .. import nn
from ..nn import conv2d, conv2d_winograd_without_weight_transform
from ..util import get_const_int, get_const_tuple, const_matrix, traverse_inline
from ..util import get_const_int, get_const_tuple, traverse_inline
from ..generic import schedule_conv2d_winograd_without_weight_transform

from ..nn.winograd_util import winograd_transform_matrices

def _infer_tile_size(data, kernel):
N, CI, H, W = get_const_tuple(data.shape)
Expand Down Expand Up @@ -48,53 +45,8 @@ def winograd_cuda(cfg, data, kernel, strides, padding, dilation, layout, out_dty
_, _, CI, CO = get_const_tuple(kernel.shape)

data_pad = nn.pad(data, (0, 0, HPAD, WPAD), (0, 0, HPAD, WPAD), name="data_pad")

if tile_size == 4:
G_data = np.array([
[1 / 4.0, 0, 0],
[-1 / 6.0, -1 / 6.0, -1 / 6.0],
[-1 / 6.0, 1 / 6.0, -1 / 6.0],
[1 / 24.0, 1 / 12.0, 1 / 6.0],
[1 / 24.0, -1 / 12.0, 1 / 6.0],
[0, 0, 1]], dtype=np.float32)

B_data = np.array([
[4, 0, 0, 0, 0, 0],
[0, -4, 4, -2, 2, 4],
[-5, -4, -4, -1, -1, 0],
[0, 1, -1, 2, -2, -5],
[1, 1, 1, 1, 1, 0],
[0, 0, 0, 0, 0, 1]], out_dtype)

A_data = np.array([
[1, 0, 0, 0],
[1, 1, 1, 1],
[1, -1, 1, -1],
[1, 2, 4, 8],
[1, -2, 4, -8],
[0, 0, 0, 1]], out_dtype)
elif tile_size == 2:
G_data = np.array([
[1, 0, 0],
[1.0/2, 1.0/2, 1.0/2],
[1.0/2, -1.0/2, 1.0/2],
[0, 0, 1]], np.float32)

B_data = np.array([
[1, 0, 0, 0],
[0, 1, -1, 1],
[-1, 1, 1, 0],
[0, 0, 0, -1]], out_dtype)

A_data = np.array([
[1, 0],
[1, 1],
[1, -1],
[0, -1]], out_dtype)
else:
raise ValueError("Unsupported tile size for winograd: " + str(tile_size))

m = A_data.shape[1]
A, B, G = winograd_transform_matrices(tile_size, out_dtype)
m = tile_size
r = 3
alpha = m + r - 1
H = (H + 2 * HPAD - KH) // HSTR + 1
Expand All @@ -104,7 +56,6 @@ def winograd_cuda(cfg, data, kernel, strides, padding, dilation, layout, out_dty

# transform kernel
if not pre_computed:
G = const_matrix(G_data, 'G')
r_kh = tvm.reduce_axis((0, KH), name='r_kh')
r_kw = tvm.reduce_axis((0, KW), name='r_kw')
kernel_pack = tvm.compute((alpha, alpha, CI, CO), lambda eps, nu, ci, co:
Expand All @@ -120,7 +71,6 @@ def winograd_cuda(cfg, data, kernel, strides, padding, dilation, layout, out_dty
[p % nW * m + nu], name='d')

# transform data
B = const_matrix(B_data)
r_a = tvm.reduce_axis((0, alpha), 'r_a')
r_b = tvm.reduce_axis((0, alpha), 'r_a')
data_pack = tvm.compute((alpha, alpha, CI, P), lambda eps, nu, ci, p:
Expand All @@ -135,7 +85,6 @@ def winograd_cuda(cfg, data, kernel, strides, padding, dilation, layout, out_dty
axis=[ci]), name='bgemm')

# inverse transform
A = const_matrix(A_data)
r_a = tvm.reduce_axis((0, alpha), 'r_a')
r_b = tvm.reduce_axis((0, alpha), 'r_a')
inverse = tvm.compute((CO, P, m, m), lambda co, p, vh, vw:
Expand Down
9 changes: 9 additions & 0 deletions topi/python/topi/generic/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,15 @@ def schedule_conv2d_winograd_without_weight_transform(outs):
"""
return _default_schedule(outs, False)

@tvm.target.generic_func
def schedule_conv2d_NCHWc_winograd_weight_transform(outs):
return _default_schedule(outs, False)


@tvm.target.generic_func
def schedule_conv2d_NCHWc_winograd_without_weight_transform(outs):
return _default_schedule(outs, False)


@tvm.target.generic_func
def schedule_conv2d_transpose_nchw(outs):
Expand Down
Loading

0 comments on commit 7c33f51

Please sign in to comment.