Skip to content

Commit

Permalink
tests: benchdnn: graph: add tests for groupnorm
Browse files Browse the repository at this point in the history
  • Loading branch information
rongzha1 authored and TaoLv committed Jul 10, 2024
1 parent a593d00 commit 511d36a
Show file tree
Hide file tree
Showing 17 changed files with 1,069 additions and 0 deletions.
38 changes: 38 additions & 0 deletions tests/benchdnn/graph/flex_rewrite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,43 @@ void flex_rewrite::infer_output_shape(
merge_ncx(data_format, gi[out0], n, c, x);
break;
// infer_norm_output_shape
case dnnl::graph::op::kind::GroupNorm:
// infer shape for dst.
in0 = aop.in_lts_[0].id_;
out0 = aop.out_lts_[0].id_;
gi[out0] = gi[in0];
// attr `keep_stats` is optional, default is `true`
if (aop.attrs_.find("keep_stats") == aop.attrs_.end()
|| aop.attrs_["keep_stats"].bool_value_) {
// `true` means it has 3 output: dst, mean and var.
// need to infer shape for mean and var
if (aop.out_lts_.size() == 3) {
int64_t groups = 0;
if (aop.attrs_.find("groups") == aop.attrs_.end()) {
fprintf(stderr,
"graph: groups is required for "
"GroupNorm!\n");
SAFE_V(FAIL);
} else {
groups = aop.attrs_["groups"].s64_value_;
}
size_t out1 = aop.out_lts_[1].id_;
size_t out2 = aop.out_lts_[2].id_;
gi[out1].clear();
gi[out2].clear();
// mean/var shape is N,C
std::vector<int64_t> mv_shape = {gi[in0][0], groups};
gi[out1] = mv_shape;
gi[out2] = mv_shape;
} else {
fprintf(stderr,
"graph: GroupNorm output number "
"mismatch!\n");
SAFE_V(FAIL);
}
}
break;
// infer_norm_output_shape
case dnnl::graph::op::kind::LayerNorm:
in0 = aop.in_lts_[0].id_;
out0 = aop.out_lts_[0].id_;
Expand Down Expand Up @@ -1040,6 +1077,7 @@ void flex_rewrite::update_output_info(
case dnnl::graph::op::kind::Exp:
case dnnl::graph::op::kind::GELU:
case dnnl::graph::op::kind::GELUBackward:
case dnnl::graph::op::kind::GroupNorm:
case dnnl::graph::op::kind::HardSigmoid:
case dnnl::graph::op::kind::HardSigmoidBackward:
case dnnl::graph::op::kind::HardSwish:
Expand Down
1 change: 1 addition & 0 deletions tests/benchdnn/graph/ref_primitive.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ ref_primitive_t::ref_primitive_t(const deserialized_op &op) {
CASE_DRIVER(conv); \
CASE_DRIVER(deconv); \
CASE_DRIVER(eltwise); \
CASE_DRIVER(gnorm); \
CASE_DRIVER(lnorm); \
CASE_DRIVER(matmul); \
CASE_DRIVER(pool); \
Expand Down
122 changes: 122 additions & 0 deletions tests/benchdnn/graph/setting_handler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1010,6 +1010,128 @@ ::eltwise::settings_t get_setting(

} // namespace eltwise

namespace gnorm {

bool get_gnorm_desc(const deserialized_op &base_op_ref, ::gnorm::desc_t &d) {
auto src_dims = base_op_ref.in_lts_[0].shape_;
if (base_op_ref.has_NXC_format()) {
src_dims = base_op_ref.get_NCX_shape(0, true);
}
d.ndims = static_cast<int>(src_dims.size());

base_op_ref.get_attr_s64(d.g, "groups");
d.mb = src_dims[0];
d.ic = src_dims[1];
d.id = d.ndims >= 5 ? src_dims[d.ndims - 3] : 1;
d.ih = d.ndims >= 4 ? src_dims[d.ndims - 2] : 1;
d.iw = d.ndims >= 3 ? src_dims[d.ndims - 1] : 1;

d.eps = 1e-5f;
base_op_ref.get_attr_f32(d.eps, "eps");

return true;
}

bool get_gnorm_dir(const deserialized_op &base_op_ref, dir_t &dir) {
const auto &op_kind = base_op_ref.kind_;
if (op_kind == "GroupNorm") {
bool keep_stats = false;

base_op_ref.get_attr_bool(keep_stats, "keep_stats");

const size_t out_size = base_op_ref.out_lts_.size();
// output: dst, mean(opt), var(opt)
if (out_size == 1) {
dir = dir_t::FWD_I;
if (keep_stats) return false;
} else if (out_size == 3) {
dir = dir_t::FWD_D;
if (!keep_stats) return false;
} else {
return false;
}
// TODO: GroupNormBackward
} else if (op_kind == "GroupNormBackward") {
assert(!"GroupNormBackward is not supported for now");
} else {
assert(!"unsupported op_kind");
return false;
}
return true;
}

bool get_gnorm_dt(
const deserialized_op &base_op_ref, std::vector<dnnl_data_type_t> &dt) {
auto src_dt = convert_dt(base_op_ref.in_lts_[0].get_data_type());
auto dst_dt = convert_dt(base_op_ref.out_lts_[0].get_data_type());
dt = {src_dt, dst_dt};
return true;
}

bool get_gnorm_flags(
const deserialized_op &base_op_ref, ::bnorm::flags_t &flags) {
bool use_affine = false;
base_op_ref.get_attr_bool(use_affine, "use_affine");
const auto &op_kind = base_op_ref.kind_;
const size_t in_size = base_op_ref.in_lts_.size();
if (op_kind == "GroupNorm") {
// input: src, gamma(opt), beta(opt)
if (use_affine) {
if (in_size == 3) {
flags = ::gnorm::USE_SCALE | ::gnorm::USE_SHIFT;
} else {
return false;
}
} else {
if (in_size == 1) {
flags = ::gnorm::NONE;
} else {
return false;
}
}
// TODO: add GroupNormBackward
} else if (op_kind == "GroupNormBackward") {
assert(!"GroupNormBackward is not supported for now");
return false;
} else {
assert(!"unsupported op_kind");
return false;
}
return true;
}

bool get_gnorm_stag_and_dtag(const deserialized_op &base_op_ref,
std::vector<std::vector<std::string>> &tag) {
// src and dst may have different tags.
std::string stag, dtag;
if (!get_driver_tag_by_idx(base_op_ref, dtag, 0, true)
|| !get_driver_tag_by_idx(base_op_ref, stag, 0, false)) {
return false;
}
assert(!stag.empty() && !dtag.empty());
tag = {{std::move(stag), std::move(dtag)}};
return true;
}

::gnorm::settings_t get_setting(
const deserialized_op &base_op_ref, res_t *res) {
::gnorm::settings_t op_setting;
DNN_GRAPH_CHECK_SETTINGS(get_gnorm_desc(base_op_ref, op_setting.desc), res);
DNN_GRAPH_CHECK_SETTINGS(
gnorm::get_gnorm_dir(base_op_ref, op_setting.dir.front()), res);
DNN_GRAPH_CHECK_SETTINGS(
gnorm::get_gnorm_dt(base_op_ref, op_setting.dt.front()), res);
DNN_GRAPH_CHECK_SETTINGS(
gnorm::get_gnorm_stag_and_dtag(base_op_ref, op_setting.tag), res);
DNN_GRAPH_CHECK_SETTINGS(
gnorm::get_gnorm_flags(base_op_ref, op_setting.flags.front()), res);
DNN_GRAPH_CHECK_SETTINGS(
get_graph_attr(base_op_ref, op_setting.fpmath_mode.front()), res);
return op_setting;
}

} // namespace gnorm

namespace lnorm {

bool get_lnorm_dir(const deserialized_op &base_op_ref, dir_t &dir) {
Expand Down
3 changes: 3 additions & 0 deletions tests/benchdnn/graph/setting_handler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "custom_driver.hpp"
#include "deconv/deconv.hpp"
#include "eltwise/eltwise.hpp"
#include "gnorm/gnorm.hpp"
#include "lnorm/lnorm.hpp"
#include "matmul/matmul.hpp"
#include "pool/pool.hpp"
Expand All @@ -50,6 +51,7 @@ DECLARE_GET_SETTING(conv);
DECLARE_GET_SETTING(custom);
DECLARE_GET_SETTING(deconv);
DECLARE_GET_SETTING(eltwise);
DECLARE_GET_SETTING(gnorm);
DECLARE_GET_SETTING(lnorm);
DECLARE_GET_SETTING(matmul);
DECLARE_GET_SETTING(pool);
Expand Down Expand Up @@ -90,6 +92,7 @@ DECLARE_TEMPLATE_GET_SETTING(conv);
DECLARE_TEMPLATE_GET_SETTING(custom);
DECLARE_TEMPLATE_GET_SETTING(deconv);
DECLARE_TEMPLATE_GET_SETTING(eltwise);
DECLARE_TEMPLATE_GET_SETTING(gnorm);
DECLARE_TEMPLATE_GET_SETTING(lnorm);
DECLARE_TEMPLATE_GET_SETTING(matmul);
DECLARE_TEMPLATE_GET_SETTING(pool);
Expand Down
31 changes: 31 additions & 0 deletions tests/benchdnn/graph/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,7 @@ dnnl::graph::op::kind opstr2kind(const std::string &kind) {
{"Exp", dnnl::graph::op::kind::Exp},
{"GELU", dnnl::graph::op::kind::GELU},
{"GELUBackward", dnnl::graph::op::kind::GELUBackward},
{"GroupNorm", dnnl::graph::op::kind::GroupNorm},
{"HardSigmoid", dnnl::graph::op::kind::HardSigmoid},
{"HardSigmoidBackward", dnnl::graph::op::kind::HardSigmoidBackward},
{"HardSwish", dnnl::graph::op::kind::HardSwish},
Expand Down Expand Up @@ -561,6 +562,7 @@ dnnl_driver_t opkind2driver(const dnnl::graph::op::kind &kind) {
{dnnl::graph::op::kind::GELU, dnnl_driver_t::eltwise},
{dnnl::graph::op::kind::GELUBackward,
dnnl_driver_t::eltwise},
{dnnl::graph::op::kind::GroupNorm, dnnl_driver_t::gnorm},
{dnnl::graph::op::kind::HardSigmoid,
dnnl_driver_t::eltwise},
{dnnl::graph::op::kind::HardSigmoidBackward,
Expand Down Expand Up @@ -885,6 +887,21 @@ int get_prim_arg_name_from_graph_op_output_offset(
return -1;
}
} break;
case dnnl::graph::op::kind::GroupNorm: {
if (output_offset == 0)
return DNNL_ARG_DST;
else if (output_offset == 1)
return DNNL_ARG_MEAN;
else if (output_offset == 2)
return DNNL_ARG_VARIANCE;
else {
BENCHDNN_PRINT(0, "Error: no matching ARG for offset %d",
static_cast<int>(output_offset));
assert(false);
return -1;
}

} break;
default: {
return DNNL_ARG_DST;
} break;
Expand Down Expand Up @@ -1230,6 +1247,20 @@ int get_prim_arg_name_from_graph_op_input_offset(
return -1;
}
} break;
case dnnl::graph::op::kind::GroupNorm: {
if (input_offset == 0)
return DNNL_ARG_SRC;
else if (input_offset == 1)
return DNNL_ARG_SCALE;
else if (input_offset == 2)
return DNNL_ARG_SHIFT;
else {
BENCHDNN_PRINT(0, "Error: no matching ARG for offset %zu",
input_offset);
assert(false);
return -1;
}
} break;
default: {
return DNNL_ARG_SRC;
} break;
Expand Down
1 change: 1 addition & 0 deletions tests/benchdnn/graph/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ enum class dnnl_driver_t {
reorder,
resampling,
softmax,
gnorm,
others
};

Expand Down
106 changes: 106 additions & 0 deletions tests/benchdnn/inputs/graph/op/bf16/gnorm.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
{
"version": "3.6.0",
"engine_kind": "cpu",
"fpmath_mode": "strict",
"input_ports": [
0,
2,
3
],
"output_ports": [
6
],
"graph": [
{
"id": 6138621056,
"name": "aten::group_norm",
"kind": "GroupNorm",
"attrs": {
"data_format": {
"type": "string",
"value": "NCX"
},
"use_affine": {
"type": "bool",
"value": 1
},
"keep_stats": {
"type": "bool",
"value": 0
},
"epsilon": {
"type": "f32",
"value": 1e-05
},
"groups": {
"type": "s64",
"value": 32
}
},
"inputs": [
{
"id": 0,
"dtype": "bf16",
"shape": [
2,
320,
48,
48
],
"stride": [
737280,
1,
15360,
320
],
"layout_type": "strided",
"property_type": "variable"
},
{
"id": 2,
"dtype": "f32",
"shape": [
320
],
"stride": [
1
],
"layout_type": "strided",
"property_type": "constant"
},
{
"id": 3,
"dtype": "f32",
"shape": [
320
],
"stride": [
1
],
"layout_type": "strided",
"property_type": "constant"
}
],
"outputs": [
{
"id": 6,
"dtype": "bf16",
"shape": [
2,
320,
48,
48
],
"stride": [
737280,
1,
15360,
320
],
"layout_type": "strided",
"property_type": "variable"
}
]
}
]
}
Loading

0 comments on commit 511d36a

Please sign in to comment.