Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[inductor][cpp] Add BMM kernel template for autotuning (pytorch#129772)
This PR adds the Cpp template for BMM, for FP32, FP16, and BF16. See pytorch#125683 for more background. 1. Adds `CppBmmTemplate` class which inherits from `CppPackedGemmTemplate`. Given a number of worker threads `num_threads` and batch size `B`, execute the Gemm kernel. For the first `B - (B % num_threads)` batch inputs, run one sub-gemm problem per thread. Then for the remaining `B % num_threads` sub-gemms, we execute each subproblem using the parallelized Gemm kernel. To manage this code, the `GEMM_TEMPLATE` from `CppPackedGemmTemplate` is rendered two different times, one with a single thread and one which includes the parallel OMP pragma. 2. Adapts `CppPackedGemmTemplate` to allow for child class. The `GEMM_TEMPLATE` is separated into different strings to allow for rendering by the child class. Slicing/indexing are adapted to allow for 3D BMM inputs. Additional methods `get_options()` and `_get_params_for_choices()` are added to reduce code duplication. BMM within `dlrm` benchmark has a single input buffer which is used for but X and W inputs. This is currently not supported in this PR. ### Performance On Granite/Sapphire Rapids, cpp_bmm template code uses AMX which requires an expensive transpose operation so the BMM op is rarely selected as faster than the existing external bmm kernel. As a result, speedup on SPR is identical with and without BMM code. Pass rate matches the rates for main exactly. #### Test Summary on Granite Rapids Test Scenario | Comp Item | Date | Compiler | torchbench | huggingface | timm_models -- | -- | -- | -- | -- | -- | -- Single Socket Multi-Threads | Pass Rate | gemm autotune| inductor | 91%, 73/80 | 100%, 46/46 | 100%, 61/61 | | | bmm + gemm autotune | inductor | 91%, 73/80 | 100%, 46/46 | 100%, 61/61 | | Geomean Speedup | gemm autotune| inductor | 2.15x | 1.91x | 2.52x | | | bmm + gemm autotune | inductor | 2.15x | 1.96x | 2.53x Single Core Single-Thread | Pass Rate | gemm autotune | inductor | 91%, 73/80 | 100%, 46/46 | 100%, 61/61 | | | bmm + gemm autotune| inductor | 91%, 73/80 | 100%, 46/46 | 100%, 61/61 | | Geomean Speedup | inductor_locally_benchmark_586 | inductor | 2.43x | 1.56x | 2.60x | | | inductor_locally_benchmark_585 | inductor | 2.45x | 1.56x | 2.63x This is not the case on an older Skylake Xeon machine. For the BMM ops contained in torchbench models, bmm performance improves by 1.10-2.64x. #### BF16 28-core Skylake Xeon | Model | Inductor | GemmAutotune | Gemm+BMM Autotune | |--------|--------|--------|--------| | BERT_pytorch | 1.233x | 2.597x | 2.608x | | hf_DistilBert | 1.128x | 2.242x | 2.368x | | hf_Reformer | 1.124x | 1.419x | 1.590x | | hf_T5_base | 1.012x | 1.257x | 1.382x | | hf_T5_large | 1.085x | 2.228x | 2.345x | ## Example BMM Code ``` #include <c10/util/Unroll.h> #include <torch/csrc/inductor/aoti_torch/c/shim.h> template <bool accum> inline void cpp_bmm_micro_gemm_amx_kernel_32_2( AMXState& amx_state, const bfloat16* __restrict__ A, const bfloat16* __restrict__ B, float* __restrict__ C, int64_t K, int64_t lda, int64_t ldb, int64_t ldc, uint8_t tilecfg_rows ) { // TODO(jgong5): add prefetch hint for A, B, C auto loadconfig = [](const amx_tilecfg& cfg) { _tile_loadconfig(&cfg); }; const auto last_k_offset = K / 32 * 32; const auto tail_k_size = K - last_k_offset; if C10_LIKELY (last_k_offset > 0) { amx_state.configure(tilecfg_rows, 64, 32 / 16, 2, loadconfig); } else { amx_state.configure(tilecfg_rows, tail_k_size * sizeof(bfloat16), 32 / 16, 2, loadconfig); } auto load_c = [&]() { _tile_loadd(0, C + 0 * ldc + 0, ldc * sizeof(float)); _tile_loadd(1, C + 0 * ldc + 16, ldc * sizeof(float)); _tile_loadd(2, C + 16 * ldc + 0, ldc * sizeof(float)); _tile_loadd(3, C + 16 * ldc + 16, ldc * sizeof(float)); }; auto zero_c = [&]() { _tile_zero(0); _tile_zero(1); _tile_zero(2); _tile_zero(3); }; if constexpr (accum) { load_c(); } else { zero_c(); } auto compute = [&](int k) { _tile_stream_loadd(4, A + 0 * lda + k, lda * sizeof(bfloat16)); _tile_loadd(6, B + k * ldb + 0, ldb * 2 * sizeof(bfloat16)); _tile_dpbf16ps(0, 4, 6); _tile_loadd(7, B + k * ldb + 32, ldb * 2 * sizeof(bfloat16)); _tile_dpbf16ps(1, 4, 7); _tile_stream_loadd(5, A + 16 * lda + k, lda * sizeof(bfloat16)); _tile_dpbf16ps(2, 5, 6); _tile_dpbf16ps(3, 5, 7); }; #pragma GCC unroll 4 for (int k = 0; k < last_k_offset; k += 32) { compute(k); } auto store_c = [&]() { // store to C _tile_stored(0, C + 0 * ldc + 0, ldc * sizeof(float)); _tile_stored(1, C + 0 * ldc + 16, ldc * sizeof(float)); _tile_stored(2, C + 16 * ldc + 0, ldc * sizeof(float)); _tile_stored(3, C + 16 * ldc + 16, ldc * sizeof(float)); }; // TODO(jgong5): move tail k computation to separate loopnest to save tile configuration overhead if C10_UNLIKELY (tail_k_size > 0) { if C10_LIKELY (last_k_offset > 0) { store_c(); amx_state.configure(tilecfg_rows, tail_k_size * sizeof(bfloat16), 32 / 16, 2, loadconfig); load_c(); } compute(last_k_offset); } store_c(); } template <bool accum> inline void cpp_bmm_micro_gemm_amx_kernel_16_2( AMXState& amx_state, const bfloat16* __restrict__ A, const bfloat16* __restrict__ B, float* __restrict__ C, int64_t K, int64_t lda, int64_t ldb, int64_t ldc, uint8_t tilecfg_rows ) { // TODO(jgong5): add prefetch hint for A, B, C auto loadconfig = [](const amx_tilecfg& cfg) { _tile_loadconfig(&cfg); }; const auto last_k_offset = K / 32 * 32; const auto tail_k_size = K - last_k_offset; if C10_LIKELY (last_k_offset > 0) { amx_state.configure(tilecfg_rows, 64, 16 / 16, 2, loadconfig); } else { amx_state.configure(tilecfg_rows, tail_k_size * sizeof(bfloat16), 16 / 16, 2, loadconfig); } auto load_c = [&]() { _tile_loadd(0, C + 0 * ldc + 0, ldc * sizeof(float)); _tile_loadd(1, C + 0 * ldc + 16, ldc * sizeof(float)); }; auto zero_c = [&]() { _tile_zero(0); _tile_zero(1); }; if constexpr (accum) { load_c(); } else { zero_c(); } auto compute = [&](int k) { _tile_stream_loadd(2, A + 0 * lda + k, lda * sizeof(bfloat16)); _tile_loadd(3, B + k * ldb + 0, ldb * 2 * sizeof(bfloat16)); _tile_dpbf16ps(0, 2, 3); _tile_loadd(4, B + k * ldb + 32, ldb * 2 * sizeof(bfloat16)); _tile_dpbf16ps(1, 2, 4); }; #pragma GCC unroll 4 for (int k = 0; k < last_k_offset; k += 32) { compute(k); } auto store_c = [&]() { // store to C _tile_stored(0, C + 0 * ldc + 0, ldc * sizeof(float)); _tile_stored(1, C + 0 * ldc + 16, ldc * sizeof(float)); }; // TODO(jgong5): move tail k computation to separate loopnest to save tile configuration overhead if C10_UNLIKELY (tail_k_size > 0) { if C10_LIKELY (last_k_offset > 0) { store_c(); amx_state.configure(tilecfg_rows, tail_k_size * sizeof(bfloat16), 16 / 16, 2, loadconfig); load_c(); } compute(last_k_offset); } store_c(); } template <bool accum> inline void cpp_bmm_micro_gemm( AMXState& amx_state, const bfloat16* __restrict__ A, const bfloat16* __restrict__ B, float* __restrict__ C, int64_t M, int64_t N, int64_t K, int64_t lda, int64_t ldb, int64_t ldc ) { AOTI_TORCH_CHECK(N % 32 == 0, "N dimension must be multiple of 32"); AOTI_TORCH_CHECK(K % 2 == 0, "K dimension must be multiple of 2"); // TODO(jgong5): loop unroll for M and N for (int64_t n = 0; n < N; n += 32) { for (int64_t m = 0; m < M; m += 32) { int64_t block_m = std::min<int64_t>(M - m, 32); int64_t m_tail = m; if (block_m >= 32) { cpp_bmm_micro_gemm_amx_kernel_32_2<accum>( amx_state, A + m * lda, B + n, C + m * ldc + n, K, lda, ldb, ldc, 16 ); block_m -= 32; m_tail += 32; } else if (block_m >= 16) { cpp_bmm_micro_gemm_amx_kernel_16_2<accum>( amx_state, A + m * lda, B + n, C + m * ldc + n, K, lda, ldb, ldc, 16 ); block_m -= 16; m_tail += 16; } if (block_m > 0) { cpp_bmm_micro_gemm_amx_kernel_16_2<accum>( amx_state, A + m_tail * lda, B + n, C + m_tail * ldc + n, K, lda, ldb, ldc, block_m ); } } } } void threaded_mm(const bfloat16* X, const bfloat16* W, bfloat16* Y, const int64_t ks_b_index) { constexpr int64_t num_threads = 48; constexpr int64_t N = 64; constexpr int64_t K = 96; constexpr int64_t Mr = 32; constexpr int64_t Nr = 32; constexpr int64_t Kr = 32; constexpr int64_t Nr_blocks = (N + Nr - 1) / Nr; constexpr int64_t Kr_blocks = (K + Kr - 1) / Kr; constexpr int64_t M = static_cast<int64_t>(384L); constexpr int64_t Mr_blocks = (M + Mr - 1) / Mr; constexpr int64_t Mt_blocks = 1; constexpr int64_t Nt_blocks = 1; constexpr int64_t Kt_blocks = 3; constexpr int64_t Mc_blocks = 1; constexpr int64_t Nc_blocks = 1; constexpr int64_t Kc_blocks = 3; constexpr int64_t num_Mc_blocks = (Mr_blocks + Mc_blocks - 1) / Mc_blocks; constexpr int64_t num_Nc_blocks = (Nr_blocks + Nc_blocks - 1) / Nc_blocks; constexpr int64_t num_Mt_blocks = (Mr_blocks + Mt_blocks - 1) / Mt_blocks; constexpr int64_t num_Nt_blocks = (Nr_blocks + Nt_blocks - 1) / Nt_blocks; constexpr int64_t num_Kt_blocks = (Kr_blocks + Kt_blocks - 1) / Kt_blocks; // make sure all partitions are assigned AOTI_TORCH_CHECK( Mt_blocks * Nt_blocks * Kt_blocks * 48 >= Mr_blocks * Nr_blocks * Kr_blocks, "Not all partitions are assigned." ); #pragma omp parallel num_threads(48) { const int tid = omp_get_thread_num(); const int64_t k_group_id = tid / num_Kt_blocks; const int64_t k_slice_id = tid % num_Kt_blocks; const int64_t n_group_id = k_group_id / num_Nt_blocks; const int64_t n_slice_id = k_group_id % num_Nt_blocks; const int64_t k_block_start = k_slice_id * Kt_blocks; const int64_t k_block_end = std::min(k_block_start + Kt_blocks, Kr_blocks); const int64_t n_block_start = n_slice_id * Nt_blocks; const int64_t n_block_end = std::min(n_block_start + Nt_blocks, Nr_blocks); const int64_t m_block_start = std::min(n_group_id * Mt_blocks, Mr_blocks); const int64_t m_block_end = std::min(m_block_start + Mt_blocks, Mr_blocks); const int64_t num_Mc_blocks_per_thread = (m_block_end - m_block_start + Mc_blocks - 1) / Mc_blocks; AMXState amx_state; auto _local_acc_buf = std::make_unique<float[]>(static_cast<int64_t>(Mc_blocks*Mr*Nc_blocks*Nr)); auto local_acc_buf = _local_acc_buf.get(); for (int64_t mc_block_id = 0; mc_block_id < num_Mc_blocks_per_thread; mc_block_id++) { const int64_t my_mc_block_id = (mc_block_id + n_slice_id) % num_Mc_blocks_per_thread; const int64_t mc = m_block_start + my_mc_block_id * Mc_blocks; const int64_t m_start = mc * Mr; const int64_t m_end = std::min(std::min(mc + Mc_blocks, m_block_end) * Mr, M); const int64_t m_size = m_end - m_start; for (int64_t nc = n_block_start; nc < n_block_end; nc += Nc_blocks) { const int64_t n_start = nc * Nr; const int64_t n_end = std::min(std::min(nc + Nc_blocks, n_block_end) * Nr, N); const int64_t n_size = n_end - n_start; // NB: assume we pad N, nc_block_end won't exceed padded N here. const int64_t nc_block_end = std::min(nc + Nc_blocks, n_block_end); if (_local_acc_buf == nullptr) { _local_acc_buf = std::make_unique<float[]>(static_cast<int64_t>(Mc_blocks*Mr*Nc_blocks*Nr)); local_acc_buf = _local_acc_buf.get(); } for (int64_t kc = k_block_start; kc < k_block_end; kc += Kc_blocks) { int64_t k_start = kc * Kr; int64_t k_end = std::min(std::min(kc + Kc_blocks, k_block_end) * Kr, K); for (int64_t nci = nc; nci < nc_block_end; nci++) { if (kc == k_block_start) { cpp_bmm_micro_gemm<static_cast<bool>(false)>( amx_state, &(X[static_cast<int64_t>(k_start + (96L*m_start) + (36864L*ks_b_index))]), &(W[static_cast<int64_t>((32L*k_start) + (3072L*nci) + (6144L*ks_b_index))]), &(local_acc_buf[static_cast<int64_t>((Nr*nci) + ((-1L)*Nr*nc))]), static_cast<int64_t>(m_end + ((-1L)*m_start)), static_cast<int64_t>(Nr), static_cast<int64_t>(k_end + ((-1L)*k_start)), static_cast<int64_t>(96L), static_cast<int64_t>(32L), static_cast<int64_t>(Nc_blocks*Nr) ); } else { cpp_bmm_micro_gemm<static_cast<bool>(true)>( amx_state, &(X[static_cast<int64_t>(k_start + (96L*m_start) + (36864L*ks_b_index))]), &(W[static_cast<int64_t>((32L*k_start) + (3072L*nci) + (6144L*ks_b_index))]), &(local_acc_buf[static_cast<int64_t>((Nr*nci) + ((-1L)*Nr*nc))]), static_cast<int64_t>(m_end + ((-1L)*m_start)), static_cast<int64_t>(Nr), static_cast<int64_t>(k_end + ((-1L)*k_start)), static_cast<int64_t>(96L), static_cast<int64_t>(32L), static_cast<int64_t>(Nc_blocks*Nr) ); } } } { { #pragma GCC ivdep for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(m_end + ((-1L)*m_start)); x0+=static_cast<int64_t>(1L)) { for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(16L*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L)))); x1+=static_cast<int64_t>(16L)) { auto tmp0 = at::vec::Vectorized<float>::loadu(local_acc_buf + static_cast<int64_t>(x1 + (Nc_blocks*Nr*x0)), static_cast<int64_t>(16)); auto tmp1 = at::vec::convert<bfloat16>(tmp0); tmp1.store(Y + static_cast<int64_t>(n_start + x1 + (64L*m_start) + (64L*x0) + (24576L*ks_b_index)), static_cast<int64_t>(16)); } for(int64_t x1=static_cast<int64_t>(16L*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L)))); x1<static_cast<int64_t>(n_end + ((-1L)*n_start)); x1+=(static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L))))) == 0 ? 1 : static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L))))))) { auto tmp0 = at::vec::Vectorized<float>::loadu(local_acc_buf + static_cast<int64_t>(x1 + (Nc_blocks*Nr*x0)), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L)))))); auto tmp1 = at::vec::convert<bfloat16>(tmp0); tmp1.store(Y + static_cast<int64_t>(n_start + x1 + (64L*m_start) + (64L*x0) + (24576L*ks_b_index)), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L)))))); } } } } } } amx_state.release([]() { _tile_release(); }); } } void single_thread_mm(const bfloat16* X, const bfloat16* W, bfloat16* Y, const int64_t ks_b_index) { constexpr int64_t num_threads = 1; constexpr int64_t N = 64; constexpr int64_t K = 96; constexpr int64_t Mr = 32; constexpr int64_t Nr = 32; constexpr int64_t Kr = 32; constexpr int64_t Nr_blocks = (N + Nr - 1) / Nr; constexpr int64_t Kr_blocks = (K + Kr - 1) / Kr; constexpr int64_t M = static_cast<int64_t>(384L); constexpr int64_t Mr_blocks = (M + Mr - 1) / Mr; constexpr int64_t Mt_blocks = 12; constexpr int64_t Nt_blocks = 2; constexpr int64_t Kt_blocks = 3; constexpr int64_t Mc_blocks = 12; constexpr int64_t Nc_blocks = 1; constexpr int64_t Kc_blocks = 3; constexpr int64_t num_Mc_blocks = (Mr_blocks + Mc_blocks - 1) / Mc_blocks; constexpr int64_t num_Nc_blocks = (Nr_blocks + Nc_blocks - 1) / Nc_blocks; constexpr int64_t num_Mt_blocks = (Mr_blocks + Mt_blocks - 1) / Mt_blocks; constexpr int64_t num_Nt_blocks = (Nr_blocks + Nt_blocks - 1) / Nt_blocks; constexpr int64_t num_Kt_blocks = (Kr_blocks + Kt_blocks - 1) / Kt_blocks; // make sure all partitions are assigned AOTI_TORCH_CHECK( Mt_blocks * Nt_blocks * Kt_blocks * 1 >= Mr_blocks * Nr_blocks * Kr_blocks, "Not all partitions are assigned." ); { constexpr int tid = 0; constexpr int64_t k_group_id = 0; constexpr int64_t k_slice_id = 0; constexpr int64_t n_group_id = 0; constexpr int64_t n_slice_id = 0; constexpr int64_t m_block_start = 0; constexpr int64_t n_block_start = 0; constexpr int64_t n_block_end = Nr_blocks; constexpr int64_t k_block_start = 0; constexpr int64_t k_block_end = Kr_blocks; constexpr int64_t num_Mc_blocks_per_thread = num_Mc_blocks; constexpr int64_t m_block_end = Mr_blocks; AMXState amx_state; auto _local_acc_buf = std::make_unique<float[]>(static_cast<int64_t>(Mc_blocks*Mr*Nc_blocks*Nr)); auto local_acc_buf = _local_acc_buf.get(); for (int64_t mc_block_id = 0; mc_block_id < num_Mc_blocks_per_thread; mc_block_id++) { const int64_t my_mc_block_id = (mc_block_id + n_slice_id) % num_Mc_blocks_per_thread; const int64_t mc = m_block_start + my_mc_block_id * Mc_blocks; const int64_t m_start = mc * Mr; const int64_t m_end = std::min(std::min(mc + Mc_blocks, m_block_end) * Mr, M); const int64_t m_size = m_end - m_start; for (int64_t nc = n_block_start; nc < n_block_end; nc += Nc_blocks) { const int64_t n_start = nc * Nr; const int64_t n_end = std::min(std::min(nc + Nc_blocks, n_block_end) * Nr, N); const int64_t n_size = n_end - n_start; // NB: assume we pad N, nc_block_end won't exceed padded N here. const int64_t nc_block_end = std::min(nc + Nc_blocks, n_block_end); if (_local_acc_buf == nullptr) { _local_acc_buf = std::make_unique<float[]>(static_cast<int64_t>(Mc_blocks*Mr*Nc_blocks*Nr)); local_acc_buf = _local_acc_buf.get(); } for (int64_t kc = k_block_start; kc < k_block_end; kc += Kc_blocks) { int64_t k_start = kc * Kr; int64_t k_end = std::min(std::min(kc + Kc_blocks, k_block_end) * Kr, K); for (int64_t nci = nc; nci < nc_block_end; nci++) { if (kc == k_block_start) { cpp_bmm_micro_gemm<static_cast<bool>(false)>( amx_state, &(X[static_cast<int64_t>(k_start + (96L*m_start) + (36864L*ks_b_index))]), &(W[static_cast<int64_t>((32L*k_start) + (3072L*nci) + (6144L*ks_b_index))]), &(local_acc_buf[static_cast<int64_t>((Nr*nci) + ((-1L)*Nr*nc))]), static_cast<int64_t>(m_end + ((-1L)*m_start)), static_cast<int64_t>(Nr), static_cast<int64_t>(k_end + ((-1L)*k_start)), static_cast<int64_t>(96L), static_cast<int64_t>(32L), static_cast<int64_t>(Nc_blocks*Nr) ); } else { cpp_bmm_micro_gemm<static_cast<bool>(true)>( amx_state, &(X[static_cast<int64_t>(k_start + (96L*m_start) + (36864L*ks_b_index))]), &(W[static_cast<int64_t>((32L*k_start) + (3072L*nci) + (6144L*ks_b_index))]), &(local_acc_buf[static_cast<int64_t>((Nr*nci) + ((-1L)*Nr*nc))]), static_cast<int64_t>(m_end + ((-1L)*m_start)), static_cast<int64_t>(Nr), static_cast<int64_t>(k_end + ((-1L)*k_start)), static_cast<int64_t>(96L), static_cast<int64_t>(32L), static_cast<int64_t>(Nc_blocks*Nr) ); } } } { { #pragma GCC ivdep for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(m_end + ((-1L)*m_start)); x0+=static_cast<int64_t>(1L)) { for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(16L*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L)))); x1+=static_cast<int64_t>(16L)) { auto tmp0 = at::vec::Vectorized<float>::loadu(local_acc_buf + static_cast<int64_t>(x1 + (Nc_blocks*Nr*x0)), static_cast<int64_t>(16)); auto tmp1 = at::vec::convert<bfloat16>(tmp0); tmp1.store(Y + static_cast<int64_t>(n_start + x1 + (64L*m_start) + (64L*x0) + (24576L*ks_b_index)), static_cast<int64_t>(16)); } for(int64_t x1=static_cast<int64_t>(16L*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L)))); x1<static_cast<int64_t>(n_end + ((-1L)*n_start)); x1+=(static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L))))) == 0 ? 1 : static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L))))))) { auto tmp0 = at::vec::Vectorized<float>::loadu(local_acc_buf + static_cast<int64_t>(x1 + (Nc_blocks*Nr*x0)), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L)))))); auto tmp1 = at::vec::convert<bfloat16>(tmp0); tmp1.store(Y + static_cast<int64_t>(n_start + x1 + (64L*m_start) + (64L*x0) + (24576L*ks_b_index)), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L)))))); } } } } } } amx_state.release([]() { _tile_release(); }); } } extern "C" void cpp_bmm(const bfloat16* X, const bfloat16* W, bfloat16* Y) { const int64_t B = static_cast<int64_t>(5L); constexpr int64_t num_threads = 48; int64_t B_single_thread_block = (B / num_threads) * num_threads; #pragma omp parallel for num_threads(48) for (int64_t b_start = 0; b_start < B_single_thread_block; ++b_start) { single_thread_mm(X, W, Y, b_start); } for (int64_t b_start = B_single_thread_block; b_start < B; ++b_start) { threaded_mm(X, W, Y, b_start); } } ``` Pull Request resolved: pytorch#129772 Approved by: https://github.com/jgong5, https://github.com/leslie-fang-intel, https://github.com/jansel
- Loading branch information