From a9bdc2eb812805f594b8ca82ad86efe8cdc0474c Mon Sep 17 00:00:00 2001 From: Andrew Tulloch Date: Wed, 14 Nov 2018 15:38:04 -0800 Subject: [PATCH] [X86] [NNVM] [TOPI] [WIP] Implement NCHWc Winograd convolutions This is the implementation alluded to in https://discuss.tvm.ai/t/improved-direct-winograd-nchwc-cpu-implementation-with-resnet-50-results/ It is a pretty standard Winograd implementation, modified for NCHWc layout. It achieves reasonable speedups (up to 2x vs current implementation) on a number of ResNet 3x3 layers on SKL and AVX. TODO: Parallelization TODO: Benchmarking suite results on full ResNet suite. TODO: Demonstration in `tune_nnvm_x86.py` --- nnvm/include/nnvm/top/nn.h | 14 +- nnvm/python/nnvm/top/nn.py | 45 ++ nnvm/src/top/nn/convolution.cc | 78 ++- topi/python/topi/generic/nn.py | 9 + topi/python/topi/nn/conv2d.py | 122 ++++ topi/python/topi/x86/conv2d.py | 570 +++++++++++++++++- .../python/test_topi_conv2d_NCHWc_winograd.py | 202 +++++++ 7 files changed, 1013 insertions(+), 27 deletions(-) create mode 100644 topi/tests/python/test_topi_conv2d_NCHWc_winograd.py diff --git a/nnvm/include/nnvm/top/nn.h b/nnvm/include/nnvm/top/nn.h index 143a9548f18ab..f7cbe7b76af4d 100644 --- a/nnvm/include/nnvm/top/nn.h +++ b/nnvm/include/nnvm/top/nn.h @@ -174,10 +174,18 @@ struct Conv2DParam : public dmlc::Parameter { struct WinogradWeightTransformParam : public dmlc::Parameter { int tile_size; - + std::string kernel_layout; DMLC_DECLARE_PARAMETER(WinogradWeightTransformParam) { - DMLC_DECLARE_FIELD(tile_size) - .describe("Tile size of winograd. E.g. 2 for F(2x2, 3x3) and 4 for F(4x4, 3x3)"); + DMLC_DECLARE_FIELD(tile_size).describe("Tile size of winograd. E.g. 2 " + "for F(2x2, 3x3) and 4 for F(4x4, " + "3x3)"); + DMLC_DECLARE_FIELD(kernel_layout) + .set_default("OIHW") + .describe( + "Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc." + "'O', 'I', 'H', 'W' stands for num_filter, input_channel, " + "height, and width" + "dimensions respectively."); } static const constexpr int kWeight = 0; diff --git a/nnvm/python/nnvm/top/nn.py b/nnvm/python/nnvm/top/nn.py index 2069a0a5ad50a..d111a6071e562 100644 --- a/nnvm/python/nnvm/top/nn.py +++ b/nnvm/python/nnvm/top/nn.py @@ -204,6 +204,51 @@ def schedule_contrib_conv2d_NCHWc(attrs, outs, target): reg.register_pattern("_contrib_conv2d_NCHWc", OpPattern.OUT_ELEMWISE_FUSABLE) +@reg.register_compute("_contrib_conv2d_NCHWc_winograd_weight_transform") +def compute_contrib_conv2d_NCHWc_winograd_weight_transform(attrs, inputs, _): + return topi.nn.conv2d_NCHWc_winograd_weight_transform( + inputs[0], attrs.get_int('tile_size'), attrs.get_string("kernel_layout")) + +@reg.register_schedule("_contrib_conv2d_NCHWc_winograd_weight_transform") +def schedule_contrib_conv2d_NCHWc_winograd_weight_transform(attrs, outs, target): + with tvm.target.create(target): + return topi.generic.schedule_conv2d_NCHWc_winograd_weight_transform(outs) + +reg.register_pattern("_contrib_conv2d_NCHWc_winograd_weight_transform", OpPattern.OUT_ELEMWISE_FUSABLE) + +@reg.register_compute("_contrib_conv2d_NCHWc_winograd_without_weight_transform") +def compute_contrib_conv2d_NCHWc_winograd_without_weight_transform(attrs, inputs, _): + """Compute definition of conv2d NCHWc""" + padding = attrs.get_int_tuple("padding") + strides = attrs.get_int_tuple("strides") + dilation = attrs.get_int_tuple("dilation") + groups = attrs.get_int("groups") + layout = attrs.get_string("layout") + out_layout = attrs.get_string("out_layout") + out_dtype = attrs.get_string("out_dtype") + tile_size = attrs.get_int("tile_size") + out_dtype = inputs[0].dtype if out_dtype == "same" else out_dtype + assert dilation == (1, 1), "Do not support dilate now" + assert groups == 1, "Do not supoort arbitrary group number" + + # pylint: disable=assignment-from-no-return + out = topi.nn.conv2d_NCHWc_winograd_without_weight_transform( + inputs[0], inputs[1], strides, padding, dilation, layout, out_layout, + out_dtype, tile_size) + + if attrs.get_bool("use_bias"): + bias = inputs[2] + bias = topi.expand_dims(bias, axis=1, num_newaxis=2) + out = topi.add(out, bias) + return out + +@reg.register_schedule("_contrib_conv2d_NCHWc_winograd_without_weight_transform") +def schedule_contrib_conv2d_NCHWc_winograd_without_weight_transform(attrs, outs, target): + with tvm.target.create(target): + return topi.generic.schedule_conv2d_NCHWc_winograd_without_weight_transform(outs) + +reg.register_pattern("_contrib_conv2d_NCHWc_winograd_without_weight_transform", + OpPattern.OUT_ELEMWISE_FUSABLE) @reg.register_compute("_contrib_conv2d_winograd_weight_transform") def compute_contrib_conv2d_winograd_weight_transform(attrs, inputs, _): diff --git a/nnvm/src/top/nn/convolution.cc b/nnvm/src/top/nn/convolution.cc index 8139474921175..769f7cc3cacf2 100644 --- a/nnvm/src/top/nn/convolution.cc +++ b/nnvm/src/top/nn/convolution.cc @@ -414,6 +414,83 @@ NNVM_REGISTER_OP(_contrib_conv2d_winograd_without_weight_transform) DMLC_REGISTER_PARAMETER(WinogradConv2DParam); +NNVM_REGISTER_OP(_contrib_conv2d_NCHWc_winograd_weight_transform) + .describe( + R"code(Weight transformation of winograd fast convolution algorithm. +Separate this into another nnvm symbol in order to enable Precompute Pass to compute the +weight transformation in advance. + +- **weight**: (Packed weight matrix) +)code" NNVM_ADD_FILELINE) + .add_argument("weight", "6D Tensor", "Packed weight tensor.") + .add_arguments(WinogradWeightTransformParam::__FIELDS__()) + .set_attr_parser(ParamParser) + .set_attr("FGetAttrDict", + ParamGetAttrDict) + .set_attr( + "FInferShape", + [](const nnvm::NodeAttrs &attrs, std::vector *in_shape, + std::vector *out_shape) { + const auto ¶m = + nnvm::get(attrs.parsed); + const TShape &wshape = (*in_shape)[0]; + + CHECK_EQ(wshape.ndim(), 6) + << "Packed Weight should be a 6 dimensional tensor"; + + // Input kernel layout is essentially COO, CII, KH, KW, CIII, COOO + // Transformed kernel layout is COO, CII, CIII, KH, KW, COOO + TShape oshape({wshape[0], wshape[1], wshape[4], + param.tile_size + wshape[2] - 1, + param.tile_size + wshape[3] - 1, wshape[5]}); + NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, oshape); + return true; + }) + .set_attr("FCorrectLayout", + [](const NodeAttrs &attrs, + std::vector *ilayouts, + const std::vector *last_ilayouts, + std::vector *olayouts) { + const auto ¶m = + nnvm::get( + attrs.parsed); + Layout kernel_layout(param.kernel_layout); + NNVM_ASSIGN_LAYOUT(*ilayouts, 0, kernel_layout); + NNVM_ASSIGN_LAYOUT(*olayouts, 0, kernel_layout); + return true; + }) + .set_attr("FInferType", ElemwiseType<1, 1>) + .set_num_outputs(1) + .set_num_inputs(1) + .set_support_level(5); + +NNVM_REGISTER_OP(_contrib_conv2d_NCHWc_winograd_without_weight_transform) + .describe(R"code(Compute conv2d with winograd algorithm. + +- **data**: Input is 5 array of shape (batch_size, in_channel_outer, height, width, in_channel_inner) +- **weight**: Any shape + We do not check shape for this input tensor. + +- **bias**: (channels,) +- **out**: Output is 4D array of shape (batch_size, channels, out_height, out_width) +)code" NNVM_ADD_FILELINE) + .add_argument("data", "5D Tensor", "Input data.") + .add_argument("weight", "6D Tensor", "Transformed weight tensor.") + .add_argument("bias", "1D Tensor", "Bias parameter.") + .add_arguments(WinogradConv2DParam::__FIELDS__()) + .set_attr_parser(ParamParser) + .set_attr("FGetAttrDict", + ParamGetAttrDict) + .set_attr("FListInputNames", + UseBiasListInputNames) + .set_attr("FInferShape", WinogradConv2DInferShape) + .set_attr("FInferType", Conv2DInferType) + .set_attr("FCorrectLayout", + Conv2DCorrectLayout) + .set_num_outputs(1) + .set_num_inputs(UseBiasNumInputs) + .set_support_level(5); + NNVM_REGISTER_OP(_conv2d_grad) .describe(R"code(2D convolution grad. @@ -441,7 +518,6 @@ NNVM_REGISTER_OP(_conv2d_grad) .set_attr("FInferType", ElemwiseType<3, -1>) .set_attr("TIsBackward", true); - DMLC_REGISTER_PARAMETER(Conv2DTransposeParam); inline bool Conv2DTransposeInferShape(const nnvm::NodeAttrs& attrs, diff --git a/topi/python/topi/generic/nn.py b/topi/python/topi/generic/nn.py index 8c303e5be182d..d00f88bfc1ea4 100644 --- a/topi/python/topi/generic/nn.py +++ b/topi/python/topi/generic/nn.py @@ -120,6 +120,15 @@ def schedule_conv2d_winograd_without_weight_transform(outs): """ return _default_schedule(outs, False) +@tvm.target.generic_func +def schedule_conv2d_NCHWc_winograd_weight_transform(outs): + return _default_schedule(outs, False) + + +@tvm.target.generic_func +def schedule_conv2d_NCHWc_winograd_without_weight_transform(outs): + return _default_schedule(outs, False) + @tvm.target.generic_func def schedule_conv2d_transpose_nchw(outs): diff --git a/topi/python/topi/nn/conv2d.py b/topi/python/topi/nn/conv2d.py index d4b9393c19dda..f913e22841455 100644 --- a/topi/python/topi/nn/conv2d.py +++ b/topi/python/topi/nn/conv2d.py @@ -480,3 +480,125 @@ def group_conv2d_nchw(Input, Filter, stride, padding, dilation, groups, out_dtyp xx * stride_w + rx * dilation_w].astype(out_dtype) * Filter[ff, rc, ry, rx].astype(out_dtype), axis=[rc, ry, rx]), tag="conv2d_nchw") + + +@tvm.target.generic_func +def conv2d_NCHWc_winograd_weight_transform(kernel, tile_size, kernel_layout): + """Weight transformation for winograd NCHWc + + Parameters + ---------- + kernel: Tensor + 6-D with shape + [num_filter_chunk, in_channel_chunk, kernel_height, kernel_width, + in_channel_block, num_filter_block] + tile_size: int + Tile size of winograd transform. e.g. 2 for F(2x2, 3x3) and 4 for F(4x4, 3x3) + + Returns + ------- + output : Tensor + 6-D with shape + [num_filter_chunk, in_channel_chunk, in_channel_block, alpha, alpha, + num_filter_block] + + """ + COO, CII, KH, KW, CIII, VC = get_const_tuple(kernel.shape) + + def get_G(m): + """ + Return the G transform matrix for the tile size `m` as a + `tvm.Expr`. + """ + assert m in (2, 4, 6) + if m == 4: + G_data = np.array( + [ + [1 / 4.0, 0, 0], + [-1 / 6.0, -1 / 6.0, -1 / 6.0], + [-1 / 6.0, 1 / 6.0, -1 / 6.0], + [1 / 24.0, 1 / 12.0, 1 / 6.0], + [1 / 24.0, -1 / 12.0, 1 / 6.0], + [0, 0, 1], + ], + dtype=np.float32, + ) + elif m == 6: + G_data = np.array( + [ + [1, 0, 0], + [-2 / 9, -2 / 9, -2 / 9], + [-2 / 9, 2 / 9, -2 / 9], + [1 / 90, 1 / 45, 2 / 45], + [1 / 90, -1 / 45, 2 / 45], + [1 / 45, 1 / 90, 1 / 180], + [1 / 45, -1 / 90, 1 / 180], + [0, 0, 1], + ], + dtype=np.float32, + ) + elif m == 2: + G_data = np.array( + [ + [1, 0, 0], + [1.0 / 2, 1.0 / 2, 1.0 / 2], + [1.0 / 2, -1.0 / 2, 1.0 / 2], + [0, 0, 1], + ], + np.float32, + ) + return const_matrix(G_data, "G") + + G = get_G(tile_size) + # transform kernel + + r_kh = tvm.reduce_axis((0, KH), "r_kh") + r_kw = tvm.reduce_axis((0, KW), "r_kw") + alpha = tile_size + 3 - 1 + U = tvm.compute( + (COO, CII, CIII, alpha, alpha, VC), + lambda coo, cii, ciii, eps, nu, vc: tvm.sum( + kernel[coo][cii][r_kh][r_kw][ciii][vc] + * G[eps][r_kh] + * G[nu][r_kw], + axis=[r_kh, r_kw], + ), + name="U", + ) + return U + + +@tvm.target.generic_func +def conv2d_NCHWc_winograd_without_weight_transform( + input, filter, strides, padding, dilation, layout, out_layout, out_dtype, tile_size): + """Compute convolution in winograd algorithm. The filter is supposed to be transformed + in advance. + + Parameters + ---------- + input : tvm.Tensor + 5-D with shape [batch, in_channel_chunk, in_height, in_width, in_channel_block] + filter : tvm.Tensor + 6-D with shape + [num_filter_chunk, in_channel_chunk, in_channel_block, alpha, alpha, num_filter_block] + strides : int or a list/tuple of two ints + Stride size, or [stride_height, stride_width] + padding : int or str + Padding size, or ['VALID', 'SAME'] + dilation: int or a list/tuple of two ints + dilation size, or [dilation_height, dilation_width] + layout : str + Input data layout + out_layout : str + Output data layout + out_dtype : str + output data type + tile_size: int + Tile size of winograd transform. e.g. 2 for F(2x2, 3x3) and 4 for F(4x4, 3x3) + + Returns + ------- + output : tvm.Tensor + 5-D with shape [batch, out_channel_chunk, out_height, out_width, out_channel_block] + """ + raise ValueError("missing register for topi.nn.conv2d_NCHWc_winograd_without_weight_transform") diff --git a/topi/python/topi/x86/conv2d.py b/topi/python/topi/x86/conv2d.py index 7e0b90f1db9b2..952e7aa8a855d 100644 --- a/topi/python/topi/x86/conv2d.py +++ b/topi/python/topi/x86/conv2d.py @@ -6,16 +6,24 @@ from tvm.autotvm.task import get_config from .. import generic, tag from .. import nn -from ..util import get_const_tuple -from ..nn.conv2d import conv2d, conv2d_NCHWc, \ - conv2d_alter_layout, _get_workload as _get_conv2d_workload +from ..util import get_const_tuple, const_matrix, get_const_int, traverse_inline +from ..nn.conv2d import ( + conv2d, + conv2d_NCHWc, + conv2d_NCHWc_winograd_weight_transform, + conv2d_NCHWc_winograd_without_weight_transform, + conv2d_alter_layout, + _get_workload as _get_conv2d_workload) + from ..nn.dilate import dilate from ..nn.depthwise_conv2d import _get_workload as _get_depthwise_conv2d_workload from ..nn.depthwise_conv2d import depthwise_conv2d_NCHWc, depthwise_conv2d_nchw from ..nn.pad import pad - +from ..nn.util import get_pad_tuple from . import conv2d_avx_1x1, conv2d_avx_common + + def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depthwise=False): """ Get default schedule config for the workload @@ -282,6 +290,38 @@ def _topi_nn_conv2d_NCHWc(*args, **kwargs): return s, [new_data, new_kernel, C] +@autotvm.task.register("topi_x86_conv2d_NCHWc_winograd") +def _topi_nn_conv2d_NCHWc_winograd(*args, **kwargs): + assert not kwargs, "Do not support kwargs in template function call" + data, kernel, strides, padding, dilation, origin_layout, dtype = deserialize_args(args) + raw_data_shape = get_const_tuple(data.shape) + raw_kernel_shape = get_const_tuple(kernel.shape) + + # get config here + cfg = get_config() + _create_tuning_space(cfg, data, kernel, strides, padding, dilation, origin_layout) + + # change shape with the value in config + ic_bn, oc_bn, ow_bn = (cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1], + cfg["tile_ow"].size[-1]) + new_data_shape = (raw_data_shape[0], raw_data_shape[1] // ic_bn, + raw_data_shape[2], raw_data_shape[3], ic_bn) + data_layout = "NCHW%dc" % ic_bn + out_layout = "NCHW%dc" % oc_bn + new_kernel_shape = (raw_kernel_shape[0] // oc_bn, raw_kernel_shape[1] // ic_bn, + raw_kernel_shape[2], raw_kernel_shape[3], ic_bn, oc_bn) + new_data = tvm.placeholder(new_data_shape, data.dtype) + new_kernel = tvm.placeholder(new_kernel_shape, kernel.dtype) + + C = _declaration_conv_NCHWc_winograd_impl( + cfg, new_data, new_kernel, strides, padding, dilation, + data_layout, out_layout, dtype, + transform_kernel=True, tile_size=None) + s = tvm.create_schedule([C.op]) + s = _schedule_conv2d_NCHWc_winograd(cfg, s, C, C) + return s, [new_data, new_kernel, C] + + @conv2d_alter_layout.register("cpu") def _alter_conv2d_layout(attrs, inputs, tinfo): import nnvm.symbol as sym @@ -294,6 +334,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfo): out_channel = attrs.get_int("channels") padding = attrs.get_int_tuple("padding") strides = attrs.get_int_tuple("strides") + dilation = attrs.get_int_tuple("dilation") layout = attrs['layout'] kh, kw = attrs.get_int_tuple("kernel_size") @@ -311,10 +352,10 @@ def _alter_conv2d_layout(attrs, inputs, tinfo): target = tvm.target.current_target() # query schedule and fallback if necessary workload = autotvm.task.args_to_workload( - [data, kernel, strides, padding, out_dtype], depthwise_conv2d_nchw) \ + [data, kernel, strides, padding, dilation, out_dtype], depthwise_conv2d_nchw) \ if is_depthwise else \ autotvm.task.args_to_workload( - [data, kernel, strides, padding, layout, out_dtype], conv2d) + [data, kernel, strides, padding, dilation, layout, out_dtype], conv2d) cfg = dispatch_ctx.query(target, workload) if cfg.is_fallback: _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depthwise) @@ -338,7 +379,10 @@ def _alter_conv2d_layout(attrs, inputs, tinfo): new_workload = autotvm.task.args_to_workload( [new_data, new_kernel, strides, padding, new_attrs['layout'], new_attrs['out_layout'], out_dtype], depthwise_conv2d_NCHWc) - else: + dispatch_ctx.update(target, new_workload, cfg) + return sym.contrib.conv2d_NCHWc(*copy_inputs, **new_attrs) + + elif cfg.is_fallback or cfg.template_key == "direct": out_channel, _, kh, kw = get_const_tuple(kernel.shape) # (oc, ic, h, w) -> (OC, IC, h, w, ic, oc) new_attrs['kernel_layout'] = 'OIHW%di%do' % (ic_bn, oc_bn) @@ -349,9 +393,42 @@ def _alter_conv2d_layout(attrs, inputs, tinfo): new_workload = autotvm.task.args_to_workload( [new_data, new_kernel, strides, padding, new_attrs['layout'], new_attrs['out_layout'], out_dtype], conv2d_NCHWc) + dispatch_ctx.update(target, new_workload, cfg) + return sym.contrib.conv2d_NCHWc(*copy_inputs, **new_attrs) + elif cfg.template_key == "winograd": + tile_size = cfg["tile_size"].val + out_channel, _, kh, kw = get_const_tuple(kernel.shape) + assert (kh, kw) == (3, 3) + # (oc, ic, h, w) -> (OC, IC, h, w, ic, oc) + new_attrs['kernel_layout'] = 'OIHW%di%do' % (ic_bn, oc_bn) + new_attrs['tile_size'] = tile_size + # Store altered operator's config + new_kernel = tvm.placeholder( + (out_channel//oc_bn, in_channel//ic_bn, ic_bn, + tile_size + 3 - 1, tile_size + 3 - 1, oc_bn), + dtype=kernel.dtype) + + new_kernel_workload = autotvm.task.args_to_workload( + [kernel, new_attrs['kernel_layout'], out_dtype, tile_size], + conv2d_NCHWc_winograd_weight_transform) + + new_kernel_sym = sym.contrib.conv2d_NCHWc_winograd_weight_transform( + copy_inputs[1], + kernel_layout=new_attrs['kernel_layout'], + tile_size=tile_size) + dispatch_ctx.update(target, new_kernel_workload, cfg) + copy_inputs[1] = new_kernel_sym - dispatch_ctx.update(target, new_workload, cfg) - return sym.contrib.conv2d_NCHWc(*copy_inputs, **new_attrs) + new_workload = autotvm.task.args_to_workload( + [new_data, new_kernel, strides, padding, dilation, new_attrs['layout'], + new_attrs['out_layout'], out_dtype], + conv2d_NCHWc_winograd_without_weight_transform) + dispatch_ctx.update(target, new_workload, cfg) + return sym.contrib.conv2d_NCHWc_winograd_without_weight_transform( + *copy_inputs, + **new_attrs) + else: + raise RuntimeError("Unknown template: {}".format(cfg.template_key)) @autotvm.register_topi_compute(conv2d_NCHWc, 'cpu', 'direct') @@ -421,24 +498,297 @@ def _declaration_conv_NCHWc(cfg, data, kernel, strides, axis=[ic, kh, kw]), name='conv2d_NCHWc', tag="conv2d_NCHWc") +def get_transform_matrices(m): + """Compute the A, B, and G transform matrices for + the tile size `m` as a `tvm.Expr`. + """ + import numpy as np + + assert m in (2, 4, 6) + if m == 4: + G_data = np.array( + [ + [1 / 4.0, 0, 0], + [-1 / 6.0, -1 / 6.0, -1 / 6.0], + [-1 / 6.0, 1 / 6.0, -1 / 6.0], + [1 / 24.0, 1 / 12.0, 1 / 6.0], + [1 / 24.0, -1 / 12.0, 1 / 6.0], + [0, 0, 1], + ], + dtype=np.float32, + ) + + B_data = np.array( + [ + [4, 0, 0, 0, 0, 0], + [0, -4, 4, -2, 2, 4], + [-5, -4, -4, -1, -1, 0], + [0, 1, -1, 2, -2, -5], + [1, 1, 1, 1, 1, 0], + [0, 0, 0, 0, 0, 1], + ], + dtype=np.float32, + ) + + A_data = np.array( + [ + [1, 0, 0, 0], + [1, 1, 1, 1], + [1, -1, 1, -1], + [1, 2, 4, 8], + [1, -2, 4, -8], + [0, 0, 0, 1], + ], + dtype=np.float32, + ) + + elif m == 6: + G_data = np.array( + [ + [1, 0, 0], + [-2 / 9, -2 / 9, -2 / 9], + [-2 / 9, 2 / 9, -2 / 9], + [1 / 90, 1 / 45, 2 / 45], + [1 / 90, -1 / 45, 2 / 45], + [1 / 45, 1 / 90, 1 / 180], + [1 / 45, -1 / 90, 1 / 180], + [0, 0, 1], + ], + dtype=np.float32, + ) + + B_data = np.array( + [ + [1, 0, -21 / 4, 0, 21 / 4, 0, -1, 0], + [0, 1, 1, -17 / 4, -17 / 4, 1, 1, 0], + [0, -1, 1, 17 / 4, -17 / 4, -1, 1, 0], + [0, 1 / 2, 1 / 4, -5 / 2, -5 / 4, 2, 1, 0], + [0, -1 / 2, 1 / 4, 5 / 2, -5 / 4, -2, 1, 0], + [0, 2, 4, -5 / 2, -5, 1 / 2, 1, 0], + [0, -2, 4, 5 / 2, -5, -1 / 2, 1, 0], + [0, -1, 0, 21 / 4, 0, -21 / 4, 0, 1], + ], + dtype=np.float32, + ).T + + A_data = np.array( + [ + [1, 1, 1, 1, 1, 32, 32, 0], + [0, 1, -1, 2, -2, 16, -16, 0], + [0, 1, 1, 4, 4, 8, 8, 0], + [0, 1, -1, 8, -8, 4, -4, 0], + [0, 1, 1, 16, 16, 2, 2, 0], + [0, 1, -1, 32, -32, 1, -1, 1], + ], + dtype=np.float32, + ).T + elif m == 2: + G_data = np.array( + [ + [1, 0, 0], + [1.0 / 2, 1.0 / 2, 1.0 / 2], + [1.0 / 2, -1.0 / 2, 1.0 / 2], + [0, 0, 1], + ], + np.float32, + ) + + B_data = np.array( + [ + [1, 0, 0, 0], + [0, 1, -1, 1], + [-1, 1, 1, 0], + [0, 0, 0, -1], + ], + np.float32, + ) + + A_data = np.array( + [[1, 0], [1, 1], [1, -1], [0, -1]], np.float32 + ) + + return ( + const_matrix(A_data, "A"), + const_matrix(B_data, "B"), + const_matrix(G_data, "G"), + ) + +@autotvm.register_topi_compute(conv2d_NCHWc, 'cpu', 'winograd') +def _declaration_conv_NCHWc_winograd(cfg, data, kernel, strides, + padding, dilation, layout, out_layout, out_dtype): + return _declaration_conv_NCHWc_winograd_impl( + cfg, data, kernel, strides, padding, dilation, + layout, out_layout, out_dtype, + transform_kernel=True, tile_size=None) + + +def _declaration_conv_NCHWc_winograd_impl( + cfg, data, kernel, strides, + padding, dilation, layout, out_layout, out_dtype, + transform_kernel, tile_size): + out_dtype = out_dtype or data.dtype + N, CII, IH, IW, CIII = get_const_tuple(data.shape) + + if transform_kernel: + COO, CII, KH, KW, CIII_, VC = get_const_tuple(kernel.shape) + else: + COO, CII, CIII_, _, _, VC = get_const_tuple(kernel.shape) + KH = 3 + KW = 3 + + cfg.define_knob("tile_size", [2, 4, 6]) + m = tile_size if tile_size else cfg["tile_size"].val + r = 3 + alpha = m + r - 1 + + pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (KH, KW)) + HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides) -@autotvm.register_topi_schedule(generic.schedule_conv2d_NCHWc, 'cpu', ['direct']) + OH = (IH + pad_top + pad_bottom - KH) // HSTR + 1 + OW = (IW + pad_left + pad_right - KW) // WSTR + 1 + data_pad = pad( + data, + [0, 0, pad_top, pad_left, 0], + [0, 0, pad_bottom, pad_right, 0], + name="data_pad" + ) + + A, B, G = get_transform_matrices(m) + + def div_round_up(a, b): + return (a + b - 1) // b + + # assert all(k == 3 for k in (KH, KW)) + assert all(p == 1 for p in (pad_top, pad_left, pad_bottom, pad_right)) + assert all(s == 1 for s in (HSTR, WSTR)) + assert OH == IH + assert OW == IW + + OH_M = div_round_up(OH, m) + OW_M = div_round_up(OW, m) + # Layouts: + + # input = (N, CII, IH, IW, CIII) + # -> transpose + ############################################################ + # input_tile_shape = (N, CII, OH // m, OH // m, alpha, alpha, CIII) + # U_shape = (COO, CII, CIII, alpha, alpha, COOO) + # V_shape = (N, CII, OH // m, OW // m, alpha, alpha, CIII) + # M_shape = (N, COO, OH // m, OW // m, alpha, alpha, COOO) + # Y_shape = (N, COO, OH // m, OW // m, m, m, COOO) + ############################################################ + # -> transpose + # O_shape = (N, COO, OH, OW, COOO) + + n, coo, oh, ow, oh_m, ow_m, vc = \ + cfg.axis(N), cfg.axis(COO), cfg.axis(OH), cfg.axis(OW), \ + cfg.axis(OH_M), cfg.axis(OW_M), cfg.axis(VC) + cii, ciii, kh, kw = cfg.reduce_axis(CII), cfg.reduce_axis(CIII), \ + cfg.reduce_axis(KH), cfg.reduce_axis(KW) + + eps, nu = cfg.axis(alpha), cfg.axis(alpha) + vh, vw = cfg.axis(m), cfg.axis(m) + r_eps, r_nu = cfg.axis(alpha), cfg.axis(alpha) + cfg.define_reorder("reorder_M", + [n, coo, oh_m, ow_m, eps, nu, vc, cii, ciii], + policy='candidate', candidate=[ + [n, coo, cii, oh_m, ow_m, eps, ciii, nu, vc], + # [n, coo, cii, oh_m, ow_m, ciii, nu, eps, vc], + # [n, coo, cii, oh_m, ow_m, nu, eps, ciii, vc], + # [n, coo, oh_m, ow_m, nu, eps, cii, ciii, vc], + ]) + + cfg.define_reorder("reorder_V", + [n, cii, oh_m, ow_m, eps, nu, ciii, r_eps, r_nu], + policy='candidate', candidate=[ + [n, cii, oh_m, ow_m, eps, r_eps, r_nu, nu, ciii], + # [n, cii, oh_m, ow_m, eps, nu, r_eps, r_nu, ciii], + # [n, cii, oh_m, ow_m, r_eps, r_nu, eps, nu, ciii], + # [n, cii, oh_m, ow_m, r_eps, r_nu, eps, nu, ciii], + ]) + + cfg.define_reorder("reorder_Y", + [n, coo, oh_m, ow_m, vh, vw, vc, r_eps, r_nu], + policy='candidate', candidate=[ + [n, coo, oh_m, ow_m, vh, r_eps, r_nu, vw, vc], + # [n, coo, oh_m, ow_m, vh, vw, r_eps, r_nu, vc], + # [n, coo, oh_m, ow_m, r_eps, r_nu, vh, vw, vc], + # [n, coo, oh_m, ow_m, r_eps, r_nu, vh, vw, vc], + ]) + + + input_tile = tvm.compute((N, CII, OH_M, OW_M, alpha, alpha, CIII), + lambda n, cii, oh_m, ow_m, eps, nu, ciii: + data_pad[n][cii][oh_m * m + eps][ow_m * m + nu][ciii], + name='input_tile') + + # transform kernel + if transform_kernel: + r_kh = tvm.reduce_axis((0, KH), 'r_kh') + r_kw = tvm.reduce_axis((0, KW), 'r_kw') + U = tvm.compute((COO, CII, CIII, alpha, alpha, VC), + lambda coo, cii, ciii, eps, nu, vc: + tvm.sum(kernel[coo][cii][r_kh][r_kw][ciii][vc].astype(out_dtype) * + G[eps][r_kh] * G[nu][r_kw], axis=[r_kh, r_kw]), + name='U') + else: + U = kernel + + # transform image + r_eps = tvm.reduce_axis((0, alpha), 'r_eps') + r_nu = tvm.reduce_axis((0, alpha), 'r_nu') + V = tvm.compute((N, CII, OH_M, OW_M, alpha, alpha, CIII), + lambda n, cii, oh_m, ow_m, eps, nu, ciii: + tvm.sum(input_tile[n][cii][oh_m][ow_m][r_eps][r_nu][ciii].astype(out_dtype) * + B[r_eps][eps] * B[r_nu][nu], axis=[r_eps, r_nu]), name='V') + cii = tvm.reduce_axis((0, CII), name='cii') + ciii = tvm.reduce_axis((0, CIII), name='ciii') + + # M_shape = (N, COO, OH // m, OW // m, alpha, alpha, COOO) + M = tvm.compute((N, COO, OH_M, OW_M, alpha, alpha, VC), + lambda n, coo, oh_m, ow_m, eps, nu, vc: + tvm.sum(U[coo][cii][ciii][eps][nu][vc] * V[n][cii][oh_m][ow_m][eps][nu][ciii], + axis=[cii, ciii]), + name='M') + + # inverse transform + r_eps = tvm.reduce_axis((0, alpha), 'r_eps') + r_nu = tvm.reduce_axis((0, alpha), 'r_nu') + # Y_shape = (N, COO, OH // m, OW // m, m, m, COOO) + Y = tvm.compute((N, COO, OH_M, OW_M, m, m, VC), + lambda n, coo, oh_m, ow_m, vh, vw, vc: + tvm.sum(M[n][coo][oh_m][ow_m][r_eps][r_nu][vc] * A[r_eps][vh] * A[r_nu][vw], + axis=[r_eps, r_nu]), + name='Y') + + output = tvm.compute((N, COO, OH, OW, VC), + lambda n, coo, oh, ow, vc: + Y[n][coo][oh // m][ow // m][oh % m][ow % m][vc], + name='output', tag='conv2d_NCHWc_winograd') + cfg.add_flop(2 * N * COO * VC * OH * OW * KH * KW * CII * CIII) + return output + +@autotvm.register_topi_compute( + conv2d_NCHWc_winograd_without_weight_transform, 'cpu', 'winograd') +def _declaration_conv_NCHWc_winograd_without_weight_transform( + cfg, data, transformed_kernel, strides, + padding, dilation, layout, out_layout, out_dtype, tile_size): + return _declaration_conv_NCHWc_winograd_impl( + cfg, data, transformed_kernel, strides, padding, dilation, + layout, out_layout, out_dtype, transform_kernel=False, tile_size=tile_size) + + +@autotvm.register_topi_schedule( + generic.schedule_conv2d_NCHWc, 'cpu', ['direct', 'winograd']) def _schedule_conv2d_NCHWc(cfg, outs): """Create schedule for tensors""" s = tvm.create_schedule([x.op for x in outs]) scheduled_ops = [] - def traverse(op): - """Traverse operators from computation graph""" - # inline all one-to-one-mapping operators except the last stage (output) - if tag.is_broadcast(op.tag): - if op not in s.outputs: - s[op].compute_inline() - for tensor in op.input_tensors: - if tensor.op.input_tensors and tensor.op not in scheduled_ops: - traverse(tensor.op) - - if 'conv2d_NCHWc' in op.tag: + def _callback(op): + if 'conv2d_NCHWc_winograd' in op.tag: + _schedule_conv2d_NCHWc_winograd(cfg, s, op.output(0), outs[0]) + elif 'conv2d_NCHWc' in op.tag: conv_out = op.output(0) kernel = conv_out.op.input_tensors[1] data_vec = conv_out.op.input_tensors[0] @@ -463,8 +813,182 @@ def traverse(op): conv2d_avx_1x1._schedule_conv_NCHWc(*args) else: conv2d_avx_common._schedule_conv_NCHWc(*args) - scheduled_ops.append(op) - traverse(outs[0].op) + traverse_inline(s, outs[0].op, _callback) + return s + +@autotvm.register_topi_schedule( + generic.schedule_conv2d_NCHWc_winograd_without_weight_transform, + 'cpu', ['winograd']) +def schedule_conv2d_winograd_without_weight_transform_(cfg, outs): + """TOPI schedule callback""" + s = tvm.create_schedule([x.op for x in outs]) + def _callback(op): + if 'conv2d_NCHWc_winograd' in op.tag: + + output = op.output(0) + _schedule_conv2d_NCHWc_winograd(cfg, s, output, outs[0]) + + traverse_inline(s, outs[0].op, _callback) + return s + +def _schedule_conv2d_NCHWc_winograd(cfg, s, output, last): + Y = output.op.input_tensors[0] + M, A = Y.op.input_tensors + U, V = M.op.input_tensors + input_tile, B = V.op.input_tensors + data_pad = input_tile.op.input_tensors[0] + + # Inline the constants. + s[A].compute_inline() + s[B].compute_inline() + + # transform kernel + if isinstance(U.op, tvm.tensor.ComputeOp): + kernel, G = U.op.input_tensors + s[G].compute_inline() + coo, cii, eps, nu, ciii, vc = s[U].op.axis + if autotvm.GLOBAL_SCOPE.in_tuning: + # kernel transformation will be pre-computed during compilation, so we skip + # this part to make tuning records correct + s[U].pragma(eps, 'debug_skip_region') + else: + pass + # r_kh, r_kw = s[U].op.reduce_axis + # s[U].reorder(k, c, eps, nu, r_kh, r_kw, kk) + # for axis in [eps, nu, r_kh, r_kw]: + # s[U].unroll(axis) + # s[U].vectorize(kk) + # s[U].parallel(k) + + if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag: + s[kernel].compute_inline() + + ############################################################ + # input tile + n, cii, oh_m, ow_m, eps, nu, ciii = s[input_tile].op.axis + # Vectorize the input tile + s[input_tile].vectorize(ciii) + + cfg.define_knob('data_pad_compute_location', [0, 1, 2, 3]) + if cfg['data_pad_compute_location'].val == 0: + s[data_pad].compute_inline() + if cfg['data_pad_compute_location'].val == 1: + s[data_pad].compute_at(s[input_tile], cii) + (_, _, _, _, dpcii) = s[data_pad].op.axis + s[data_pad].vectorize(dpcii) + if cfg['data_pad_compute_location'].val == 2: + s[data_pad].compute_at(s[input_tile], oh_m) + (_, _, _, _, dpcii) = s[data_pad].op.axis + s[data_pad].vectorize(dpcii) + if cfg['data_pad_compute_location'].val == 3: + s[data_pad].compute_at(s[input_tile], ow_m) + (_, _, _, _, dpcii) = s[data_pad].op.axis + s[data_pad].vectorize(dpcii) + + ############################################################ + + ############################################################ + # data_pad + # s[data_pad].compute_inline() + ############################################################ + + ############################################################ + # transform image + n, cii, oh_m, ow_m, eps, nu, ciii = s[V].op.axis + r_eps, r_nu = s[V].op.reduce_axis + + s[V].vectorize(ciii) + # import ipdb; ipdb.set_trace() + cfg["reorder_V"].apply(s, V, [n, cii, oh_m, ow_m, eps, nu, ciii, r_eps, r_nu]) + + cfg.define_annotate("reduce_V", [r_eps, r_nu, eps, nu], + policy='unroll') + cfg['reduce_V'].apply(s, V, [r_eps, r_nu, eps, nu], cfg=cfg) + + + cfg.define_knob('input_tile_compute_location', [0, 1, 2, 3]) + if cfg['input_tile_compute_location'].val == 1: + s[input_tile].compute_at(s[V], cii) + if cfg['input_tile_compute_location'].val == 2: + s[input_tile].compute_at(s[V], oh_m) + if cfg['input_tile_compute_location'].val == 3: + s[input_tile].compute_at(s[V], ow_m) + ############################################################ + + ############################################################ + # batch gemm + n, coo, oh_m, ow_m, eps, nu, vc = s[M].op.axis + cii, ciii = s[M].op.reduce_axis + s[M].vectorize(vc) + + cfg["reorder_M"].apply(s, M, [n, coo, oh_m, ow_m, eps, nu, vc, cii, ciii]) + + cfg.define_annotate("reduce_M", [eps, nu], + policy='try_unroll') + cfg['reduce_M'].apply(s, M, [eps, nu], cfg=cfg) + + cfg.define_knob('V_compute_location', [0, 1, 2, 3]) + if cfg['V_compute_location'].val == 1: + s[V].compute_at(s[M], coo) + if cfg['V_compute_location'].val == 2: + s[V].compute_at(s[M], oh_m) + if cfg['V_compute_location'].val == 3: + s[V].compute_at(s[M], ow_m) + + ############################################################ + + ############################################################ + # inverse transform + s[A].compute_inline() + n, coo, oh_m, ow_m, vh, vw, vc = s[Y].op.axis + r_eps, r_nu = s[Y].op.reduce_axis + s[Y].vectorize(vc) + + cfg['reorder_Y'].apply(s, Y, [n, coo, oh_m, ow_m, vh, vw, vc, r_eps, r_nu]) + + cfg.define_annotate("reduce_Y", [r_eps, r_nu, vh, vw], + policy='unroll') + cfg['reduce_Y'].apply(s, Y, [r_eps, r_nu, vh, vw], cfg=cfg) + + cfg.define_knob('M_compute_location', [0, 1, 2, 3]) + if cfg['M_compute_location'].val == 1: + s[M].compute_at(s[Y], coo) + if cfg['M_compute_location'].val == 2: + s[M].compute_at(s[Y], oh_m) + if cfg['M_compute_location'].val == 3: + s[M].compute_at(s[Y], ow_m) + + ############################################################ + + ############################################################ + # output + + if output != last: + s[output].compute_inline() + + n, coo, oh, ow, vc = s[last].op.axis + s[last].vectorize(vc) + + OH = get_const_int(oh.dom.extent) + OW = get_const_int(ow.dom.extent) + mh = get_const_int(vh.dom.extent) + mw = get_const_int(vw.dom.extent) + cfg.define_knob('output_tile', [1]) + cfg.define_annotate('reduce_output', [cfg.axis(mh), cfg.axis(mw)], policy="try_unroll") + if OH % mh == 0 and OW % mw == 0 and cfg['output_tile'].val == 1: + # We can tile in OH + oh, ow, ohi, owi = s[last].tile(oh, ow, mh, mw) + cfg["reduce_output"].apply(s, last, [ohi, owi], cfg=cfg) + + cfg.define_knob('Y_compute_location', [0, 1, 2, 3]) + if cfg['Y_compute_location'].val == 1: + s[Y].compute_at(s[last], coo) + if cfg['Y_compute_location'].val == 2: + s[Y].compute_at(s[last], oh) + if cfg['Y_compute_location'].val == 3: + s[Y].compute_at(s[last], ow) + ############################################################ + return s diff --git a/topi/tests/python/test_topi_conv2d_NCHWc_winograd.py b/topi/tests/python/test_topi_conv2d_NCHWc_winograd.py new file mode 100644 index 0000000000000..780e3c2b0e249 --- /dev/null +++ b/topi/tests/python/test_topi_conv2d_NCHWc_winograd.py @@ -0,0 +1,202 @@ +"""Example code to do convolution.""" + +import numpy as np +import tvm +from tvm import autotvm +from tvm.autotvm.task.space import FallbackConfigEntity +import topi +import topi.testing +from tvm.contrib.pickle_memoize import memoize +from topi.util import get_const_tuple + + +def _transform_data(data, bn): + # NCHW -> NCHW[x]c + batch_size, channel, height, width = data.shape + data = np.reshape(data, (batch_size, channel//bn, bn, height, width)) + data = np.transpose(data, (0, 1, 3, 4, 2)) + return data + +def _transform_kernel(kernel, ic_bn, oc_bn): + # OIHW -> OIHW[x]i[x]o + out_channel, in_channel, kh, kw = kernel.shape + kernel = np.reshape(kernel, (out_channel//oc_bn, oc_bn, in_channel//ic_bn, ic_bn, kh, kw)) + kernel = np.transpose(kernel, (0, 2, 4, 5, 3, 1)) + return kernel + +def _transform_bias(bias, bn): + # [num_filter, 1, 1] -> [num_filter//bn, 1, 1, bn] + num_filter, h, w = bias.shape + bias = np.reshape(bias, (num_filter//bn, bn, h, w)) + bias = np.transpose(bias, (0, 2, 3, 1)) + return bias + + +def verify_conv2d_NCHWc_winograd( + batch, in_channel, in_size, num_filter, kernel, + stride, padding, dilation=1, add_bias=False, add_relu=False, tile_size=2): + print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d, ts=%d)" % + (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation, tile_size)) + + in_height = in_width = in_size + + # for testing functionality, + # we choose arbitrary block size that can divide the channel, + # regardless of the performance. + oc_block = 1 + for bn in range(16, 0, -1): + if num_filter % bn == 0: + oc_block = bn + break + + ic_block = 1 + for bn in range(oc_block, 0, -1): + if in_channel % bn == 0: + ic_block = bn + break + + A = tvm.placeholder((batch, in_channel // ic_block, in_height, in_width, ic_block), name='A') + W = tvm.placeholder((num_filter // oc_block, in_channel // ic_block, kernel, kernel, ic_block, oc_block), name='W') + bias = tvm.placeholder((num_filter // oc_block, 1, 1, oc_block), name='bias') + + a_shape = get_const_tuple(A.shape) + w_shape = get_const_tuple(W.shape) + bias_shape = get_const_tuple(bias.shape) + dtype = A.dtype + kernel_layout = \ + "OIHW{ic_block}i{oc_block}o".format(ic_block=ic_block, oc_block=oc_block) + layout = "NCHW{ic_block}c".format(ic_block=ic_block) + out_layout = "NCHW{oc_block}c".format(oc_block=oc_block) + + @memoize("topi.tests.test_topi_conv2d_NCHWc_winograd.verify_conv2d_NCHWc_winograd") + def get_ref_data(): + a_np = np.random.uniform(size=(batch, in_channel, in_height, in_width)).astype(dtype) + w_np = np.random.uniform(size=(num_filter, in_channel, kernel, kernel)).astype(dtype) * 0.01 + b_np = np.random.uniform(size=(num_filter, 1, 1)).astype(dtype) + c_np = topi.testing.conv2d_nchw_python(a_np, w_np, stride, padding) + if add_bias: + c_np += b_np + if add_relu: + c_np = np.maximum(c_np, 0) + return _transform_data(a_np, ic_block), _transform_kernel(w_np, ic_block, oc_block), \ + _transform_bias(b_np, oc_block), _transform_data(c_np, oc_block) + + a_np, w_np, b_np, c_np = get_ref_data() + + def check_device_without_weight_transform(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + print("Running on target: %s" % device) + + with tvm.target.create(device): + WT = topi.nn.conv2d_NCHWc_winograd_weight_transform( + W, + tile_size=tile_size, + kernel_layout=kernel_layout + ) + C = topi.nn.conv2d_NCHWc_winograd_without_weight_transform( + A, WT, (stride, stride), (padding, padding), + (dilation, dilation), + layout=layout, + out_layout=out_layout, + out_dtype=dtype, + tile_size=tile_size) + if add_bias: + C = topi.add(C, bias) + if add_relu: + C = topi.nn.relu(C) + s = topi.generic.schedule_conv2d_NCHWc_winograd_without_weight_transform([C]) + + a = tvm.nd.array(a_np, ctx) + w = tvm.nd.array(w_np, ctx) + b = tvm.nd.array(b_np, ctx) + c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx) + if add_bias: + func = tvm.build(s, [A, W, bias, C], device, name="relu_bias_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation)) + func(a, w, b, c) + else: + func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation)) + func(a, w, c) + print(np.max(np.abs(((c.asnumpy() - c_np) / (np.abs(c_np) + 0.001))))) + + tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5) + + def check_device_with_weight_transform(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + print("Running on target: %s" % device) + + with tvm.target.create(device): + C = topi.nn.conv2d_NCHWc( + A, W, (stride, stride), (padding, padding), + (dilation, dilation), + layout=layout, + out_layout=out_layout, + out_dtype=dtype) + if add_bias: + C = topi.add(C, bias) + if add_relu: + C = topi.nn.relu(C) + s = topi.generic.schedule_conv2d_NCHWc([C]) + + a = tvm.nd.array(a_np, ctx) + w = tvm.nd.array(w_np, ctx) + b = tvm.nd.array(b_np, ctx) + c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx) + if add_bias: + func = tvm.build(s, [A, W, bias, C], device, name="relu_bias_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation)) + func(a, w, b, c) + else: + func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation)) + func(a, w, c) + print(np.max(np.abs(((c.asnumpy() - c_np) / (np.abs(c_np) + 0.001))))) + + + tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5) + + # test llvm only for now since conv2d_NCHWc_winograd is only implemented on this backend. + for device in ['llvm']: + check_device_with_weight_transform(device) + check_device_without_weight_transform(device) + + +class WinogradFallback(autotvm.FallbackContext): + def _query_inside(self, target, workload): + key = (target, workload) + if key in self.memory: + return self.memory[key] + cfg = FallbackConfigEntity() + cfg.template_key = 'winograd' + self.memory[key] = cfg + return cfg + + +def test_conv2d_nchw(): + autotvm.DispatchContext.current.silent = True + + with WinogradFallback(): + # resnet 18 workloads + verify_conv2d_NCHWc_winograd(1, 64, 56, 64, 3, 1, 1) + verify_conv2d_NCHWc_winograd(1, 128, 28, 128, 3, 1, 1, tile_size=4) + verify_conv2d_NCHWc_winograd(1, 256, 14, 256, 3, 1, 1, tile_size=4) + verify_conv2d_NCHWc_winograd(1, 512, 7, 512, 3, 1, 1) + + # batch size = 2 + verify_conv2d_NCHWc_winograd(2, 64, 56, 64, 3, 1, 1) + + # relu, bias + verify_conv2d_NCHWc_winograd(2, 64, 56, 64, 3, 1, 1, add_bias=True) + verify_conv2d_NCHWc_winograd(2, 64, 56, 64, 3, 1, 1, add_relu=True) + verify_conv2d_NCHWc_winograd(2, 64, 56, 64, 3, 1, 1, add_relu=True, add_bias=True) + + # werid workloads + verify_conv2d_NCHWc_winograd(1, 1, 1, 1, 3, 1, 1) + verify_conv2d_NCHWc_winograd(3, 3, 3, 3, 3, 1, 1) + verify_conv2d_NCHWc_winograd(2, 13, 71, 59, 3, 1, 1) + +if __name__ == "__main__": + test_conv2d_nchw()