From b380aed2de51595f97b2de9b7a1eb8370d162aef Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Fri, 13 Jan 2023 19:35:03 -0500 Subject: [PATCH] [Op] Group for conv2d, ceil mode for max_pool2d (#358) This PR adds the parameter group for conv2d, as well as the option ceil_mode for max_pool2d. They were omitted in the previous round of PR mainly for simplicity, and are introduced now due to their uses in ML frameworks. --- include/tvm/relax/attrs/nn.h | 8 ++ python/tvm/relax/op/nn/nn.py | 13 ++- src/relax/op/nn/convolution.cc | 32 ++++++-- src/relax/op/nn/convolution.h | 5 +- src/relax/op/nn/pooling.cc | 16 ++-- src/relax/op/nn/pooling.h | 2 +- tests/python/relax/test_op_nn_convolution.py | 83 ++++++++++++++++++++ tests/python/relax/test_op_nn_pooling.py | 62 +++++++++++++++ 8 files changed, 204 insertions(+), 17 deletions(-) diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h index e04ff337aa..4b4e0680e2 100644 --- a/include/tvm/relax/attrs/nn.h +++ b/include/tvm/relax/attrs/nn.h @@ -34,6 +34,7 @@ struct Conv2DAttrs : public tvm::AttrsNode { Array strides; Array padding; Array dilation; + int groups; String data_layout; String kernel_layout; String out_layout; @@ -49,6 +50,9 @@ struct Conv2DAttrs : public tvm::AttrsNode { "four int : padding width in the order of (top, left, bottom, right)"); TVM_ATTR_FIELD(dilation).describe( "Specifies the dilation rate to use for dilated convolution."); + TVM_ATTR_FIELD(groups).describe( + "Number of groups to split the input into for grouped convolution. The number of input and " + "output channels should be divisible by the number of groups."); TVM_ATTR_FIELD(data_layout) .describe( "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." @@ -76,6 +80,7 @@ struct MaxPool2DAttrs : public tvm::AttrsNode { Array strides; Array padding; Array dilation; + bool ceil_mode; String layout; String out_layout; @@ -89,6 +94,9 @@ struct MaxPool2DAttrs : public tvm::AttrsNode { "one int : same padding used on all sides" "two int : bottom, right will use same padding as top, left" "four int : padding width in the order of (top, left, bottom, right)"); + TVM_ATTR_FIELD(ceil_mode).describe( + "A boolean indicating if use ceil or floor to compute the output shape. By using ceil, " + "every element in the input tensor will be covered by a sliding window."); TVM_ATTR_FIELD(layout).describe( "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py index d27daff508..a62cc7f997 100644 --- a/python/tvm/relax/op/nn/nn.py +++ b/python/tvm/relax/op/nn/nn.py @@ -32,6 +32,7 @@ def conv2d( strides: Union[PrimExprLike, Tuple[PrimExprLike]] = (1, 1), padding: Union[PrimExprLike, Tuple[PrimExprLike]] = (0, 0), dilation: Union[PrimExprLike, Tuple[PrimExprLike]] = (1, 1), + groups: int = 1, data_layout: str = "NCHW", kernel_layout: str = "OIHW", out_layout: Optional[str] = None, @@ -81,6 +82,10 @@ def conv2d( Specifies the dilation rate to be used for dilated convolution. It is required to have length either 1 or 2. + groups : int + Number of groups to split the input into for grouped convolution. + The number of input and output channels should be divisible by the number of groups. + data_layout : str Layout of the input. @@ -111,6 +116,7 @@ def conv2d( strides, padding, dilation, + groups, data_layout, kernel_layout, out_layout, @@ -124,6 +130,7 @@ def max_pool2d( strides: Union[PrimExprLike, Tuple[PrimExprLike]] = (1, 1), padding: Union[PrimExprLike, Tuple[PrimExprLike]] = (0, 0), dilation: Union[PrimExprLike, Tuple[PrimExprLike]] = (1, 1), + ceil_mode: bool = False, layout: str = "NCHW", out_layout: Optional[str] = None, ) -> Expr: @@ -165,6 +172,10 @@ def max_pool2d( dilation : Union[PrimExprLike, Tuple[PrimExprLike]] The dilation of pooling. It is required to have length either 1 or 2. + ceil_mode : bool + A boolean indicating if use ceil or floor to compute the output shape. + By using ceil, every element in the input tensor will be covered by a sliding window. + layout : str Layout of the input. @@ -186,7 +197,7 @@ def max_pool2d( padding = (padding, padding, padding, padding) return _ffi_api.max_pool2d( # type: ignore - data, pool_size, strides, padding, dilation, layout, out_layout + data, pool_size, strides, padding, dilation, ceil_mode, layout, out_layout ) diff --git a/src/relax/op/nn/convolution.cc b/src/relax/op/nn/convolution.cc index 0fb860f8df..bae63e831f 100644 --- a/src/relax/op/nn/convolution.cc +++ b/src/relax/op/nn/convolution.cc @@ -33,7 +33,7 @@ namespace relax { TVM_REGISTER_NODE_TYPE(Conv2DAttrs); Expr conv2d(Expr data, Expr weight, Array strides, Array padding, - Array dilation, String data_layout, String kernel_layout, + Array dilation, int groups, String data_layout, String kernel_layout, Optional out_layout, DataType out_dtype) { padding = GetCompletePadding2D(std::move(padding)); if (strides.size() == 1) { @@ -43,13 +43,16 @@ Expr conv2d(Expr data, Expr weight, Array strides, Array pad dilation.push_back(dilation[0]); } + CHECK_GT(groups, 0) << "The number of groups in convolution is expected to be positive. However, " + "the given number of groups is " + << groups; CHECK_EQ(strides.size(), 2) << "The input strides length is expected to be 2. However, the given strides is " << strides; CHECK_EQ(dilation.size(), 2) << "The input dilation length is expected to be 2. However, the given dilation is " << dilation; return MakeConv(std::move(data), std::move(weight), std::move(strides), - std::move(padding), std::move(dilation), data_layout, + std::move(padding), std::move(dilation), groups, data_layout, std::move(kernel_layout), out_layout.value_or(data_layout), out_dtype, /*op_name=*/"relax.nn.conv2d"); } @@ -90,12 +93,25 @@ StructInfo InferStructInfoConv2d(const Call& call, const BlockBuilder& ctx) { arith::Analyzer* analyzer = ctx->GetAnalyzer(); PrimExpr input_channel_data = data_NCHW_shape[1]; PrimExpr input_channel_kernel = weight_OIHW_shape[1]; - if (analyzer->CanProve(input_channel_data != input_channel_kernel)) { - ctx->ReportFatal(Diagnostic::Error(call->span) - << "The channel size of the data should equal to the input channel size of " - "the weight. However, the data channel size is " - << input_channel_data << " while the weight input channel size is " - << input_channel_kernel); + if (analyzer->CanProve(input_channel_data != input_channel_kernel * attrs->groups)) { + ctx->ReportFatal( + Diagnostic::Error(call) + << "The channel size of the data should equal to the product of input channel size of the " + "weight and the number of groups. However, the data channel size is " + << input_channel_data << " while the weight input channel size and number of groups are " + << input_channel_kernel << " and " << attrs->groups); + } else if (!analyzer->CanProveEqual(input_channel_data, input_channel_kernel * attrs->groups)) { + // Todo(relax-team): Trust the input shape at this moment, and revisit + // this condition with runtime shape check + } + if (analyzer->CanProve(floormod(weight_OIHW_shape[0], attrs->groups) != 0)) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Conv2d expects the number of output channels to be divisible by the " + "number of groups. However, the number of output channels is " + << weight_OIHW_shape[0] << " while the number of groups is " << attrs->groups); + } else if (!analyzer->CanProveEqual(floormod(weight_OIHW_shape[0], attrs->groups), 0)) { + // Todo(relax-team): Trust the input shape at this moment, and revisit + // this condition with runtime shape check } PrimExpr input_h = data_NCHW_shape[2]; diff --git a/src/relax/op/nn/convolution.h b/src/relax/op/nn/convolution.h index 9c8342d7f8..19acd1bd67 100644 --- a/src/relax/op/nn/convolution.h +++ b/src/relax/op/nn/convolution.h @@ -37,12 +37,13 @@ namespace relax { template inline Expr MakeConv(Expr data, Expr weight, Array strides, Array padding, - Array dilation, String data_layout, String kernel_layout, + Array dilation, int groups, String data_layout, String kernel_layout, String out_layout, DataType out_dtype, std::string op_name) { auto attrs = make_object(); attrs->strides = std::move(strides); attrs->padding = std::move(padding); attrs->dilation = std::move(dilation); + attrs->groups = groups; attrs->data_layout = std::move(data_layout); attrs->kernel_layout = std::move(kernel_layout); attrs->out_layout = std::move(out_layout); @@ -53,7 +54,7 @@ inline Expr MakeConv(Expr data, Expr weight, Array strides, Array strides, Array padding, - Array dilation, String data_layout, String kernel_layout, + Array dilation, int groups, String data_layout, String kernel_layout, Optional out_layout, DataType out_dtype); } // namespace relax diff --git a/src/relax/op/nn/pooling.cc b/src/relax/op/nn/pooling.cc index ee6d08a62e..a1ac14585b 100644 --- a/src/relax/op/nn/pooling.cc +++ b/src/relax/op/nn/pooling.cc @@ -29,7 +29,7 @@ namespace relax { TVM_REGISTER_NODE_TYPE(MaxPool2DAttrs); Expr max_pool2d(Expr data, Array pool_size, Array strides, - Array padding, Array dilation, String layout, + Array padding, Array dilation, bool ceil_mode, String layout, Optional out_layout) { padding = GetCompletePadding2D(std::move(padding)); if (pool_size.size() == 1) { @@ -56,6 +56,7 @@ Expr max_pool2d(Expr data, Array pool_size, Array strides, attrs->strides = std::move(strides); attrs->padding = std::move(padding); attrs->dilation = std::move(dilation); + attrs->ceil_mode = ceil_mode; attrs->layout = layout; attrs->out_layout = out_layout.value_or(layout); static const Op& op = Op::Get("relax.nn.max_pool2d"); @@ -95,10 +96,15 @@ StructInfo InferStructInfoMaxPool2D(const Call& call, const BlockBuilder& ctx) { out_NCHW_shape.resize(4); out_NCHW_shape[0] = data_NCHW_shape[0]; out_NCHW_shape[1] = data_NCHW_shape[1]; - out_NCHW_shape[2] = analyzer->Simplify( - (input_h + padding_h - attrs->dilation[0] * (kernel_h - 1) - 1) / attrs->strides[0] + 1); - out_NCHW_shape[3] = analyzer->Simplify( - (input_w + padding_w - attrs->dilation[1] * (kernel_w - 1) - 1) / attrs->strides[1] + 1); + + PrimExpr numerator_h = input_h + padding_h - attrs->dilation[0] * (kernel_h - 1) - 1; + PrimExpr numerator_w = input_w + padding_w - attrs->dilation[1] * (kernel_w - 1) - 1; + if (attrs->ceil_mode) { + numerator_h += attrs->strides[0] - 1; + numerator_w += attrs->strides[1] - 1; + } + out_NCHW_shape[2] = analyzer->Simplify(numerator_h / attrs->strides[0] + 1); + out_NCHW_shape[3] = analyzer->Simplify(numerator_w / attrs->strides[1] + 1); Array out_shape = out2NCHW.BackwardShape(out_NCHW_shape); return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype); diff --git a/src/relax/op/nn/pooling.h b/src/relax/op/nn/pooling.h index 770972a5fc..04f342b9f5 100644 --- a/src/relax/op/nn/pooling.h +++ b/src/relax/op/nn/pooling.h @@ -34,7 +34,7 @@ namespace relax { /*! \brief 2D maximum pooling operator. */ Expr max_pool2d(Expr data, Array pool_size, Array strides, - Array padding, Array dilation, String layout, + Array padding, Array dilation, bool ceil_mode, String layout, Optional out_layout); /*! \brief 2D adaptive average pooling operator. */ diff --git a/tests/python/relax/test_op_nn_convolution.py b/tests/python/relax/test_op_nn_convolution.py index d1cfc63f5a..25b144e28f 100644 --- a/tests/python/relax/test_op_nn_convolution.py +++ b/tests/python/relax/test_op_nn_convolution.py @@ -218,6 +218,89 @@ def test_conv2d_infer_struct_info_shape_var(): ) +def test_conv2d_infer_struct_info_groups(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 128, 28, 28), "float32")) + x1 = relax.Var("x", R.Tensor((2, 8, 28, 28, 16), "float32")) + w0 = relax.Var("w", R.Tensor((48, 16, 3, 3), "float32")) + w1 = relax.Var("w", R.Tensor((48, 2, 3, 3, 8), "float32")) + + _check_inference( + bb, relax.op.nn.conv2d(x0, w0, groups=8), relax.TensorStructInfo((2, 48, 26, 26), "float32") + ) + _check_inference( + bb, + relax.op.nn.conv2d(x0, w1, kernel_layout="OIHW8i", groups=8), + relax.TensorStructInfo((2, 48, 26, 26), "float32"), + ) + _check_inference( + bb, + relax.op.nn.conv2d(x1, w0, data_layout="NCHW16c", groups=8), + relax.TensorStructInfo((2, 3, 26, 26, 16), "float32"), + ) + + +def test_conv2d_infer_struct_info_symbolic_groups(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + ic = tir.Var("c", "int64") + oc = tir.Var("oc", "int64") + x = relax.Var("x", R.Tensor((n, ic * 4, 28, 28), "float32")) + w0 = relax.Var("w", R.Tensor((oc * 4, ic, 3, 3), "float32")) + w1 = relax.Var("w", R.Tensor((oc, ic, 3, 3), "float32")) + + _check_inference( + bb, + relax.op.nn.conv2d(x, w0, groups=4), + relax.TensorStructInfo((n, oc * 4, 26, 26), "float32"), + ) + _check_inference( + bb, relax.op.nn.conv2d(x, w1, groups=4), relax.TensorStructInfo((n, oc, 26, 26), "float32") + ) + + +def test_conv2d_infer_struct_info_input_channel_group_incompatible(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + ic = tir.Var("c", "int64") + oc = tir.Var("oc", "int64") + x0 = relax.Var("x", R.Tensor((2, 128, 28, 28), "float32")) + w0 = relax.Var("w", R.Tensor((48, 20, 3, 3), "float32")) + x1 = relax.Var("x", R.Tensor((n, ic * 6, 28, 28), "float32")) + w1 = relax.Var("w", R.Tensor((oc, ic - 1, 3, 3), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv2d(x0, w0, groups=6)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv2d(x1, w1, groups=6)) + + +def test_conv2d_infer_struct_info_output_channel_group_incompatible(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + ic = tir.Var("c", "int64") + oc = tir.Var("oc", "int64") + x0 = relax.Var("x", R.Tensor((2, 120, 28, 28), "float32")) + w0 = relax.Var("w", R.Tensor((128, 20, 3, 3), "float32")) + x1 = relax.Var("x", R.Tensor((n, ic * 6, 28, 28), "float32")) + w1 = relax.Var("w", R.Tensor((oc * 6 + 4, ic * 6, 3, 3), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv2d(x0, w0, groups=6)) + with pytest.raises(TVMError): + bb.normalize(relax.op.nn.conv2d(x1, w1, groups=6)) + + +def test_conv2d_non_positive_group(): + x = relax.Var("x", R.Tensor((2, 128, 28, 28), "float32")) + w = relax.Var("w", R.Tensor((48, 16, 3, 3), "float32")) + + with pytest.raises(TVMError): + relax.op.nn.conv2d(x, w, groups=0) + with pytest.raises(TVMError): + relax.op.nn.conv2d(x, w, groups=-2) + + def test_conv2d_infer_struct_info_more_input_dtype(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float16")) diff --git a/tests/python/relax/test_op_nn_pooling.py b/tests/python/relax/test_op_nn_pooling.py index 658a174d8a..a70d8af95b 100644 --- a/tests/python/relax/test_op_nn_pooling.py +++ b/tests/python/relax/test_op_nn_pooling.py @@ -170,6 +170,68 @@ def test_max_pool2d_infer_struct_info_shape_var(): ) +def test_max_pool2d_infer_struct_info_ceil_mode(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32")) + + _check_inference( + bb, + relax.op.nn.max_pool2d(x, pool_size=3, strides=2, ceil_mode=True), + relax.TensorStructInfo((2, 3, 16, 16), "float32"), + ) + _check_inference( + bb, + relax.op.nn.max_pool2d(x, pool_size=(5, 3), strides=2, ceil_mode=True), + relax.TensorStructInfo((2, 3, 15, 16), "float32"), + ) + + +def test_max_pool2d_infer_struct_info_ceil_mode_symbolic(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + c = tir.Var("c", "int64") + ih = tir.Var("ih", "int64") + iw = tir.Var("iw", "int64") + kh = tir.Var("kh", "int64") + kw = tir.Var("kw", "int64") + stride_h = tir.Var("stride_h", "int64") + stride_w = tir.Var("stride_w", "int64") + padding_t = tir.Var("padding_t", "int64") + padding_l = tir.Var("padding_l", "int64") + padding_b = tir.Var("padding_b", "int64") + padding_r = tir.Var("padding_r", "int64") + dilation_h = tir.Var("dilation_h", "int64") + dilation_w = tir.Var("dilation_w", "int64") + x = relax.Var("x", R.Tensor((n, c, ih, iw), "float32")) + + _check_inference( + bb, + relax.op.nn.max_pool2d( + x, + pool_size=(kh, kw), + strides=(stride_h, stride_w), + padding=(padding_t, padding_l, padding_b, padding_r), + dilation=(dilation_h, dilation_w), + ceil_mode=True, + ), + relax.TensorStructInfo( + ( + n, + c, + tvm.tir.div( + ih + padding_t + padding_b + stride_h - dilation_h * (kh - 1) - 2, stride_h + ) + + 1, + tvm.tir.div( + iw + padding_l + padding_r + stride_w - dilation_w * (kw - 1) - 2, stride_w + ) + + 1, + ), + "float32", + ), + ) + + def test_max_pool2d_infer_struct_info_more_input_dtype(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float16"))