Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[X86] [NNVM] [TOPI] Implement NCHWc Winograd convolutions #2111

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions nnvm/include/nnvm/top/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -174,10 +174,18 @@ struct Conv2DParam : public dmlc::Parameter<Conv2DParam> {

struct WinogradWeightTransformParam : public dmlc::Parameter<WinogradWeightTransformParam> {
int tile_size;

std::string kernel_layout;
DMLC_DECLARE_PARAMETER(WinogradWeightTransformParam) {
DMLC_DECLARE_FIELD(tile_size)
.describe("Tile size of winograd. E.g. 2 for F(2x2, 3x3) and 4 for F(4x4, 3x3)");
DMLC_DECLARE_FIELD(tile_size).describe("Tile size of winograd. E.g. 2 "
"for F(2x2, 3x3) and 4 for F(4x4, "
"3x3)");
DMLC_DECLARE_FIELD(kernel_layout)
.set_default("OIHW")
.describe(
"Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc."
"'O', 'I', 'H', 'W' stands for num_filter, input_channel, "
"height, and width"
"dimensions respectively.");
}

static const constexpr int kWeight = 0;
Expand Down
48 changes: 48 additions & 0 deletions nnvm/python/nnvm/top/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,54 @@ def schedule_contrib_conv2d_NCHWc(attrs, outs, target):

reg.register_pattern("_contrib_conv2d_NCHWc", OpPattern.OUT_ELEMWISE_FUSABLE)

@reg.register_compute("_contrib_conv2d_NCHWc_winograd_weight_transform")
def compute_contrib_conv2d_NCHWc_winograd_weight_transform(attrs, inputs, _):
return topi.nn.conv2d_NCHWc_winograd_weight_transform(
inputs[0], attrs.get_int('tile_size'), attrs.get_string("kernel_layout"))

@reg.register_schedule("_contrib_conv2d_NCHWc_winograd_weight_transform")
def schedule_contrib_conv2d_NCHWc_winograd_weight_transform(attrs, outs, target):
with tvm.target.create(target):
return topi.generic.schedule_conv2d_NCHWc_winograd_weight_transform(outs)

reg.register_pattern(
"_contrib_conv2d_NCHWc_winograd_weight_transform",
OpPattern.OUT_ELEMWISE_FUSABLE
)

@reg.register_compute("_contrib_conv2d_NCHWc_winograd_without_weight_transform")
def compute_contrib_conv2d_NCHWc_winograd_without_weight_transform(attrs, inputs, _):
"""Compute definition of conv2d NCHWc"""
padding = attrs.get_int_tuple("padding")
strides = attrs.get_int_tuple("strides")
dilation = attrs.get_int_tuple("dilation")
groups = attrs.get_int("groups")
layout = attrs.get_string("layout")
out_layout = attrs.get_string("out_layout")
out_dtype = attrs.get_string("out_dtype")
tile_size = attrs.get_int("tile_size")
out_dtype = inputs[0].dtype if out_dtype == "same" else out_dtype
assert dilation == (1, 1), "Do not support dilate now"
assert groups == 1, "Do not supoort arbitrary group number"

# pylint: disable=assignment-from-no-return
out = topi.nn.conv2d_NCHWc_winograd_without_weight_transform(
inputs[0], inputs[1], strides, padding, dilation, layout, out_layout,
out_dtype, tile_size)

if attrs.get_bool("use_bias"):
bias = inputs[2]
bias = topi.expand_dims(bias, axis=1, num_newaxis=2)
out = topi.add(out, bias)
return out

@reg.register_schedule("_contrib_conv2d_NCHWc_winograd_without_weight_transform")
def schedule_contrib_conv2d_NCHWc_winograd_without_weight_transform(attrs, outs, target):
with tvm.target.create(target):
return topi.generic.schedule_conv2d_NCHWc_winograd_without_weight_transform(outs)

reg.register_pattern("_contrib_conv2d_NCHWc_winograd_without_weight_transform",
OpPattern.OUT_ELEMWISE_FUSABLE)

@reg.register_compute("_contrib_conv2d_winograd_weight_transform")
def compute_contrib_conv2d_winograd_weight_transform(attrs, inputs, _):
Expand Down
78 changes: 77 additions & 1 deletion nnvm/src/top/nn/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,83 @@ NNVM_REGISTER_OP(_contrib_conv2d_winograd_without_weight_transform)

DMLC_REGISTER_PARAMETER(WinogradConv2DParam);

NNVM_REGISTER_OP(_contrib_conv2d_NCHWc_winograd_weight_transform)
Copy link
Member

@merrymercy merrymercy Nov 25, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is better to support NCHWc layout in the original symbol rather than creating a new symbol, as discussed by @masahi. Registering and maintaining new symbols is harder than adding "if" in the old symbol.

.describe(
R"code(Weight transformation of winograd fast convolution algorithm.
Separate this into another nnvm symbol in order to enable Precompute Pass to compute the
weight transformation in advance.
- **weight**: (Packed weight matrix)
)code" NNVM_ADD_FILELINE)
.add_argument("weight", "6D Tensor", "Packed weight tensor.")
.add_arguments(WinogradWeightTransformParam::__FIELDS__())
.set_attr_parser(ParamParser<WinogradWeightTransformParam>)
.set_attr<FGetAttrDict>("FGetAttrDict",
ParamGetAttrDict<WinogradWeightTransformParam>)
.set_attr<FInferShape>(
"FInferShape",
[](const nnvm::NodeAttrs &attrs, std::vector<TShape> *in_shape,
std::vector<TShape> *out_shape) {
const auto &param =
nnvm::get<WinogradWeightTransformParam>(attrs.parsed);
const TShape &wshape = (*in_shape)[0];

CHECK_EQ(wshape.ndim(), 6)
<< "Packed Weight should be a 6 dimensional tensor";

// Input kernel layout is essentially COO, CII, KH, KW, CIII, COOO
// Transformed kernel layout is COO, CII, CIII, KH, KW, COOO
TShape oshape({wshape[0], wshape[1], wshape[4],
param.tile_size + wshape[2] - 1,
param.tile_size + wshape[3] - 1, wshape[5]});
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, oshape);
return true;
})
.set_attr<FCorrectLayout>("FCorrectLayout",
[](const NodeAttrs &attrs,
std::vector<Layout> *ilayouts,
const std::vector<Layout> *last_ilayouts,
std::vector<Layout> *olayouts) {
const auto &param =
nnvm::get<WinogradWeightTransformParam>(
attrs.parsed);
Layout kernel_layout(param.kernel_layout);
NNVM_ASSIGN_LAYOUT(*ilayouts, 0, kernel_layout);
NNVM_ASSIGN_LAYOUT(*olayouts, 0, kernel_layout);
return true;
})
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_num_outputs(1)
.set_num_inputs(1)
.set_support_level(5);

NNVM_REGISTER_OP(_contrib_conv2d_NCHWc_winograd_without_weight_transform)
.describe(R"code(Compute conv2d with winograd algorithm.
- **data**: Input is 5 array of shape (batch_size, in_channel_outer, height, width, in_channel_inner)
- **weight**: Any shape
We do not check shape for this input tensor.
- **bias**: (channels,)
- **out**: Output is 4D array of shape (batch_size, channels, out_height, out_width)
)code" NNVM_ADD_FILELINE)
.add_argument("data", "5D Tensor", "Input data.")
.add_argument("weight", "6D Tensor", "Transformed weight tensor.")
.add_argument("bias", "1D Tensor", "Bias parameter.")
.add_arguments(WinogradConv2DParam::__FIELDS__())
.set_attr_parser(ParamParser<WinogradConv2DParam>)
.set_attr<FGetAttrDict>("FGetAttrDict",
ParamGetAttrDict<WinogradConv2DParam>)
.set_attr<FListInputNames>("FListInputNames",
UseBiasListInputNames<WinogradConv2DParam>)
.set_attr<FInferShape>("FInferShape", WinogradConv2DInferShape)
.set_attr<FInferType>("FInferType", Conv2DInferType<WinogradConv2DParam>)
.set_attr<FCorrectLayout>("FCorrectLayout",
Conv2DCorrectLayout<WinogradConv2DParam>)
.set_num_outputs(1)
.set_num_inputs(UseBiasNumInputs<WinogradConv2DParam>)
.set_support_level(5);

NNVM_REGISTER_OP(_conv2d_grad)
.describe(R"code(2D convolution grad.
Expand Down Expand Up @@ -441,7 +518,6 @@ NNVM_REGISTER_OP(_conv2d_grad)
.set_attr<FInferType>("FInferType", ElemwiseType<3, -1>)
.set_attr<TIsBackward>("TIsBackward", true);


DMLC_REGISTER_PARAMETER(Conv2DTransposeParam);

inline bool Conv2DTransposeInferShape(const nnvm::NodeAttrs& attrs,
Expand Down
57 changes: 4 additions & 53 deletions topi/python/topi/arm_cpu/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@

import warnings

import numpy as np

import tvm
from tvm import autotvm

from ..generic import schedule_conv2d_nchw, schedule_conv2d_winograd_without_weight_transform
from ..util import traverse_inline, get_const_tuple, const_matrix
from ..util import traverse_inline, get_const_tuple
from ..nn import dilate, pad, conv2d, conv2d_alter_layout, conv2d_winograd_without_weight_transform
from ..nn.winograd_util import winograd_transform_matrices
from ..nn.util import get_const_int, get_pad_tuple

@autotvm.register_topi_compute(conv2d, 'arm_cpu', ['direct'])
Expand Down Expand Up @@ -304,53 +303,8 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt
assert layout == 'NCHW'
assert KH == 3 and KW == 3 and HPAD == 1 and WPAD == 1 and HSTR == 1 and WSTR == 1
data_pad = pad(data, (0, 0, HPAD, WPAD), name="data_pad")

if tile_size == 4:
G_data = np.array([
[1 / 4.0, 0, 0],
[-1 / 6.0, -1 / 6.0, -1 / 6.0],
[-1 / 6.0, 1 / 6.0, -1 / 6.0],
[1 / 24.0, 1 / 12.0, 1 / 6.0],
[1 / 24.0, -1 / 12.0, 1 / 6.0],
[0, 0, 1]], dtype=np.float32)

B_data = np.array([
[4, 0, 0, 0, 0, 0],
[0, -4, 4, -2, 2, 4],
[-5, -4, -4, -1, -1, 0],
[0, 1, -1, 2, -2, -5],
[1, 1, 1, 1, 1, 0],
[0, 0, 0, 0, 0, 1]], out_dtype)

A_data = np.array([
[1, 0, 0, 0],
[1, 1, 1, 1],
[1, -1, 1, -1],
[1, 2, 4, 8],
[1, -2, 4, -8],
[0, 0, 0, 1]], out_dtype)
elif tile_size == 2:
G_data = np.array([
[1, 0, 0],
[1.0/2, 1.0/2, 1.0/2],
[1.0/2, -1.0/2, 1.0/2],
[0, 0, 1]], np.float32)

B_data = np.array([
[1, 0, 0, 0],
[0, 1, -1, 1],
[-1, 1, 1, 0],
[0, 0, 0, -1]], out_dtype)

A_data = np.array([
[1, 0],
[1, 1],
[1, -1],
[0, -1]], out_dtype)
else:
raise ValueError("Unsupported tile size for winograd: " + str(tile_size))

m = A_data.shape[1]
A, B, G = winograd_transform_matrices(tile_size, out_dtype)
m = tile_size
r = 3
alpha = m + r - 1
K = CO
Expand All @@ -377,15 +331,13 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt
if pre_computed:
U = kernel
else:
G = const_matrix(G_data, 'G')
r_kh = tvm.reduce_axis((0, KH), 'r_kh')
r_kw = tvm.reduce_axis((0, KW), 'r_kw')
U = tvm.compute((alpha, alpha, K // VK, C, VK), lambda eps, nu, k, c, kk:
tvm.sum(kernel[k * VK + kk][c][r_kh][r_kw].astype(out_dtype) *
G[eps][r_kh] * G[nu][r_kw], axis=[r_kh, r_kw]), name='U')

# transform image
B = const_matrix(B_data, 'B')
r_eps = tvm.reduce_axis((0, alpha), 'r_eps')
r_nu = tvm.reduce_axis((0, alpha), 'r_nu')
V = tvm.compute((alpha, alpha, P // VP, C, VP), lambda eps, nu, b, c, bb:
Expand All @@ -399,7 +351,6 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt
V[eps][nu][b // VP][c][b % VP], axis=c), name='M')

# inverse transform
A = const_matrix(A_data, 'A')
r_eps = tvm.reduce_axis((0, alpha), 'r_eps')
r_nu = tvm.reduce_axis((0, alpha), 'r_nu')
Y = tvm.compute((K, P, m, m), lambda k, b, vh, vw:
Expand Down
59 changes: 4 additions & 55 deletions topi/python/topi/cuda/conv2d_winograd.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
# pylint: disable=invalid-name,unused-variable,unused-argument
"""Winograd template for cuda backend"""

import numpy as np

import tvm
from tvm import autotvm

from .. import nn
from ..nn import conv2d, conv2d_winograd_without_weight_transform
from ..util import get_const_int, get_const_tuple, const_matrix, traverse_inline
from ..util import get_const_int, get_const_tuple, traverse_inline
from ..generic import schedule_conv2d_winograd_without_weight_transform

from ..nn.winograd_util import winograd_transform_matrices

def _infer_tile_size(data, kernel):
N, CI, H, W = get_const_tuple(data.shape)
Expand Down Expand Up @@ -48,53 +45,8 @@ def winograd_cuda(cfg, data, kernel, strides, padding, dilation, layout, out_dty
_, _, CI, CO = get_const_tuple(kernel.shape)

data_pad = nn.pad(data, (0, 0, HPAD, WPAD), (0, 0, HPAD, WPAD), name="data_pad")

if tile_size == 4:
G_data = np.array([
[1 / 4.0, 0, 0],
[-1 / 6.0, -1 / 6.0, -1 / 6.0],
[-1 / 6.0, 1 / 6.0, -1 / 6.0],
[1 / 24.0, 1 / 12.0, 1 / 6.0],
[1 / 24.0, -1 / 12.0, 1 / 6.0],
[0, 0, 1]], dtype=np.float32)

B_data = np.array([
[4, 0, 0, 0, 0, 0],
[0, -4, 4, -2, 2, 4],
[-5, -4, -4, -1, -1, 0],
[0, 1, -1, 2, -2, -5],
[1, 1, 1, 1, 1, 0],
[0, 0, 0, 0, 0, 1]], out_dtype)

A_data = np.array([
[1, 0, 0, 0],
[1, 1, 1, 1],
[1, -1, 1, -1],
[1, 2, 4, 8],
[1, -2, 4, -8],
[0, 0, 0, 1]], out_dtype)
elif tile_size == 2:
G_data = np.array([
[1, 0, 0],
[1.0/2, 1.0/2, 1.0/2],
[1.0/2, -1.0/2, 1.0/2],
[0, 0, 1]], np.float32)

B_data = np.array([
[1, 0, 0, 0],
[0, 1, -1, 1],
[-1, 1, 1, 0],
[0, 0, 0, -1]], out_dtype)

A_data = np.array([
[1, 0],
[1, 1],
[1, -1],
[0, -1]], out_dtype)
else:
raise ValueError("Unsupported tile size for winograd: " + str(tile_size))

m = A_data.shape[1]
A, B, G = winograd_transform_matrices(tile_size, out_dtype)
m = tile_size
r = 3
alpha = m + r - 1
H = (H + 2 * HPAD - KH) // HSTR + 1
Expand All @@ -104,7 +56,6 @@ def winograd_cuda(cfg, data, kernel, strides, padding, dilation, layout, out_dty

# transform kernel
if not pre_computed:
G = const_matrix(G_data, 'G')
r_kh = tvm.reduce_axis((0, KH), name='r_kh')
r_kw = tvm.reduce_axis((0, KW), name='r_kw')
kernel_pack = tvm.compute((alpha, alpha, CI, CO), lambda eps, nu, ci, co:
Expand All @@ -120,7 +71,6 @@ def winograd_cuda(cfg, data, kernel, strides, padding, dilation, layout, out_dty
[p % nW * m + nu], name='d')

# transform data
B = const_matrix(B_data)
r_a = tvm.reduce_axis((0, alpha), 'r_a')
r_b = tvm.reduce_axis((0, alpha), 'r_a')
data_pack = tvm.compute((alpha, alpha, CI, P), lambda eps, nu, ci, p:
Expand All @@ -135,7 +85,6 @@ def winograd_cuda(cfg, data, kernel, strides, padding, dilation, layout, out_dty
axis=[ci]), name='bgemm')

# inverse transform
A = const_matrix(A_data)
r_a = tvm.reduce_axis((0, alpha), 'r_a')
r_b = tvm.reduce_axis((0, alpha), 'r_a')
inverse = tvm.compute((CO, P, m, m), lambda co, p, vh, vw:
Expand Down
9 changes: 9 additions & 0 deletions topi/python/topi/generic/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,15 @@ def schedule_conv2d_winograd_without_weight_transform(outs):
"""
return _default_schedule(outs, False)

@tvm.target.generic_func
def schedule_conv2d_NCHWc_winograd_weight_transform(outs):
return _default_schedule(outs, False)


@tvm.target.generic_func
def schedule_conv2d_NCHWc_winograd_without_weight_transform(outs):
return _default_schedule(outs, False)


@tvm.target.generic_func
def schedule_conv2d_transpose_nchw(outs):
Expand Down
Loading