From bedac807851acc75e91f50955a6b2b162a2ffbe6 Mon Sep 17 00:00:00 2001 From: Nikhil Sharma <150424968+nikhilfujitsu@users.noreply.github.com> Date: Thu, 16 May 2024 03:07:06 +0530 Subject: [PATCH] Update cpu_batch_normalization_list.cpp cpu: aarch64: Expand ARM SVE support in jit_uni_batch_normalization Added sve_256 in the implementation list Update jit_uni_batch_normalization.cpp Updated the block size definition to accommodate different ISAs. Added 'OR' conditions to extend support for additonal block_size. Predicate registers are set according to isa vector length. ldr and str instruction changed to ldw1 and stw1 respectively. To support load and store operations as per ISA. --- .../aarch64/jit_uni_batch_normalization.cpp | 127 +++++++++++------- src/cpu/cpu_batch_normalization_list.cpp | 4 +- 2 files changed, 82 insertions(+), 49 deletions(-) diff --git a/src/cpu/aarch64/jit_uni_batch_normalization.cpp b/src/cpu/aarch64/jit_uni_batch_normalization.cpp index 6e8ce32445d..455305082ce 100644 --- a/src/cpu/aarch64/jit_uni_batch_normalization.cpp +++ b/src/cpu/aarch64/jit_uni_batch_normalization.cpp @@ -1,6 +1,6 @@ /******************************************************************************* * Copyright 2020-2022 Intel Corporation -* Copyright 2020-2022 FUJITSU LIMITED +* Copyright 2020-2024 FUJITSU LIMITED * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -430,7 +430,7 @@ struct jit_bnorm_t : public jit_generator { #undef STR_PARAM } - void prepare_tail_mask_sve_512() { + void prepare_tail_mask() { if (!is_c_padded()) return; const int tail = pd_->C() % (int)(vlen / sizeof(float)); set_preg(ktail_mask.s, tail, X_TMP_0, X_TMP_1); @@ -447,18 +447,18 @@ struct jit_bnorm_t : public jit_generator { if (with_relu) uni_clear(vzero); } - void fwd_process_relu_sve_512(ZRegS vdst, int offt = 0) { + void fwd_process_relu(ZRegS vdst, int offt = 0) { const int bits = bit_shift(); const int offset = offt / (1 << bits); XReg r = jbp_->is_nspc_ ? reg_soff_nspc : reg_soff; ZRegS zzero = ZRegS(vzero.getIdx()); - assert(isa == sve_512); + assert(isa == sve_256 || isa == sve_512); assert(bits < 64); lsr(r, r, bits); fcmlt(kstore_mask.s, P_ALL_ONE / T_z, zzero, vdst); - sub(X_DEFAULT_ADDR, sp, 8); // sve_512 + sub(X_DEFAULT_ADDR, sp, 8); uzp1(p_tmp0.b, kstore_mask.b, kstore_mask.b); uzp1(p_tmp0.b, p_tmp0.b, p_tmp0.b); str(p_tmp0, ptr(X_DEFAULT_ADDR)); @@ -472,11 +472,11 @@ struct jit_bnorm_t : public jit_generator { lsl(r, r, bit_shift()); } - void fwd_process_relu_alpha_sve_512(TRegS vmm_dst) { + void fwd_process_relu_alpha(TRegS vmm_dst) { ZRegS dst = ZRegS(vmm_dst.getIdx()); ZRegS z_tmp0 = ZRegS(t_tmp0.getIdx()); - assert(isa == sve_512); + assert(isa == sve_256 || isa == sve_512); add_imm(X_DEFAULT_ADDR, sp, (int)stack_off_relu_alpha, X_TMP_0); ld1rw(ZRegS(t_tmp0.getIdx()), P_ALL_ONE / T_z, ptr(X_DEFAULT_ADDR)); @@ -486,7 +486,7 @@ struct jit_bnorm_t : public jit_generator { sel(dst, kstore_mask, dst, z_tmp0); } - void bwd_process_relu_sve_512(ZRegS vdiff_dst, int offt = 0) { + void bwd_process_relu(ZRegS vdiff_dst, int offt = 0) { const int bits = bit_shift(); const int offset = offt / (1 << bits); XReg r = jbp_->is_nspc_ ? reg_soff_nspc : reg_soff; @@ -498,7 +498,7 @@ struct jit_bnorm_t : public jit_generator { if (offset) add_imm(X_DEFAULT_ADDR, X_DEFAULT_ADDR, offset, X_TMP_0); ldrh(W_TMP_0, ptr(X_DEFAULT_ADDR)); - sub(X_DEFAULT_ADDR, sp, 8); // sve_512 + sub(X_DEFAULT_ADDR, sp, 8); str(X_TMP_0, ptr(X_DEFAULT_ADDR)); ldr(kstore_mask, ptr(X_DEFAULT_ADDR)); zip1(kstore_mask.b, kstore_mask.b, kstore_mask.b); @@ -512,7 +512,9 @@ struct jit_bnorm_t : public jit_generator { ldr(QReg(IDX(v)), ptr(x)); } - void uni_load_spat_data(const ZReg &z, const XReg &x) { ldr(z, ptr(x)); } + void uni_load_spat_data(const ZReg &z, const XReg &x) { + ld1w(z.s, P_ALL_ONE / T_z, ptr(x)); + } void uni_store_spat_data( const XReg &x, const VReg &v, bool is_nt_store = false) { @@ -522,7 +524,7 @@ struct jit_bnorm_t : public jit_generator { void uni_store_spat_data( const XReg &x, const ZReg &z, bool is_nt_store = false) { - stnt1w(z.s, P_ALL_ONE, ptr(x)); + stnt1w(z.s, P_ALL_ONE / T_z, ptr(x)); } void jump_check(const Label &l_no_mask) { @@ -541,7 +543,8 @@ struct jit_bnorm_t : public jit_generator { if (is_c_padded()) { jump_check(l_no_mask); - if (isa == sve_512) ld1w(ZRegS(IDX(t)), ktail_mask / T_z, ptr(x)); + assert(isa == sve_256 || isa == sve_512); + ld1w(ZRegS(IDX(t)), ktail_mask / T_z, ptr(x)); b(l_ret); } L(l_no_mask); @@ -554,7 +557,8 @@ struct jit_bnorm_t : public jit_generator { if (is_c_padded()) { jump_check(l_no_mask); - if (isa == sve_512) st1w(ZRegS(IDX(t)), ktail_mask / T_z, ptr(x)); + assert(isa == sve_256 || isa == sve_512); + st1w(ZRegS(IDX(t)), ktail_mask / T_z, ptr(x)); b(l_ret); } L(l_no_mask); @@ -589,7 +593,9 @@ struct jit_bnorm_t : public jit_generator { void uni_ldr(const VReg &v, const XReg &x) { ldr(QReg(IDX(v)), ptr(x)); } - void uni_ldr(const ZReg &z, const XReg &x) { ldr(z, ptr(x)); } + void uni_ldr(const ZReg &z, const XReg &x) { + ld1w(z.s, P_ALL_ONE / T_z, ptr(x)); + } void uni_str(const VReg &v, const XReg &base, const XReg &off = XReg(DUMMY_IDX), const int disp = 0) { @@ -615,7 +621,7 @@ struct jit_bnorm_t : public jit_generator { void uni_str(const ZReg &z, const XReg &base, const XReg &off = XReg(DUMMY_IDX), const int disp = 0) { - str(z, ptr(xreg_addr(base, off, disp))); + st1w(z.s, P_ALL_ONE / T_z, ptr(xreg_addr(base, off, disp))); } void uni_stnt1w(const ZReg &z, const XReg &base, @@ -885,12 +891,12 @@ struct jit_bnorm_t : public jit_generator { if (with_relu_inf_only) { // --attr=post_ops='relu' if (pd_->alpha() != 0.f) - fwd_process_relu_alpha_sve_512(vdata); + fwd_process_relu_alpha(vdata); else uni_fmaxnm(vdata, vdata, vzero.s); } else if (with_relu) { // --flags=R - assert(isa == sve_512); - fwd_process_relu_sve_512( + assert(isa == sve_256 || isa == sve_512); + fwd_process_relu( ZRegS(vdata.getIdx()), idx * vlen_spat_data_); } add(X_DEFAULT_ADDR, reg_dst, reg_soff_nspc); @@ -1004,7 +1010,8 @@ struct jit_bnorm_t : public jit_generator { L(zero_rbuf); { uni_str(TReg(0), reg_rbuf1, reg_coff); - add_imm(reg_coff, reg_coff, isa == sve_512 ? vlen : vlen / 2, + add_imm(reg_coff, reg_coff, + (isa == sve_256 || isa == sve_512) ? vlen : vlen / 2, X_TMP_0); cmp(reg_coff, reg_coff_max); b(NE, zero_rbuf); @@ -1080,13 +1087,13 @@ struct jit_bnorm_t : public jit_generator { subs_imm(reg_ctr, reg_ctr, 1, X_TMP_0); b(NE, mean_reduction_thrs); } - if (isa == sve_512) + if (isa == sve_256 || isa == sve_512) fdiv(ZRegS(1), P_ALL_ONE / T_m, ZRegS(vchan_size.getIdx())); else fdiv(VReg4S(1), VReg4S(1), VReg4S(vchan_size.getIdx())); uni_store_maybe_tail(mean_ptr(), TReg(1)); - if (isa == sve_512) + if (isa == sve_256 || isa == sve_512) add_imm(reg_coff, reg_coff, vlen, X_TMP_0); else add_imm(reg_coff, reg_coff, vlen / 2, X_TMP_0); @@ -1163,13 +1170,13 @@ struct jit_bnorm_t : public jit_generator { subs(reg_ctr, reg_ctr, 1); b(NE, var_reduction_thrs); } - if (isa == sve_512) + if (isa == sve_256 || isa == sve_512) fdiv(ZRegS(1), P_ALL_ONE / T_m, ZRegS(vchan_size.getIdx())); else { fdiv(VReg4S(1), VReg4S(1), VReg4S(IDX(vchan_size))); } uni_store_maybe_tail(var_ptr(), TReg(1)); - if (isa == sve_512) + if (isa == sve_256 || isa == sve_512) add_imm(reg_coff, reg_coff, vlen, X_TMP_0); else add_imm(reg_coff, reg_coff, vlen / 2, X_TMP_0); @@ -1224,12 +1231,12 @@ struct jit_bnorm_t : public jit_generator { } if (with_relu_inf_only) { // --attr=post_ops='relu' if (pd_->alpha() != 0.f) { - fwd_process_relu_alpha_sve_512(v); + fwd_process_relu_alpha(v); } else uni_fmaxnm(v, v, vzero.s); } else if (with_relu) { // --flags=R - assert(isa == sve_512); - fwd_process_relu_sve_512(ZRegS(v.getIdx()), offt); + assert(isa == sve_256 || isa == sve_512); + fwd_process_relu(ZRegS(v.getIdx()), offt); } add(X_DEFAULT_ADDR, reg_dst, reg_soff); if (offt) @@ -1405,8 +1412,8 @@ struct jit_bnorm_t : public jit_generator { if (offt) add_imm(X_TMP_0, X_TMP_0, offt, X_TMP_1); uni_load_spat_data(t2, X_TMP_0); if (with_relu) { - assert(isa == sve_512); - bwd_process_relu_sve_512(ZRegS(t2.getIdx()), offt); + assert(isa == sve_256 || isa == sve_512); + bwd_process_relu(ZRegS(t2.getIdx()), offt); } fsub(t3.s, vmean.s, t1.s); if (isa == asimd) { @@ -1490,8 +1497,8 @@ struct jit_bnorm_t : public jit_generator { uni_load_spat_data(vdiff_dst, X_TMP_3); if (with_relu) { - assert(isa == sve_512); - bwd_process_relu_sve_512(ZRegS(vdiff_dst.getIdx()), offt); + assert(isa == sve_256 || isa == sve_512); + bwd_process_relu(ZRegS(vdiff_dst.getIdx()), offt); } fsub(vsrc.s, vsrc.s, vmean.s); @@ -1603,8 +1610,8 @@ struct jit_bnorm_t : public jit_generator { add_imm(X_DEFAULT_ADDR, X_DEFAULT_ADDR, offt, X_TMP_0); uni_load_spat_data(TReg(v.getIdx()), X_DEFAULT_ADDR); if (with_relu) { - assert(isa == sve_512); - bwd_process_relu_sve_512(ZRegS(v.getIdx()), offt); + assert(isa == sve_256 || isa == sve_512); + bwd_process_relu(ZRegS(v.getIdx()), offt); } if (!pd_->use_global_stats()) { fsub(v, v, vdiff_beta.s); @@ -1723,9 +1730,8 @@ struct jit_bnorm_t : public jit_generator { TReg(vdiff_data.getIdx()), X_DEFAULT_ADDR); if (with_relu) { - assert(isa == sve_512); - bwd_process_relu_sve_512( - ZRegS(vdiff_data.getIdx()), offt); + assert(isa == sve_256 || isa == sve_512); + bwd_process_relu(ZRegS(vdiff_data.getIdx()), offt); } if (!pd_->use_global_stats()) { @@ -1841,7 +1847,7 @@ struct jit_bnorm_t : public jit_generator { uni_str(TReg(0), X_TMP_0); add(X_TMP_0, reg_rbuf2, reg_coff); uni_str(TReg(0), X_TMP_0); - if (isa == sve_512) + if (isa == sve_256 || isa == sve_512) add_imm(reg_coff, reg_coff, vlen, X_TMP_0); else add_imm(reg_coff, reg_coff, vlen / 2, X_TMP_0); @@ -1852,7 +1858,7 @@ struct jit_bnorm_t : public jit_generator { LDR_ASSERT(reg_src, sp, (int)stack_off_src); LDR_ASSERT(reg_diff_dst, sp, (int)stack_off_diff_dst); if (with_relu) { - assert(isa == sve_512); + assert(isa == sve_256 || isa == sve_512); LDR_ASSERT(reg_ws, sp, (int)stack_off_ws); } @@ -1935,7 +1941,8 @@ struct jit_bnorm_t : public jit_generator { fmul(TRegS(0), TRegS(0), vsqrtvar.s); uni_store_maybe_tail(diff_gamma_ptr(), TReg(0)); uni_store_maybe_tail(diff_beta_ptr(), TReg(1)); - add_imm(reg_coff, reg_coff, isa == sve_512 ? vlen : vlen / 2, + add_imm(reg_coff, reg_coff, + isa == sve_256 || isa == sve_512 ? vlen : vlen / 2, X_TMP_0); cmp(reg_coff, reg_coff_max); b(NE, sh_reduction_channels); @@ -1946,7 +1953,7 @@ struct jit_bnorm_t : public jit_generator { LDR_ASSERT(reg_diff_src, sp, (int)stack_off_diff_src); if (with_relu) { - assert(isa == sve_512); + assert(isa == sve_256 || isa == sve_512); LDR_ASSERT(reg_ws, sp, (int)stack_off_ws); } @@ -2003,20 +2010,31 @@ struct jit_bnorm_t : public jit_generator { jit_bnorm_t(const batch_normalization_pd_t *pd, const jit_bnorm_conf_t *jbp) : pd_(pd), jbp_(jbp) { - static_assert(isa == asimd || isa == sve_512, "unsupported isa"); + static_assert(isa == asimd || isa == sve_256 || isa == sve_512, + "unsupported isa"); is_bf16_ = pd_->src_md()->data_type == data_type::bf16; is_f16_ = pd_->src_md()->data_type == data_type::f16; vlen_spat_data_ = vlen / (1 + is_xf16()); // 32B of xF16 -> 64B of FP32 - unroll_blocks = isa == sve_512 && !jbp_->is_spatial_thr_ ? 4 : 1; - unroll_regs = isa == sve_512 && !jbp_->is_spatial_thr_ ? 4 : 1; + unroll_blocks + = (isa == sve_256 || isa == sve_512) && !jbp_->is_spatial_thr_ + ? 4 + : 1; + unroll_regs + = (isa == sve_256 || isa == sve_512) && !jbp_->is_spatial_thr_ + ? 4 + : 1; } void generate() override { preamble(); - if (isa == sve_512) { prepare_tail_mask_sve_512(); } + size_t simd_w_ = cpu_isa_traits::vlen / sizeof(float); + if (simd_w_ != cpu_sveLen / sizeof(float)) + set_preg(P_ALL_ONE.s, simd_w_, X_TMP_0, X_TMP_1); + + if (isa == sve_256 || isa == sve_512) { prepare_tail_mask(); } compute_static_strides(); @@ -2281,6 +2299,10 @@ status_t jit_uni_batch_normalization_fwd_t::pd_t::init(engine_t *engine) { if (!src_d.matches_one_of_tag( nCw16c, nChw16c, nCdhw16c, nc, nwc, nhwc, ndhwc)) return status::unimplemented; + } else if (isa == sve_256) { + if (!src_d.matches_one_of_tag( + nCw8c, nChw8c, nCdhw8c, nc, nwc, nhwc, ndhwc)) + return status::unimplemented; } else { if (!src_d.matches_one_of_tag(nCw8c, nChw8c, nCdhw8c)) return status::unimplemented; @@ -2288,14 +2310,15 @@ status_t jit_uni_batch_normalization_fwd_t::pd_t::init(engine_t *engine) { if (is_fwd() ? with_relu_post_op(is_training()) || fuse_norm_relu() : fuse_norm_relu()) - if (isa != sve_512) return status::unimplemented; + if (isa != sve_512) return status::unimplemented; // TODO if (is_training() && fuse_norm_relu()) { - if (isa < sve_512) return status::unimplemented; + if (isa != sve_256 && isa != sve_512) return status::unimplemented; init_default_ws(1); } - if (memory_desc_wrapper(src_md()).padded_dims()[1] != C() && isa < sve_512) + if (memory_desc_wrapper(src_md()).padded_dims()[1] != C() && isa != sve_256 + && isa != sve_512) return status::unimplemented; // Only IC % 16 == 0 is supported for now @@ -2386,6 +2409,11 @@ status_t jit_uni_batch_normalization_bwd_t::pd_t::init(engine_t *engine) { nc, nwc, nCw16c, nhwc, nChw16c, ndhwc, nCdhw16c); diff_src_tag = diff_src_d.matches_one_of_tag( nc, nwc, nCw16c, nhwc, nChw16c, ndhwc, nCdhw16c); + } else if (isa == sve_256) { + src_tag = src_d.matches_one_of_tag( + nc, nwc, nCw8c, nhwc, nChw8c, ndhwc, nCdhw8c); + diff_src_tag = diff_src_d.matches_one_of_tag( + nc, nwc, nCw8c, nhwc, nChw8c, ndhwc, nCdhw8c); } else { src_tag = src_d.matches_one_of_tag(nCw8c, nChw8c, nCdhw8c); diff_src_tag = diff_src_d.matches_one_of_tag(nCw8c, nChw8c, nCdhw8c); @@ -2394,7 +2422,8 @@ status_t jit_uni_batch_normalization_bwd_t::pd_t::init(engine_t *engine) { && src_tag == diff_src_tag); if (!ok) return status::unimplemented; - if (memory_desc_wrapper(src_md()).padded_dims()[1] != C() && isa < sve_512) + if (memory_desc_wrapper(src_md()).padded_dims()[1] != C() && isa != sve_256 + && isa != sve_512) return status::unimplemented; // Only IC % 16 == 0 is supported for now @@ -2404,7 +2433,7 @@ status_t jit_uni_batch_normalization_bwd_t::pd_t::init(engine_t *engine) { } if (fuse_norm_relu()) { - if (isa < sve_512) return status::unimplemented; + if (isa != sve_256 && isa != sve_512) return status::unimplemented; init_default_ws(1); if (!compare_ws(hint_fwd_pd_)) return status::unimplemented; } @@ -2465,6 +2494,8 @@ jit_uni_batch_normalization_bwd_t::~jit_uni_batch_normalization_bwd_t() { /* struct instantiation */ template struct jit_uni_batch_normalization_fwd_t; template struct jit_uni_batch_normalization_bwd_t; +template struct jit_uni_batch_normalization_fwd_t; +template struct jit_uni_batch_normalization_bwd_t; template struct jit_uni_batch_normalization_fwd_t; template struct jit_uni_batch_normalization_bwd_t; diff --git a/src/cpu/cpu_batch_normalization_list.cpp b/src/cpu/cpu_batch_normalization_list.cpp index dd9aaeb5040..ab093a380f0 100644 --- a/src/cpu/cpu_batch_normalization_list.cpp +++ b/src/cpu/cpu_batch_normalization_list.cpp @@ -1,6 +1,6 @@ /******************************************************************************* * Copyright 2019-2022 Intel Corporation -* Copyright 2021 FUJITSU LIMITED +* Copyright 2024 FUJITSU LIMITED * Copyright 2022 Arm Ltd. and affiliates * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -58,6 +58,7 @@ const std::map> &impl_list_map() { CPU_INSTANCE_X64(jit_uni_tbb_batch_normalization_fwd_t) CPU_INSTANCE_X64(jit_uni_tbb_batch_normalization_fwd_t) CPU_INSTANCE_AARCH64(jit_uni_batch_normalization_fwd_t) + CPU_INSTANCE_AARCH64(jit_uni_batch_normalization_fwd_t) CPU_INSTANCE_AARCH64(jit_uni_batch_normalization_fwd_t) CPU_INSTANCE_AARCH64_ACL(acl_batch_normalization_fwd_t) CPU_INSTANCE(ncsp_batch_normalization_fwd_t) @@ -85,6 +86,7 @@ const std::map> &impl_list_map() { CPU_INSTANCE_X64(jit_uni_tbb_batch_normalization_bwd_t) CPU_INSTANCE_X64(jit_uni_tbb_batch_normalization_bwd_t) CPU_INSTANCE_AARCH64(jit_uni_batch_normalization_bwd_t) + CPU_INSTANCE_AARCH64(jit_uni_batch_normalization_bwd_t) CPU_INSTANCE_AARCH64(jit_uni_batch_normalization_bwd_t) CPU_INSTANCE(ncsp_batch_normalization_bwd_t) CPU_INSTANCE(ncsp_batch_normalization_bwd_t)