Skip to content

Commit

Permalink
gpu: ocl: enable blocked int4 reorder
Browse files Browse the repository at this point in the history
  • Loading branch information
kealan-barbieri committed May 3, 2024
1 parent 64f06d6 commit 54748c0
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 25 deletions.
4 changes: 4 additions & 0 deletions src/gpu/intel/ocl/ref_reorder.cl
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,11 @@ __kernel void ref_reorder(__global SRC_DATA_T *restrict src,
int pad_d4 = NDIMS > 4 && d4 >= SRC_D4;
int pad_d5 = NDIMS > 5 && d5 >= SRC_D5;
if (pad_d0 || pad_d1 || pad_d2 || pad_d3 || pad_d4 || pad_d5) {
#if TO_I4
SET_DOUBLE_HALF_BYTE(dst, dst_off, 0);
#else
dst[dst_off] = 0;
#endif
continue;
}
#endif
Expand Down
19 changes: 8 additions & 11 deletions src/gpu/intel/ocl/ref_reorder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,6 @@ status_t ref_reorder_t::pd_t::init_conf(engine_t *engine) {

status_t status = status::success;

if (!IMPLICATION(
utils::one_of(src_mdw.data_type(), data_type::s4, data_type::u4)
|| utils::one_of(dst_mdw.data_type(), data_type::s4,
data_type::u4),
dst_mdw.is_plain() && src_mdw.is_plain()))
return status::unimplemented;

const auto &padded_dims = dst_mdw.padded_dims();
conf.src_quant = {attr(), src_mdw, DNNL_ARG_SRC};
conf.dst_quant = {attr(), dst_mdw, DNNL_ARG_DST};
Expand Down Expand Up @@ -116,12 +109,16 @@ status_t ref_reorder_t::pd_t::init_kernel_ctx(
auto &dst_blk = dst_md()->format_desc.blocking;

int dst_contig_dim = -1;
if (dst_blk.inner_nblks > 0)
dst_contig_dim = dst_blk.inner_idxs[0];
else
if (dst_blk.inner_nblks > 0) {
for (int i = dst_md()->ndims; i >= 0; i--)
if (dst_blk.inner_idxs[i] != 0) {
dst_contig_dim = dst_blk.inner_idxs[i];
break;
}
} else {
for (int i = 0; i < dst_md()->ndims; i++)
if (dst_blk.strides[i] == 1) dst_contig_dim = i;

}
// TODO: also check that innermost block or dimension has even size
if (dst_contig_dim < 0) return status::unimplemented;

Expand Down
30 changes: 22 additions & 8 deletions tests/benchdnn/inputs/reorder/test_reorder_int4
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,29 @@
--reset
--sdt=f32 --ddt=u4,s4
--stag=bax,abx
--dtag=abx,bax
2x64x14x14 2x56x14x14
2x64x64x3x3 2x56x56x3x3
4x16x16x3x3 2x16x6x3x2 2x2x10x2x3
--dtag=abx,bax 2x64x14x14 2x56x14x14
--dtag=abx,bax 2x64x64x3x3 2x56x56x3x3
--dtag=abx,bax 4x16x16x3x3 2x16x6x3x2 2x2x10x2x3
--dtag=aBx16b 2x64x14x14 2x56x14x14
--dtag=gOIhw16i16o 2x64x64x3x3 2x56x56x3x3
--dtag=gOIhw8i16o2i 2x64x64x3x3 2x56x56x3x3
--dtag=gOIhw8o16i2o 2x64x64x3x3 2x56x56x3x3
--dtag=gOIhw2i4o2i 4x16x16x3x3 2x16x6x3x2 2x2x10x2x3
--dtag=gOIhw2o4i2o 4x16x16x3x3 2x16x6x3x2 2x2x10x2x3
--dtag=gOIhw4i8o2i 4x16x16x3x3 2x16x6x3x2 2x2x10x2x3
--dtag=gOIhw4o8i2o 4x16x16x3x3 2x16x6x3x2 2x2x10x2x3

--reset
--sdt=u4,s4 --ddt=f32
--dtag=abx,bax
--stag=bax,abx
2x64x14x14 2x56x14x14
2x64x64x3x3 2x56x56x3x3
4x16x16x3x3 2x16x6x3x2 2x2x10x2x3
--stag=bax,abx 2x64x14x14 2x56x14x14
--stag=bax,abx 2x64x64x3x3 2x56x56x3x3
--stag=bax,abx 4x16x16x3x3 2x16x6x3x2 2x2x10x2x3
--stag=aBx16b 2x64x14x14 2x56x14x14
--stag=gOIhw16i16o 2x64x64x3x3 2x56x56x3x3
--stag=gOIhw8i16o2i 2x64x64x3x3 2x56x56x3x3
--stag=gOIhw8o16i2o 2x64x64x3x3 2x56x56x3x3
--stag=aBCde2b4c2b 4x16x16x3x3 2x16x6x3x2 2x2x10x2x3
--stag=aBCde2c4b2c 4x16x16x3x3 2x16x6x3x2 2x2x10x2x3
--stag=aBCde4b8c2b 4x16x16x3x3 2x16x6x3x2 2x2x10x2x3
--stag=aBCde4c8b2c 4x16x16x3x3 2x16x6x3x2 2x2x10x2x3
13 changes: 7 additions & 6 deletions tests/benchdnn/reorder/reorder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -279,13 +279,14 @@ void skip_unimplemented_prb(const prb_t *prb, res_t *res) {
}
}

// Int4 reorder support is limited on all platforms.
if (sdt == dnnl_s4 || ddt == dnnl_s4 || sdt == dnnl_u4 || ddt == dnnl_u4) {
res->state = SKIPPED, res->reason = CASE_NOT_SUPPORTED;
return;
}

if (is_cpu()) {
// Int4 reorder support is limited on CPU.
if (sdt == dnnl_s4 || ddt == dnnl_s4 || sdt == dnnl_u4
|| ddt == dnnl_u4) {
res->state = SKIPPED, res->reason = CASE_NOT_SUPPORTED;
return;
}

// CPU reorder doesn't support (xf8,xf16)<-->s32 combinations.
const bool s32_src_ok = IMPLICATION(sdt == dnnl_s32,
ddt != dnnl_f8_e5m2 && ddt != dnnl_f8_e4m3 && ddt != dnnl_bf16
Expand Down

0 comments on commit 54748c0

Please sign in to comment.