Skip to content
This repository has been archived by the owner on May 22, 2023. It is now read-only.

Commit

Permalink
[Op] Group for conv2d, ceil mode for max_pool2d (#358)
Browse files Browse the repository at this point in the history
This PR adds the parameter group for conv2d, as well as the option ceil_mode for max_pool2d. They were omitted in the previous round of PR mainly for simplicity, and are introduced now due to their uses in ML frameworks.
  • Loading branch information
MasterJH5574 authored Jan 14, 2023
1 parent c041842 commit b380aed
Show file tree
Hide file tree
Showing 8 changed files with 204 additions and 17 deletions.
8 changes: 8 additions & 0 deletions include/tvm/relax/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ struct Conv2DAttrs : public tvm::AttrsNode<Conv2DAttrs> {
Array<PrimExpr> strides;
Array<PrimExpr> padding;
Array<PrimExpr> dilation;
int groups;
String data_layout;
String kernel_layout;
String out_layout;
Expand All @@ -49,6 +50,9 @@ struct Conv2DAttrs : public tvm::AttrsNode<Conv2DAttrs> {
"four int : padding width in the order of (top, left, bottom, right)");
TVM_ATTR_FIELD(dilation).describe(
"Specifies the dilation rate to use for dilated convolution.");
TVM_ATTR_FIELD(groups).describe(
"Number of groups to split the input into for grouped convolution. The number of input and "
"output channels should be divisible by the number of groups.");
TVM_ATTR_FIELD(data_layout)
.describe(
"Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
Expand Down Expand Up @@ -76,6 +80,7 @@ struct MaxPool2DAttrs : public tvm::AttrsNode<MaxPool2DAttrs> {
Array<PrimExpr> strides;
Array<PrimExpr> padding;
Array<PrimExpr> dilation;
bool ceil_mode;
String layout;
String out_layout;

Expand All @@ -89,6 +94,9 @@ struct MaxPool2DAttrs : public tvm::AttrsNode<MaxPool2DAttrs> {
"one int : same padding used on all sides"
"two int : bottom, right will use same padding as top, left"
"four int : padding width in the order of (top, left, bottom, right)");
TVM_ATTR_FIELD(ceil_mode).describe(
"A boolean indicating if use ceil or floor to compute the output shape. By using ceil, "
"every element in the input tensor will be covered by a sliding window.");
TVM_ATTR_FIELD(layout).describe(
"Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
Expand Down
13 changes: 12 additions & 1 deletion python/tvm/relax/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def conv2d(
strides: Union[PrimExprLike, Tuple[PrimExprLike]] = (1, 1),
padding: Union[PrimExprLike, Tuple[PrimExprLike]] = (0, 0),
dilation: Union[PrimExprLike, Tuple[PrimExprLike]] = (1, 1),
groups: int = 1,
data_layout: str = "NCHW",
kernel_layout: str = "OIHW",
out_layout: Optional[str] = None,
Expand Down Expand Up @@ -81,6 +82,10 @@ def conv2d(
Specifies the dilation rate to be used for dilated convolution.
It is required to have length either 1 or 2.
groups : int
Number of groups to split the input into for grouped convolution.
The number of input and output channels should be divisible by the number of groups.
data_layout : str
Layout of the input.
Expand Down Expand Up @@ -111,6 +116,7 @@ def conv2d(
strides,
padding,
dilation,
groups,
data_layout,
kernel_layout,
out_layout,
Expand All @@ -124,6 +130,7 @@ def max_pool2d(
strides: Union[PrimExprLike, Tuple[PrimExprLike]] = (1, 1),
padding: Union[PrimExprLike, Tuple[PrimExprLike]] = (0, 0),
dilation: Union[PrimExprLike, Tuple[PrimExprLike]] = (1, 1),
ceil_mode: bool = False,
layout: str = "NCHW",
out_layout: Optional[str] = None,
) -> Expr:
Expand Down Expand Up @@ -165,6 +172,10 @@ def max_pool2d(
dilation : Union[PrimExprLike, Tuple[PrimExprLike]]
The dilation of pooling. It is required to have length either 1 or 2.
ceil_mode : bool
A boolean indicating if use ceil or floor to compute the output shape.
By using ceil, every element in the input tensor will be covered by a sliding window.
layout : str
Layout of the input.
Expand All @@ -186,7 +197,7 @@ def max_pool2d(
padding = (padding, padding, padding, padding)

return _ffi_api.max_pool2d( # type: ignore
data, pool_size, strides, padding, dilation, layout, out_layout
data, pool_size, strides, padding, dilation, ceil_mode, layout, out_layout
)


Expand Down
32 changes: 24 additions & 8 deletions src/relax/op/nn/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ namespace relax {
TVM_REGISTER_NODE_TYPE(Conv2DAttrs);

Expr conv2d(Expr data, Expr weight, Array<PrimExpr> strides, Array<PrimExpr> padding,
Array<PrimExpr> dilation, String data_layout, String kernel_layout,
Array<PrimExpr> dilation, int groups, String data_layout, String kernel_layout,
Optional<String> out_layout, DataType out_dtype) {
padding = GetCompletePadding2D(std::move(padding));
if (strides.size() == 1) {
Expand All @@ -43,13 +43,16 @@ Expr conv2d(Expr data, Expr weight, Array<PrimExpr> strides, Array<PrimExpr> pad
dilation.push_back(dilation[0]);
}

CHECK_GT(groups, 0) << "The number of groups in convolution is expected to be positive. However, "
"the given number of groups is "
<< groups;
CHECK_EQ(strides.size(), 2)
<< "The input strides length is expected to be 2. However, the given strides is " << strides;
CHECK_EQ(dilation.size(), 2)
<< "The input dilation length is expected to be 2. However, the given dilation is "
<< dilation;
return MakeConv<Conv2DAttrs>(std::move(data), std::move(weight), std::move(strides),
std::move(padding), std::move(dilation), data_layout,
std::move(padding), std::move(dilation), groups, data_layout,
std::move(kernel_layout), out_layout.value_or(data_layout),
out_dtype, /*op_name=*/"relax.nn.conv2d");
}
Expand Down Expand Up @@ -90,12 +93,25 @@ StructInfo InferStructInfoConv2d(const Call& call, const BlockBuilder& ctx) {
arith::Analyzer* analyzer = ctx->GetAnalyzer();
PrimExpr input_channel_data = data_NCHW_shape[1];
PrimExpr input_channel_kernel = weight_OIHW_shape[1];
if (analyzer->CanProve(input_channel_data != input_channel_kernel)) {
ctx->ReportFatal(Diagnostic::Error(call->span)
<< "The channel size of the data should equal to the input channel size of "
"the weight. However, the data channel size is "
<< input_channel_data << " while the weight input channel size is "
<< input_channel_kernel);
if (analyzer->CanProve(input_channel_data != input_channel_kernel * attrs->groups)) {
ctx->ReportFatal(
Diagnostic::Error(call)
<< "The channel size of the data should equal to the product of input channel size of the "
"weight and the number of groups. However, the data channel size is "
<< input_channel_data << " while the weight input channel size and number of groups are "
<< input_channel_kernel << " and " << attrs->groups);
} else if (!analyzer->CanProveEqual(input_channel_data, input_channel_kernel * attrs->groups)) {
// Todo(relax-team): Trust the input shape at this moment, and revisit
// this condition with runtime shape check
}
if (analyzer->CanProve(floormod(weight_OIHW_shape[0], attrs->groups) != 0)) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "Conv2d expects the number of output channels to be divisible by the "
"number of groups. However, the number of output channels is "
<< weight_OIHW_shape[0] << " while the number of groups is " << attrs->groups);
} else if (!analyzer->CanProveEqual(floormod(weight_OIHW_shape[0], attrs->groups), 0)) {
// Todo(relax-team): Trust the input shape at this moment, and revisit
// this condition with runtime shape check
}

PrimExpr input_h = data_NCHW_shape[2];
Expand Down
5 changes: 3 additions & 2 deletions src/relax/op/nn/convolution.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,13 @@ namespace relax {

template <typename T>
inline Expr MakeConv(Expr data, Expr weight, Array<PrimExpr> strides, Array<PrimExpr> padding,
Array<PrimExpr> dilation, String data_layout, String kernel_layout,
Array<PrimExpr> dilation, int groups, String data_layout, String kernel_layout,
String out_layout, DataType out_dtype, std::string op_name) {
auto attrs = make_object<T>();
attrs->strides = std::move(strides);
attrs->padding = std::move(padding);
attrs->dilation = std::move(dilation);
attrs->groups = groups;
attrs->data_layout = std::move(data_layout);
attrs->kernel_layout = std::move(kernel_layout);
attrs->out_layout = std::move(out_layout);
Expand All @@ -53,7 +54,7 @@ inline Expr MakeConv(Expr data, Expr weight, Array<PrimExpr> strides, Array<Prim

/*! \brief 2D convolution */
Expr conv2d(Expr data, Expr weight, Array<PrimExpr> strides, Array<PrimExpr> padding,
Array<PrimExpr> dilation, String data_layout, String kernel_layout,
Array<PrimExpr> dilation, int groups, String data_layout, String kernel_layout,
Optional<String> out_layout, DataType out_dtype);

} // namespace relax
Expand Down
16 changes: 11 additions & 5 deletions src/relax/op/nn/pooling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ namespace relax {
TVM_REGISTER_NODE_TYPE(MaxPool2DAttrs);

Expr max_pool2d(Expr data, Array<PrimExpr> pool_size, Array<PrimExpr> strides,
Array<PrimExpr> padding, Array<PrimExpr> dilation, String layout,
Array<PrimExpr> padding, Array<PrimExpr> dilation, bool ceil_mode, String layout,
Optional<String> out_layout) {
padding = GetCompletePadding2D(std::move(padding));
if (pool_size.size() == 1) {
Expand All @@ -56,6 +56,7 @@ Expr max_pool2d(Expr data, Array<PrimExpr> pool_size, Array<PrimExpr> strides,
attrs->strides = std::move(strides);
attrs->padding = std::move(padding);
attrs->dilation = std::move(dilation);
attrs->ceil_mode = ceil_mode;
attrs->layout = layout;
attrs->out_layout = out_layout.value_or(layout);
static const Op& op = Op::Get("relax.nn.max_pool2d");
Expand Down Expand Up @@ -95,10 +96,15 @@ StructInfo InferStructInfoMaxPool2D(const Call& call, const BlockBuilder& ctx) {
out_NCHW_shape.resize(4);
out_NCHW_shape[0] = data_NCHW_shape[0];
out_NCHW_shape[1] = data_NCHW_shape[1];
out_NCHW_shape[2] = analyzer->Simplify(
(input_h + padding_h - attrs->dilation[0] * (kernel_h - 1) - 1) / attrs->strides[0] + 1);
out_NCHW_shape[3] = analyzer->Simplify(
(input_w + padding_w - attrs->dilation[1] * (kernel_w - 1) - 1) / attrs->strides[1] + 1);

PrimExpr numerator_h = input_h + padding_h - attrs->dilation[0] * (kernel_h - 1) - 1;
PrimExpr numerator_w = input_w + padding_w - attrs->dilation[1] * (kernel_w - 1) - 1;
if (attrs->ceil_mode) {
numerator_h += attrs->strides[0] - 1;
numerator_w += attrs->strides[1] - 1;
}
out_NCHW_shape[2] = analyzer->Simplify(numerator_h / attrs->strides[0] + 1);
out_NCHW_shape[3] = analyzer->Simplify(numerator_w / attrs->strides[1] + 1);

Array<PrimExpr> out_shape = out2NCHW.BackwardShape(out_NCHW_shape);
return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype);
Expand Down
2 changes: 1 addition & 1 deletion src/relax/op/nn/pooling.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ namespace relax {

/*! \brief 2D maximum pooling operator. */
Expr max_pool2d(Expr data, Array<PrimExpr> pool_size, Array<PrimExpr> strides,
Array<PrimExpr> padding, Array<PrimExpr> dilation, String layout,
Array<PrimExpr> padding, Array<PrimExpr> dilation, bool ceil_mode, String layout,
Optional<String> out_layout);

/*! \brief 2D adaptive average pooling operator. */
Expand Down
83 changes: 83 additions & 0 deletions tests/python/relax/test_op_nn_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,89 @@ def test_conv2d_infer_struct_info_shape_var():
)


def test_conv2d_infer_struct_info_groups():
bb = relax.BlockBuilder()
x0 = relax.Var("x", R.Tensor((2, 128, 28, 28), "float32"))
x1 = relax.Var("x", R.Tensor((2, 8, 28, 28, 16), "float32"))
w0 = relax.Var("w", R.Tensor((48, 16, 3, 3), "float32"))
w1 = relax.Var("w", R.Tensor((48, 2, 3, 3, 8), "float32"))

_check_inference(
bb, relax.op.nn.conv2d(x0, w0, groups=8), relax.TensorStructInfo((2, 48, 26, 26), "float32")
)
_check_inference(
bb,
relax.op.nn.conv2d(x0, w1, kernel_layout="OIHW8i", groups=8),
relax.TensorStructInfo((2, 48, 26, 26), "float32"),
)
_check_inference(
bb,
relax.op.nn.conv2d(x1, w0, data_layout="NCHW16c", groups=8),
relax.TensorStructInfo((2, 3, 26, 26, 16), "float32"),
)


def test_conv2d_infer_struct_info_symbolic_groups():
bb = relax.BlockBuilder()
n = tir.Var("n", "int64")
ic = tir.Var("c", "int64")
oc = tir.Var("oc", "int64")
x = relax.Var("x", R.Tensor((n, ic * 4, 28, 28), "float32"))
w0 = relax.Var("w", R.Tensor((oc * 4, ic, 3, 3), "float32"))
w1 = relax.Var("w", R.Tensor((oc, ic, 3, 3), "float32"))

_check_inference(
bb,
relax.op.nn.conv2d(x, w0, groups=4),
relax.TensorStructInfo((n, oc * 4, 26, 26), "float32"),
)
_check_inference(
bb, relax.op.nn.conv2d(x, w1, groups=4), relax.TensorStructInfo((n, oc, 26, 26), "float32")
)


def test_conv2d_infer_struct_info_input_channel_group_incompatible():
bb = relax.BlockBuilder()
n = tir.Var("n", "int64")
ic = tir.Var("c", "int64")
oc = tir.Var("oc", "int64")
x0 = relax.Var("x", R.Tensor((2, 128, 28, 28), "float32"))
w0 = relax.Var("w", R.Tensor((48, 20, 3, 3), "float32"))
x1 = relax.Var("x", R.Tensor((n, ic * 6, 28, 28), "float32"))
w1 = relax.Var("w", R.Tensor((oc, ic - 1, 3, 3), "float32"))

with pytest.raises(TVMError):
bb.normalize(relax.op.nn.conv2d(x0, w0, groups=6))
with pytest.raises(TVMError):
bb.normalize(relax.op.nn.conv2d(x1, w1, groups=6))


def test_conv2d_infer_struct_info_output_channel_group_incompatible():
bb = relax.BlockBuilder()
n = tir.Var("n", "int64")
ic = tir.Var("c", "int64")
oc = tir.Var("oc", "int64")
x0 = relax.Var("x", R.Tensor((2, 120, 28, 28), "float32"))
w0 = relax.Var("w", R.Tensor((128, 20, 3, 3), "float32"))
x1 = relax.Var("x", R.Tensor((n, ic * 6, 28, 28), "float32"))
w1 = relax.Var("w", R.Tensor((oc * 6 + 4, ic * 6, 3, 3), "float32"))

with pytest.raises(TVMError):
bb.normalize(relax.op.nn.conv2d(x0, w0, groups=6))
with pytest.raises(TVMError):
bb.normalize(relax.op.nn.conv2d(x1, w1, groups=6))


def test_conv2d_non_positive_group():
x = relax.Var("x", R.Tensor((2, 128, 28, 28), "float32"))
w = relax.Var("w", R.Tensor((48, 16, 3, 3), "float32"))

with pytest.raises(TVMError):
relax.op.nn.conv2d(x, w, groups=0)
with pytest.raises(TVMError):
relax.op.nn.conv2d(x, w, groups=-2)


def test_conv2d_infer_struct_info_more_input_dtype():
bb = relax.BlockBuilder()
x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float16"))
Expand Down
62 changes: 62 additions & 0 deletions tests/python/relax/test_op_nn_pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,68 @@ def test_max_pool2d_infer_struct_info_shape_var():
)


def test_max_pool2d_infer_struct_info_ceil_mode():
bb = relax.BlockBuilder()
x = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32"))

_check_inference(
bb,
relax.op.nn.max_pool2d(x, pool_size=3, strides=2, ceil_mode=True),
relax.TensorStructInfo((2, 3, 16, 16), "float32"),
)
_check_inference(
bb,
relax.op.nn.max_pool2d(x, pool_size=(5, 3), strides=2, ceil_mode=True),
relax.TensorStructInfo((2, 3, 15, 16), "float32"),
)


def test_max_pool2d_infer_struct_info_ceil_mode_symbolic():
bb = relax.BlockBuilder()
n = tir.Var("n", "int64")
c = tir.Var("c", "int64")
ih = tir.Var("ih", "int64")
iw = tir.Var("iw", "int64")
kh = tir.Var("kh", "int64")
kw = tir.Var("kw", "int64")
stride_h = tir.Var("stride_h", "int64")
stride_w = tir.Var("stride_w", "int64")
padding_t = tir.Var("padding_t", "int64")
padding_l = tir.Var("padding_l", "int64")
padding_b = tir.Var("padding_b", "int64")
padding_r = tir.Var("padding_r", "int64")
dilation_h = tir.Var("dilation_h", "int64")
dilation_w = tir.Var("dilation_w", "int64")
x = relax.Var("x", R.Tensor((n, c, ih, iw), "float32"))

_check_inference(
bb,
relax.op.nn.max_pool2d(
x,
pool_size=(kh, kw),
strides=(stride_h, stride_w),
padding=(padding_t, padding_l, padding_b, padding_r),
dilation=(dilation_h, dilation_w),
ceil_mode=True,
),
relax.TensorStructInfo(
(
n,
c,
tvm.tir.div(
ih + padding_t + padding_b + stride_h - dilation_h * (kh - 1) - 2, stride_h
)
+ 1,
tvm.tir.div(
iw + padding_l + padding_r + stride_w - dilation_w * (kw - 1) - 2, stride_w
)
+ 1,
),
"float32",
),
)


def test_max_pool2d_infer_struct_info_more_input_dtype():
bb = relax.BlockBuilder()
x0 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float16"))
Expand Down

0 comments on commit b380aed

Please sign in to comment.