Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cpu: aarch64: batch_normalization : Expand ARM SVE support in jit_uni_batch_normalization #1918

Merged
merged 2 commits into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 79 additions & 48 deletions src/cpu/aarch64/jit_uni_batch_normalization.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*******************************************************************************
* Copyright 2020-2022 Intel Corporation
* Copyright 2020-2022 FUJITSU LIMITED
* Copyright 2020-2024 FUJITSU LIMITED
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -430,7 +430,7 @@ struct jit_bnorm_t : public jit_generator {
#undef STR_PARAM
}

void prepare_tail_mask_sve_512() {
void prepare_tail_mask() {
if (!is_c_padded()) return;
const int tail = pd_->C() % (int)(vlen / sizeof(float));
set_preg(ktail_mask.s, tail, X_TMP_0, X_TMP_1);
Expand All @@ -447,18 +447,18 @@ struct jit_bnorm_t : public jit_generator {
if (with_relu) uni_clear(vzero);
}

void fwd_process_relu_sve_512(ZRegS vdst, int offt = 0) {
void fwd_process_relu(ZRegS vdst, int offt = 0) {
const int bits = bit_shift();
const int offset = offt / (1 << bits);
XReg r = jbp_->is_nspc_ ? reg_soff_nspc : reg_soff;
ZRegS zzero = ZRegS(vzero.getIdx());

assert(isa == sve_512);
assert(isa == sve_256 || isa == sve_512);

assert(bits < 64);
lsr(r, r, bits);
fcmlt(kstore_mask.s, P_ALL_ONE / T_z, zzero, vdst);
sub(X_DEFAULT_ADDR, sp, 8); // sve_512
sub(X_DEFAULT_ADDR, sp, 8);
uzp1(p_tmp0.b, kstore_mask.b, kstore_mask.b);
uzp1(p_tmp0.b, p_tmp0.b, p_tmp0.b);
str(p_tmp0, ptr(X_DEFAULT_ADDR));
Expand All @@ -472,11 +472,11 @@ struct jit_bnorm_t : public jit_generator {
lsl(r, r, bit_shift());
}

void fwd_process_relu_alpha_sve_512(TRegS vmm_dst) {
void fwd_process_relu_alpha(TRegS vmm_dst) {
ZRegS dst = ZRegS(vmm_dst.getIdx());
ZRegS z_tmp0 = ZRegS(t_tmp0.getIdx());

assert(isa == sve_512);
assert(isa == sve_256 || isa == sve_512);

add_imm(X_DEFAULT_ADDR, sp, (int)stack_off_relu_alpha, X_TMP_0);
ld1rw(ZRegS(t_tmp0.getIdx()), P_ALL_ONE / T_z, ptr(X_DEFAULT_ADDR));
Expand All @@ -486,7 +486,7 @@ struct jit_bnorm_t : public jit_generator {
sel(dst, kstore_mask, dst, z_tmp0);
}

void bwd_process_relu_sve_512(ZRegS vdiff_dst, int offt = 0) {
void bwd_process_relu(ZRegS vdiff_dst, int offt = 0) {
const int bits = bit_shift();
const int offset = offt / (1 << bits);
XReg r = jbp_->is_nspc_ ? reg_soff_nspc : reg_soff;
Expand All @@ -498,7 +498,7 @@ struct jit_bnorm_t : public jit_generator {
if (offset) add_imm(X_DEFAULT_ADDR, X_DEFAULT_ADDR, offset, X_TMP_0);

ldrh(W_TMP_0, ptr(X_DEFAULT_ADDR));
sub(X_DEFAULT_ADDR, sp, 8); // sve_512
sub(X_DEFAULT_ADDR, sp, 8);
str(X_TMP_0, ptr(X_DEFAULT_ADDR));
ldr(kstore_mask, ptr(X_DEFAULT_ADDR));
zip1(kstore_mask.b, kstore_mask.b, kstore_mask.b);
Expand All @@ -512,7 +512,9 @@ struct jit_bnorm_t : public jit_generator {
ldr(QReg(IDX(v)), ptr(x));
}

void uni_load_spat_data(const ZReg &z, const XReg &x) { ldr(z, ptr(x)); }
void uni_load_spat_data(const ZReg &z, const XReg &x) {
ld1w(z.s, P_ALL_ONE / T_z, ptr(x));
}

void uni_store_spat_data(
const XReg &x, const VReg &v, bool is_nt_store = false) {
Expand All @@ -522,7 +524,7 @@ struct jit_bnorm_t : public jit_generator {

void uni_store_spat_data(
const XReg &x, const ZReg &z, bool is_nt_store = false) {
stnt1w(z.s, P_ALL_ONE, ptr(x));
stnt1w(z.s, P_ALL_ONE / T_z, ptr(x));
}

void jump_check(const Label &l_no_mask) {
Expand All @@ -541,7 +543,8 @@ struct jit_bnorm_t : public jit_generator {

if (is_c_padded()) {
jump_check(l_no_mask);
if (isa == sve_512) ld1w(ZRegS(IDX(t)), ktail_mask / T_z, ptr(x));
assert(isa == sve_256 || isa == sve_512);
ld1w(ZRegS(IDX(t)), ktail_mask / T_z, ptr(x));
b(l_ret);
}
L(l_no_mask);
Expand All @@ -554,7 +557,8 @@ struct jit_bnorm_t : public jit_generator {

if (is_c_padded()) {
jump_check(l_no_mask);
if (isa == sve_512) st1w(ZRegS(IDX(t)), ktail_mask / T_z, ptr(x));
assert(isa == sve_256 || isa == sve_512);
st1w(ZRegS(IDX(t)), ktail_mask / T_z, ptr(x));
b(l_ret);
}
L(l_no_mask);
Expand Down Expand Up @@ -589,7 +593,9 @@ struct jit_bnorm_t : public jit_generator {

void uni_ldr(const VReg &v, const XReg &x) { ldr(QReg(IDX(v)), ptr(x)); }

void uni_ldr(const ZReg &z, const XReg &x) { ldr(z, ptr(x)); }
void uni_ldr(const ZReg &z, const XReg &x) {
ld1w(z.s, P_ALL_ONE / T_z, ptr(x));
}

void uni_str(const VReg &v, const XReg &base,
const XReg &off = XReg(DUMMY_IDX), const int disp = 0) {
Expand All @@ -615,7 +621,7 @@ struct jit_bnorm_t : public jit_generator {

void uni_str(const ZReg &z, const XReg &base,
const XReg &off = XReg(DUMMY_IDX), const int disp = 0) {
str(z, ptr(xreg_addr(base, off, disp)));
st1w(z.s, P_ALL_ONE / T_z, ptr(xreg_addr(base, off, disp)));
}

void uni_stnt1w(const ZReg &z, const XReg &base,
Expand Down Expand Up @@ -885,12 +891,12 @@ struct jit_bnorm_t : public jit_generator {

if (with_relu_inf_only) { // --attr=post_ops='relu'
if (pd_->alpha() != 0.f)
fwd_process_relu_alpha_sve_512(vdata);
fwd_process_relu_alpha(vdata);
else
uni_fmaxnm(vdata, vdata, vzero.s);
} else if (with_relu) { // --flags=R
assert(isa == sve_512);
fwd_process_relu_sve_512(
assert(isa == sve_256 || isa == sve_512);
fwd_process_relu(
ZRegS(vdata.getIdx()), idx * vlen_spat_data_);
}
add(X_DEFAULT_ADDR, reg_dst, reg_soff_nspc);
Expand Down Expand Up @@ -1004,7 +1010,8 @@ struct jit_bnorm_t : public jit_generator {
L(zero_rbuf);
{
uni_str(TReg(0), reg_rbuf1, reg_coff);
add_imm(reg_coff, reg_coff, isa == sve_512 ? vlen : vlen / 2,
add_imm(reg_coff, reg_coff,
(isa == sve_256 || isa == sve_512) ? vlen : vlen / 2,
nikhilfujitsu marked this conversation as resolved.
Show resolved Hide resolved
X_TMP_0);
cmp(reg_coff, reg_coff_max);
b(NE, zero_rbuf);
Expand Down Expand Up @@ -1080,13 +1087,13 @@ struct jit_bnorm_t : public jit_generator {
subs_imm(reg_ctr, reg_ctr, 1, X_TMP_0);
b(NE, mean_reduction_thrs);
}
if (isa == sve_512)
if (isa == sve_256 || isa == sve_512)
nikhilfujitsu marked this conversation as resolved.
Show resolved Hide resolved
fdiv(ZRegS(1), P_ALL_ONE / T_m, ZRegS(vchan_size.getIdx()));
else
fdiv(VReg4S(1), VReg4S(1), VReg4S(vchan_size.getIdx()));
uni_store_maybe_tail(mean_ptr(), TReg(1));

if (isa == sve_512)
if (isa == sve_256 || isa == sve_512)
add_imm(reg_coff, reg_coff, vlen, X_TMP_0);
else
add_imm(reg_coff, reg_coff, vlen / 2, X_TMP_0);
Expand Down Expand Up @@ -1163,13 +1170,13 @@ struct jit_bnorm_t : public jit_generator {
subs(reg_ctr, reg_ctr, 1);
b(NE, var_reduction_thrs);
}
if (isa == sve_512)
if (isa == sve_256 || isa == sve_512)
fdiv(ZRegS(1), P_ALL_ONE / T_m, ZRegS(vchan_size.getIdx()));
else {
fdiv(VReg4S(1), VReg4S(1), VReg4S(IDX(vchan_size)));
}
uni_store_maybe_tail(var_ptr(), TReg(1));
if (isa == sve_512)
if (isa == sve_256 || isa == sve_512)
add_imm(reg_coff, reg_coff, vlen, X_TMP_0);
else
add_imm(reg_coff, reg_coff, vlen / 2, X_TMP_0);
Expand Down Expand Up @@ -1224,12 +1231,12 @@ struct jit_bnorm_t : public jit_generator {
}
if (with_relu_inf_only) { // --attr=post_ops='relu'
if (pd_->alpha() != 0.f) {
fwd_process_relu_alpha_sve_512(v);
fwd_process_relu_alpha(v);
} else
uni_fmaxnm(v, v, vzero.s);
} else if (with_relu) { // --flags=R
assert(isa == sve_512);
fwd_process_relu_sve_512(ZRegS(v.getIdx()), offt);
assert(isa == sve_256 || isa == sve_512);
fwd_process_relu(ZRegS(v.getIdx()), offt);
}
add(X_DEFAULT_ADDR, reg_dst, reg_soff);
if (offt)
Expand Down Expand Up @@ -1405,8 +1412,8 @@ struct jit_bnorm_t : public jit_generator {
if (offt) add_imm(X_TMP_0, X_TMP_0, offt, X_TMP_1);
uni_load_spat_data(t2, X_TMP_0);
if (with_relu) {
assert(isa == sve_512);
bwd_process_relu_sve_512(ZRegS(t2.getIdx()), offt);
assert(isa == sve_256 || isa == sve_512);
bwd_process_relu(ZRegS(t2.getIdx()), offt);
}
fsub(t3.s, vmean.s, t1.s);
if (isa == asimd) {
Expand Down Expand Up @@ -1490,8 +1497,8 @@ struct jit_bnorm_t : public jit_generator {
uni_load_spat_data(vdiff_dst, X_TMP_3);

if (with_relu) {
assert(isa == sve_512);
bwd_process_relu_sve_512(ZRegS(vdiff_dst.getIdx()), offt);
assert(isa == sve_256 || isa == sve_512);
bwd_process_relu(ZRegS(vdiff_dst.getIdx()), offt);
}

fsub(vsrc.s, vsrc.s, vmean.s);
Expand Down Expand Up @@ -1603,8 +1610,8 @@ struct jit_bnorm_t : public jit_generator {
add_imm(X_DEFAULT_ADDR, X_DEFAULT_ADDR, offt, X_TMP_0);
uni_load_spat_data(TReg(v.getIdx()), X_DEFAULT_ADDR);
if (with_relu) {
assert(isa == sve_512);
bwd_process_relu_sve_512(ZRegS(v.getIdx()), offt);
assert(isa == sve_256 || isa == sve_512);
bwd_process_relu(ZRegS(v.getIdx()), offt);
}
if (!pd_->use_global_stats()) {
fsub(v, v, vdiff_beta.s);
Expand Down Expand Up @@ -1723,9 +1730,8 @@ struct jit_bnorm_t : public jit_generator {
TReg(vdiff_data.getIdx()), X_DEFAULT_ADDR);

if (with_relu) {
assert(isa == sve_512);
bwd_process_relu_sve_512(
ZRegS(vdiff_data.getIdx()), offt);
assert(isa == sve_256 || isa == sve_512);
bwd_process_relu(ZRegS(vdiff_data.getIdx()), offt);
}

if (!pd_->use_global_stats()) {
Expand Down Expand Up @@ -1841,7 +1847,7 @@ struct jit_bnorm_t : public jit_generator {
uni_str(TReg(0), X_TMP_0);
add(X_TMP_0, reg_rbuf2, reg_coff);
uni_str(TReg(0), X_TMP_0);
if (isa == sve_512)
if (isa == sve_256 || isa == sve_512)
add_imm(reg_coff, reg_coff, vlen, X_TMP_0);
else
add_imm(reg_coff, reg_coff, vlen / 2, X_TMP_0);
Expand All @@ -1852,7 +1858,7 @@ struct jit_bnorm_t : public jit_generator {
LDR_ASSERT(reg_src, sp, (int)stack_off_src);
LDR_ASSERT(reg_diff_dst, sp, (int)stack_off_diff_dst);
if (with_relu) {
assert(isa == sve_512);
assert(isa == sve_256 || isa == sve_512);
LDR_ASSERT(reg_ws, sp, (int)stack_off_ws);
}

Expand Down Expand Up @@ -1935,7 +1941,8 @@ struct jit_bnorm_t : public jit_generator {
fmul(TRegS(0), TRegS(0), vsqrtvar.s);
uni_store_maybe_tail(diff_gamma_ptr(), TReg(0));
uni_store_maybe_tail(diff_beta_ptr(), TReg(1));
add_imm(reg_coff, reg_coff, isa == sve_512 ? vlen : vlen / 2,
add_imm(reg_coff, reg_coff,
isa == sve_256 || isa == sve_512 ? vlen : vlen / 2,
X_TMP_0);
cmp(reg_coff, reg_coff_max);
b(NE, sh_reduction_channels);
Expand All @@ -1946,7 +1953,7 @@ struct jit_bnorm_t : public jit_generator {

LDR_ASSERT(reg_diff_src, sp, (int)stack_off_diff_src);
if (with_relu) {
assert(isa == sve_512);
assert(isa == sve_256 || isa == sve_512);
LDR_ASSERT(reg_ws, sp, (int)stack_off_ws);
}

Expand Down Expand Up @@ -2003,20 +2010,31 @@ struct jit_bnorm_t : public jit_generator {

jit_bnorm_t(const batch_normalization_pd_t *pd, const jit_bnorm_conf_t *jbp)
: pd_(pd), jbp_(jbp) {
static_assert(isa == asimd || isa == sve_512, "unsupported isa");
static_assert(isa == asimd || isa == sve_256 || isa == sve_512,
"unsupported isa");

is_bf16_ = pd_->src_md()->data_type == data_type::bf16;
is_f16_ = pd_->src_md()->data_type == data_type::f16;
vlen_spat_data_ = vlen / (1 + is_xf16()); // 32B of xF16 -> 64B of FP32

unroll_blocks = isa == sve_512 && !jbp_->is_spatial_thr_ ? 4 : 1;
unroll_regs = isa == sve_512 && !jbp_->is_spatial_thr_ ? 4 : 1;
unroll_blocks
= (isa == sve_256 || isa == sve_512) && !jbp_->is_spatial_thr_
? 4
: 1;
unroll_regs
= (isa == sve_256 || isa == sve_512) && !jbp_->is_spatial_thr_
? 4
: 1;
}

void generate() override {
preamble();

if (isa == sve_512) { prepare_tail_mask_sve_512(); }
size_t simd_w_ = cpu_isa_traits<isa>::vlen / sizeof(float);
if (simd_w_ != cpu_sveLen / sizeof(float))
set_preg(P_ALL_ONE.s, simd_w_, X_TMP_0, X_TMP_1);

if (isa == sve_256 || isa == sve_512) { prepare_tail_mask(); }

compute_static_strides();

Expand Down Expand Up @@ -2281,21 +2299,26 @@ status_t jit_uni_batch_normalization_fwd_t<isa>::pd_t::init(engine_t *engine) {
if (!src_d.matches_one_of_tag(
nCw16c, nChw16c, nCdhw16c, nc, nwc, nhwc, ndhwc))
return status::unimplemented;
} else if (isa == sve_256) {
if (!src_d.matches_one_of_tag(
nCw8c, nChw8c, nCdhw8c, nc, nwc, nhwc, ndhwc))
return status::unimplemented;
} else {
if (!src_d.matches_one_of_tag(nCw8c, nChw8c, nCdhw8c))
return status::unimplemented;
}

if (is_fwd() ? with_relu_post_op(is_training()) || fuse_norm_relu()
: fuse_norm_relu())
if (isa != sve_512) return status::unimplemented;
if (isa != sve_512) return status::unimplemented; // TODO

if (is_training() && fuse_norm_relu()) {
if (isa < sve_512) return status::unimplemented;
if (isa != sve_256 && isa != sve_512) return status::unimplemented;
init_default_ws(1);
}

if (memory_desc_wrapper(src_md()).padded_dims()[1] != C() && isa < sve_512)
if (memory_desc_wrapper(src_md()).padded_dims()[1] != C() && isa != sve_256
&& isa != sve_512)
return status::unimplemented;

// Only IC % 16 == 0 is supported for now
Expand Down Expand Up @@ -2386,6 +2409,11 @@ status_t jit_uni_batch_normalization_bwd_t<isa>::pd_t::init(engine_t *engine) {
nc, nwc, nCw16c, nhwc, nChw16c, ndhwc, nCdhw16c);
diff_src_tag = diff_src_d.matches_one_of_tag(
nc, nwc, nCw16c, nhwc, nChw16c, ndhwc, nCdhw16c);
} else if (isa == sve_256) {
src_tag = src_d.matches_one_of_tag(
nc, nwc, nCw8c, nhwc, nChw8c, ndhwc, nCdhw8c);
diff_src_tag = diff_src_d.matches_one_of_tag(
nc, nwc, nCw8c, nhwc, nChw8c, ndhwc, nCdhw8c);
} else {
src_tag = src_d.matches_one_of_tag(nCw8c, nChw8c, nCdhw8c);
diff_src_tag = diff_src_d.matches_one_of_tag(nCw8c, nChw8c, nCdhw8c);
Expand All @@ -2394,7 +2422,8 @@ status_t jit_uni_batch_normalization_bwd_t<isa>::pd_t::init(engine_t *engine) {
&& src_tag == diff_src_tag);
if (!ok) return status::unimplemented;

if (memory_desc_wrapper(src_md()).padded_dims()[1] != C() && isa < sve_512)
if (memory_desc_wrapper(src_md()).padded_dims()[1] != C() && isa != sve_256
&& isa != sve_512)
return status::unimplemented;

// Only IC % 16 == 0 is supported for now
Expand All @@ -2404,7 +2433,7 @@ status_t jit_uni_batch_normalization_bwd_t<isa>::pd_t::init(engine_t *engine) {
}

if (fuse_norm_relu()) {
if (isa < sve_512) return status::unimplemented;
if (isa != sve_256 && isa != sve_512) return status::unimplemented;
init_default_ws(1);
if (!compare_ws(hint_fwd_pd_)) return status::unimplemented;
}
Expand Down Expand Up @@ -2465,6 +2494,8 @@ jit_uni_batch_normalization_bwd_t<isa>::~jit_uni_batch_normalization_bwd_t() {
/* struct instantiation */
template struct jit_uni_batch_normalization_fwd_t<asimd>;
template struct jit_uni_batch_normalization_bwd_t<asimd>;
template struct jit_uni_batch_normalization_fwd_t<sve_256>;
template struct jit_uni_batch_normalization_bwd_t<sve_256>;
template struct jit_uni_batch_normalization_fwd_t<sve_512>;
template struct jit_uni_batch_normalization_bwd_t<sve_512>;

Expand Down
Loading