Skip to content

Commit

Permalink
benchdnn: brgemm: ukernel: fix f16 crashes with odd K
Browse files Browse the repository at this point in the history
  • Loading branch information
dzarukin committed Jul 12, 2024
1 parent 2a5526b commit 5897aa0
Showing 1 changed file with 7 additions and 8 deletions.
15 changes: 7 additions & 8 deletions tests/benchdnn/brgemm/brgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ struct kernel_args_t {
#else
brgemm_(nullptr)
, brgemm_pack_B_(nullptr)
, need_pack_(0)
#endif
, scratchpad_size_(0)
, generate_skip_accumulation_(false)
Expand All @@ -291,6 +292,7 @@ struct kernel_args_t {
#else
dnnl_brgemm_t brgemm_;
dnnl_brgemm_pack_B_t brgemm_pack_B_;
int need_pack_; // `int` to match C API
#endif
size_t scratchpad_size_;
bool generate_skip_accumulation_;
Expand Down Expand Up @@ -407,12 +409,11 @@ int init_kernel(kernel_args_t &kernel_args) {
SAFE(check_dnnl_status(st, prb, res), WARN);
if (res->state == SKIPPED) return OK;

int need_pack = 0;
st = dnnl_brgemm_pack_B_need_pack(brgemm_pack_B, &need_pack);
st = dnnl_brgemm_pack_B_need_pack(brgemm_pack_B, &kernel_args.need_pack_);
SAFE(check_dnnl_status(st, prb, res), WARN);
if (res->state == SKIPPED) return OK;

if (need_pack) {
if (kernel_args.need_pack_) {
st = dnnl_brgemm_pack_B_generate(brgemm_pack_B);
SAFE(check_dnnl_status(st, prb, res), WARN);
if (res->state == SKIPPED) return OK;
Expand Down Expand Up @@ -660,7 +661,8 @@ void init_memory_args(

// Needed to have enough memory after packing routine.
// TODO: generalize special f16 case.
const int dt_multiplier = prb->wei_dt() == dnnl_f16
const int dt_multiplier
= prb->wei_dt() == dnnl_f16 && !kernel_args.need_pack_
? 1
: 4 / dnnl_data_type_size(prb->wei_dt());
const dnnl_dim_t k_rounded = dt_multiplier * div_up(prb->k, dt_multiplier);
Expand Down Expand Up @@ -1124,10 +1126,7 @@ int doit(const prb_t *prb, res_t *res) {
? (char *)mem_map.at(DNNL_ARG_SCRATCHPAD)
: nullptr;

int need_pack = 0;
DNN_SAFE(dnnl_brgemm_pack_B_need_pack(brgemm_pack_B, &need_pack), WARN);

if (need_pack) {
if (kernel_args.need_pack_) {
auto st = dnnl_brgemm_pack_B_execute(
brgemm_pack_B, wei_ptr, wei_packed_ptr);
SAFE(check_dnnl_status(st, prb, res), WARN);
Expand Down

0 comments on commit 5897aa0

Please sign in to comment.