Skip to content

Commit

Permalink
[X86] [NNVM] [TOPI] [WIP] Implement NCHWc Winograd convolutions
Browse files Browse the repository at this point in the history
This is the implementation alluded to in
https://discuss.tvm.ai/t/improved-direct-winograd-nchwc-cpu-implementation-with-resnet-50-results/

It is a pretty standard Winograd implementation, modified for NCHWc
layout. It achieves reasonable speedups (up to 2x vs current
implementation) on a number of ResNet 3x3 layers on SKL and AVX.

TODO: Parallelization
TODO: Benchmarking suite results on full ResNet suite.
TODO: Demonstration in `tune_nnvm_x86.py`
  • Loading branch information
ajtulloch committed Nov 15, 2018
1 parent 1f2c815 commit a9bdc2e
Show file tree
Hide file tree
Showing 7 changed files with 1,013 additions and 27 deletions.
14 changes: 11 additions & 3 deletions nnvm/include/nnvm/top/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -174,10 +174,18 @@ struct Conv2DParam : public dmlc::Parameter<Conv2DParam> {

struct WinogradWeightTransformParam : public dmlc::Parameter<WinogradWeightTransformParam> {
int tile_size;

std::string kernel_layout;
DMLC_DECLARE_PARAMETER(WinogradWeightTransformParam) {
DMLC_DECLARE_FIELD(tile_size)
.describe("Tile size of winograd. E.g. 2 for F(2x2, 3x3) and 4 for F(4x4, 3x3)");
DMLC_DECLARE_FIELD(tile_size).describe("Tile size of winograd. E.g. 2 "
"for F(2x2, 3x3) and 4 for F(4x4, "
"3x3)");
DMLC_DECLARE_FIELD(kernel_layout)
.set_default("OIHW")
.describe(
"Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc."
"'O', 'I', 'H', 'W' stands for num_filter, input_channel, "
"height, and width"
"dimensions respectively.");
}

static const constexpr int kWeight = 0;
Expand Down
45 changes: 45 additions & 0 deletions nnvm/python/nnvm/top/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,51 @@ def schedule_contrib_conv2d_NCHWc(attrs, outs, target):

reg.register_pattern("_contrib_conv2d_NCHWc", OpPattern.OUT_ELEMWISE_FUSABLE)

@reg.register_compute("_contrib_conv2d_NCHWc_winograd_weight_transform")
def compute_contrib_conv2d_NCHWc_winograd_weight_transform(attrs, inputs, _):
return topi.nn.conv2d_NCHWc_winograd_weight_transform(
inputs[0], attrs.get_int('tile_size'), attrs.get_string("kernel_layout"))

@reg.register_schedule("_contrib_conv2d_NCHWc_winograd_weight_transform")
def schedule_contrib_conv2d_NCHWc_winograd_weight_transform(attrs, outs, target):
with tvm.target.create(target):
return topi.generic.schedule_conv2d_NCHWc_winograd_weight_transform(outs)

reg.register_pattern("_contrib_conv2d_NCHWc_winograd_weight_transform", OpPattern.OUT_ELEMWISE_FUSABLE)

@reg.register_compute("_contrib_conv2d_NCHWc_winograd_without_weight_transform")
def compute_contrib_conv2d_NCHWc_winograd_without_weight_transform(attrs, inputs, _):
"""Compute definition of conv2d NCHWc"""
padding = attrs.get_int_tuple("padding")
strides = attrs.get_int_tuple("strides")
dilation = attrs.get_int_tuple("dilation")
groups = attrs.get_int("groups")
layout = attrs.get_string("layout")
out_layout = attrs.get_string("out_layout")
out_dtype = attrs.get_string("out_dtype")
tile_size = attrs.get_int("tile_size")
out_dtype = inputs[0].dtype if out_dtype == "same" else out_dtype
assert dilation == (1, 1), "Do not support dilate now"
assert groups == 1, "Do not supoort arbitrary group number"

# pylint: disable=assignment-from-no-return
out = topi.nn.conv2d_NCHWc_winograd_without_weight_transform(
inputs[0], inputs[1], strides, padding, dilation, layout, out_layout,
out_dtype, tile_size)

if attrs.get_bool("use_bias"):
bias = inputs[2]
bias = topi.expand_dims(bias, axis=1, num_newaxis=2)
out = topi.add(out, bias)
return out

@reg.register_schedule("_contrib_conv2d_NCHWc_winograd_without_weight_transform")
def schedule_contrib_conv2d_NCHWc_winograd_without_weight_transform(attrs, outs, target):
with tvm.target.create(target):
return topi.generic.schedule_conv2d_NCHWc_winograd_without_weight_transform(outs)

reg.register_pattern("_contrib_conv2d_NCHWc_winograd_without_weight_transform",
OpPattern.OUT_ELEMWISE_FUSABLE)

@reg.register_compute("_contrib_conv2d_winograd_weight_transform")
def compute_contrib_conv2d_winograd_weight_transform(attrs, inputs, _):
Expand Down
78 changes: 77 additions & 1 deletion nnvm/src/top/nn/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,83 @@ NNVM_REGISTER_OP(_contrib_conv2d_winograd_without_weight_transform)

DMLC_REGISTER_PARAMETER(WinogradConv2DParam);

NNVM_REGISTER_OP(_contrib_conv2d_NCHWc_winograd_weight_transform)
.describe(
R"code(Weight transformation of winograd fast convolution algorithm.
Separate this into another nnvm symbol in order to enable Precompute Pass to compute the
weight transformation in advance.
- **weight**: (Packed weight matrix)
)code" NNVM_ADD_FILELINE)
.add_argument("weight", "6D Tensor", "Packed weight tensor.")
.add_arguments(WinogradWeightTransformParam::__FIELDS__())
.set_attr_parser(ParamParser<WinogradWeightTransformParam>)
.set_attr<FGetAttrDict>("FGetAttrDict",
ParamGetAttrDict<WinogradWeightTransformParam>)
.set_attr<FInferShape>(
"FInferShape",
[](const nnvm::NodeAttrs &attrs, std::vector<TShape> *in_shape,
std::vector<TShape> *out_shape) {
const auto &param =
nnvm::get<WinogradWeightTransformParam>(attrs.parsed);
const TShape &wshape = (*in_shape)[0];

CHECK_EQ(wshape.ndim(), 6)
<< "Packed Weight should be a 6 dimensional tensor";

// Input kernel layout is essentially COO, CII, KH, KW, CIII, COOO
// Transformed kernel layout is COO, CII, CIII, KH, KW, COOO
TShape oshape({wshape[0], wshape[1], wshape[4],
param.tile_size + wshape[2] - 1,
param.tile_size + wshape[3] - 1, wshape[5]});
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, oshape);
return true;
})
.set_attr<FCorrectLayout>("FCorrectLayout",
[](const NodeAttrs &attrs,
std::vector<Layout> *ilayouts,
const std::vector<Layout> *last_ilayouts,
std::vector<Layout> *olayouts) {
const auto &param =
nnvm::get<WinogradWeightTransformParam>(
attrs.parsed);
Layout kernel_layout(param.kernel_layout);
NNVM_ASSIGN_LAYOUT(*ilayouts, 0, kernel_layout);
NNVM_ASSIGN_LAYOUT(*olayouts, 0, kernel_layout);
return true;
})
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_num_outputs(1)
.set_num_inputs(1)
.set_support_level(5);

NNVM_REGISTER_OP(_contrib_conv2d_NCHWc_winograd_without_weight_transform)
.describe(R"code(Compute conv2d with winograd algorithm.
- **data**: Input is 5 array of shape (batch_size, in_channel_outer, height, width, in_channel_inner)
- **weight**: Any shape
We do not check shape for this input tensor.
- **bias**: (channels,)
- **out**: Output is 4D array of shape (batch_size, channels, out_height, out_width)
)code" NNVM_ADD_FILELINE)
.add_argument("data", "5D Tensor", "Input data.")
.add_argument("weight", "6D Tensor", "Transformed weight tensor.")
.add_argument("bias", "1D Tensor", "Bias parameter.")
.add_arguments(WinogradConv2DParam::__FIELDS__())
.set_attr_parser(ParamParser<WinogradConv2DParam>)
.set_attr<FGetAttrDict>("FGetAttrDict",
ParamGetAttrDict<WinogradConv2DParam>)
.set_attr<FListInputNames>("FListInputNames",
UseBiasListInputNames<WinogradConv2DParam>)
.set_attr<FInferShape>("FInferShape", WinogradConv2DInferShape)
.set_attr<FInferType>("FInferType", Conv2DInferType<WinogradConv2DParam>)
.set_attr<FCorrectLayout>("FCorrectLayout",
Conv2DCorrectLayout<WinogradConv2DParam>)
.set_num_outputs(1)
.set_num_inputs(UseBiasNumInputs<WinogradConv2DParam>)
.set_support_level(5);

NNVM_REGISTER_OP(_conv2d_grad)
.describe(R"code(2D convolution grad.
Expand Down Expand Up @@ -441,7 +518,6 @@ NNVM_REGISTER_OP(_conv2d_grad)
.set_attr<FInferType>("FInferType", ElemwiseType<3, -1>)
.set_attr<TIsBackward>("TIsBackward", true);


DMLC_REGISTER_PARAMETER(Conv2DTransposeParam);

inline bool Conv2DTransposeInferShape(const nnvm::NodeAttrs& attrs,
Expand Down
9 changes: 9 additions & 0 deletions topi/python/topi/generic/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,15 @@ def schedule_conv2d_winograd_without_weight_transform(outs):
"""
return _default_schedule(outs, False)

@tvm.target.generic_func
def schedule_conv2d_NCHWc_winograd_weight_transform(outs):
return _default_schedule(outs, False)


@tvm.target.generic_func
def schedule_conv2d_NCHWc_winograd_without_weight_transform(outs):
return _default_schedule(outs, False)


@tvm.target.generic_func
def schedule_conv2d_transpose_nchw(outs):
Expand Down
122 changes: 122 additions & 0 deletions topi/python/topi/nn/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,3 +480,125 @@ def group_conv2d_nchw(Input, Filter, stride, padding, dilation, groups, out_dtyp
xx * stride_w + rx * dilation_w].astype(out_dtype) *
Filter[ff, rc, ry, rx].astype(out_dtype),
axis=[rc, ry, rx]), tag="conv2d_nchw")


@tvm.target.generic_func
def conv2d_NCHWc_winograd_weight_transform(kernel, tile_size, kernel_layout):
"""Weight transformation for winograd NCHWc
Parameters
----------
kernel: Tensor
6-D with shape
[num_filter_chunk, in_channel_chunk, kernel_height, kernel_width,
in_channel_block, num_filter_block]
tile_size: int
Tile size of winograd transform. e.g. 2 for F(2x2, 3x3) and 4 for F(4x4, 3x3)
Returns
-------
output : Tensor
6-D with shape
[num_filter_chunk, in_channel_chunk, in_channel_block, alpha, alpha,
num_filter_block]
"""
COO, CII, KH, KW, CIII, VC = get_const_tuple(kernel.shape)

def get_G(m):
"""
Return the G transform matrix for the tile size `m` as a
`tvm.Expr`.
"""
assert m in (2, 4, 6)
if m == 4:
G_data = np.array(
[
[1 / 4.0, 0, 0],
[-1 / 6.0, -1 / 6.0, -1 / 6.0],
[-1 / 6.0, 1 / 6.0, -1 / 6.0],
[1 / 24.0, 1 / 12.0, 1 / 6.0],
[1 / 24.0, -1 / 12.0, 1 / 6.0],
[0, 0, 1],
],
dtype=np.float32,
)
elif m == 6:
G_data = np.array(
[
[1, 0, 0],
[-2 / 9, -2 / 9, -2 / 9],
[-2 / 9, 2 / 9, -2 / 9],
[1 / 90, 1 / 45, 2 / 45],
[1 / 90, -1 / 45, 2 / 45],
[1 / 45, 1 / 90, 1 / 180],
[1 / 45, -1 / 90, 1 / 180],
[0, 0, 1],
],
dtype=np.float32,
)
elif m == 2:
G_data = np.array(
[
[1, 0, 0],
[1.0 / 2, 1.0 / 2, 1.0 / 2],
[1.0 / 2, -1.0 / 2, 1.0 / 2],
[0, 0, 1],
],
np.float32,
)
return const_matrix(G_data, "G")

G = get_G(tile_size)
# transform kernel

r_kh = tvm.reduce_axis((0, KH), "r_kh")
r_kw = tvm.reduce_axis((0, KW), "r_kw")
alpha = tile_size + 3 - 1
U = tvm.compute(
(COO, CII, CIII, alpha, alpha, VC),
lambda coo, cii, ciii, eps, nu, vc: tvm.sum(
kernel[coo][cii][r_kh][r_kw][ciii][vc]
* G[eps][r_kh]
* G[nu][r_kw],
axis=[r_kh, r_kw],
),
name="U",
)
return U


@tvm.target.generic_func
def conv2d_NCHWc_winograd_without_weight_transform(
input, filter, strides, padding, dilation, layout, out_layout, out_dtype, tile_size):
"""Compute convolution in winograd algorithm. The filter is supposed to be transformed
in advance.
Parameters
----------
input : tvm.Tensor
5-D with shape [batch, in_channel_chunk, in_height, in_width, in_channel_block]
filter : tvm.Tensor
6-D with shape
[num_filter_chunk, in_channel_chunk, in_channel_block, alpha, alpha, num_filter_block]
strides : int or a list/tuple of two ints
Stride size, or [stride_height, stride_width]
padding : int or str
Padding size, or ['VALID', 'SAME']
dilation: int or a list/tuple of two ints
dilation size, or [dilation_height, dilation_width]
layout : str
Input data layout
out_layout : str
Output data layout
out_dtype : str
output data type
tile_size: int
Tile size of winograd transform. e.g. 2 for F(2x2, 3x3) and 4 for F(4x4, 3x3)
Returns
-------
output : tvm.Tensor
5-D with shape [batch, out_channel_chunk, out_height, out_width, out_channel_block]
"""
raise ValueError("missing register for topi.nn.conv2d_NCHWc_winograd_without_weight_transform")
Loading

0 comments on commit a9bdc2e

Please sign in to comment.