From 64e62755d0f77a879d4bf842c2f40b987be70ea7 Mon Sep 17 00:00:00 2001 From: Mateusz Belicki Date: Tue, 12 Sep 2023 08:58:08 +0000 Subject: [PATCH] Adapt large JointMatrix load and store builtins to new mangling. Adapt large JointMatrix load and store builtins to new mangling, additionally this patch fixes problem with only generic memory variants of those builtins being present. --- .../OpenCL/PreRelease/IBiF_matrix.cl | 250 +++++++++++------- .../address-spaces.ll | 214 +++++++++++++++ 2 files changed, 374 insertions(+), 90 deletions(-) create mode 100644 IGC/Compiler/tests/JointMatrixFuncsResolutionPass/address-spaces.ll diff --git a/IGC/BiFModule/Languages/OpenCL/PreRelease/IBiF_matrix.cl b/IGC/BiFModule/Languages/OpenCL/PreRelease/IBiF_matrix.cl index d00f03588038..56d26e188820 100644 --- a/IGC/BiFModule/Languages/OpenCL/PreRelease/IBiF_matrix.cl +++ b/IGC/BiFModule/Languages/OpenCL/PreRelease/IBiF_matrix.cl @@ -121,6 +121,7 @@ extern __constant int __JointMatrixLoadStoreOpt; #define OUT_VEC2(type) type##2 #define OUT_VEC1(type) type +#define OUT_STORE_VEC16(type) type##16 #define OUT_STORE_VEC8(type) type##8 #define OUT_STORE_VEC7(type) type##8 #define OUT_STORE_VEC6(type) type##8 @@ -152,6 +153,7 @@ extern __constant int __JointMatrixLoadStoreOpt; wi_contrib[i] = readop((src) + i * (stride)); // variants for 7,6,5,3 and 1 are only used to make the code compilable +#define DEFINE_BLOCK_RW_NAME16(rw, us) intel_sub_group_block_##rw##us##16 #define DEFINE_BLOCK_RW_NAME8(rw, us) intel_sub_group_block_##rw##us##8 #define DEFINE_BLOCK_RW_NAME7(rw, us) intel_sub_group_block_##rw##us##8 #define DEFINE_BLOCK_RW_NAME6(rw, us) intel_sub_group_block_##rw##us##8 @@ -775,100 +777,168 @@ INLINE void __builtin_spriv_OpJointMatrixMadINTEL_32x64x16_bf16_bf16_fp32(__priv __builtin_spriv_OpJointMatrixMadINTEL_16x16x16_bf16_bf16_fp32(a1, b3, c7, d7); } -INLINE void __builtin_spriv_OpJointMatrixLoadINTEL_Accumulator_RowMajor_SG16_16x16_i32_generic_v8i8_pi32_i32(__private char *dst, char *mem, long stride) { - IMPLEMENT_BLOCK2D_LOAD_SG16_ROW_MAJOR(int, 32, int, 32, 16, 16, 16) -} - -INLINE void __builtin_spriv_OpJointMatrixLoadINTEL_PackedA_RowMajor_SG16_16x16_i16_generic_v8i8_pi32_i32(__private char *dst, char *mem, long stride) { - IMPLEMENT_BLOCK2D_LOAD_SG16_ROW_MAJOR(short, 16, short, 16, 16, 16, 16) -} +#define DEFINE_LOAD_IMPL_LARGE(layout, sg, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, shape, order, us, stride_opt, address_space) \ + INLINE void MANGLE_LOAD_NAME_##address_space(layout, sg, elem_bitwidth, shape, M) (__private char *dst, char *mem, long stride) { \ + int sg_size = get_sub_group_size(); \ + if ( __JointMatrixLoadStoreOpt >= BLOCK2D_IMPL && (M == 2 || M == 4 || M == 8 || M == 16) \ + && (order == _ROW_MAJOR || order == _VNNI_TX) && address_space == AS_GLOBAL \ + ) { \ + /* It seems __builtin_IB_subgroup_block_rw always needs k=16 \ + Maybe it is number of columns divided by pack factor which always gives 16 on SG16 HW */ \ + IMPLEMENT_BLOCK2D_LOAD##sg##order(element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, 16, stride_opt) \ + } \ + contrib_type *ptr = (contrib_type *)mem; \ + int slid = get_sub_group_local_id(); \ + int pack_factor = sizeof (contrib_type) / sizeof (element_type); \ + stride = stride / pack_factor; \ + int sg_cols = K / pack_factor; \ + int skip_factor = sg_size / sg_cols; \ + __private contrib_type *wi_contrib = (__private contrib_type *)dst; \ + for (int i = 0; i < M; i++) { \ + if ( (i*skip_factor + slid/sg_cols) < M ) \ + wi_contrib[i] = ptr[IND##order(slid, stride, skip_factor, i, sg_cols)]; \ + else \ + wi_contrib[i] = 0; /*last even row for matrix with odd number of rows doesn't exist*/ \ + } \ + } -INLINE void __builtin_spriv_OpJointMatrixLoadINTEL_Accumulator_RowMajor_SG16_32x64_i32_generic_v8i8_pi32_i32(__private char *dst, char *mem, long stride) { - __private char *c0 = dst + 0 * 16 * (sizeof (int)); - __private char *c1 = dst + 1 * 16 * (sizeof (int)); - __private char *c2 = dst + 2 * 16 * (sizeof (int)); - __private char *c3 = dst + 3 * 16 * (sizeof (int)); - __private char *c4 = dst + 4 * 16 * (sizeof (int)); - __private char *c5 = dst + 5 * 16 * (sizeof (int)); - __private char *c6 = dst + 6 * 16 * (sizeof (int)); - __private char *c7 = dst + 7 * 16 * (sizeof (int)); - - char *mem0 = mem + 0 * 16 * (sizeof (int)); - char *mem1 = mem + 1 * 16 * (sizeof (int)); - char *mem2 = mem + 2 * 16 * (sizeof (int)); - char *mem3 = mem + 3 * 16 * (sizeof (int)); - char *mem4 = mem + 0 * 16 * (sizeof (int)) + 16 * (sizeof (int)) * stride; - char *mem5 = mem + 1 * 16 * (sizeof (int)) + 16 * (sizeof (int)) * stride; - char *mem6 = mem + 2 * 16 * (sizeof (int)) + 16 * (sizeof (int)) * stride; - char *mem7 = mem + 3 * 16 * (sizeof (int)) + 16 * (sizeof (int)) * stride; - - __builtin_spriv_OpJointMatrixLoadINTEL_Accumulator_RowMajor_SG16_16x16_i32_generic_v8i8_pi32_i32(c0, mem0, stride); - __builtin_spriv_OpJointMatrixLoadINTEL_Accumulator_RowMajor_SG16_16x16_i32_generic_v8i8_pi32_i32(c1, mem1, stride); - __builtin_spriv_OpJointMatrixLoadINTEL_Accumulator_RowMajor_SG16_16x16_i32_generic_v8i8_pi32_i32(c2, mem2, stride); - __builtin_spriv_OpJointMatrixLoadINTEL_Accumulator_RowMajor_SG16_16x16_i32_generic_v8i8_pi32_i32(c3, mem3, stride); - __builtin_spriv_OpJointMatrixLoadINTEL_Accumulator_RowMajor_SG16_16x16_i32_generic_v8i8_pi32_i32(c4, mem4, stride); - __builtin_spriv_OpJointMatrixLoadINTEL_Accumulator_RowMajor_SG16_16x16_i32_generic_v8i8_pi32_i32(c5, mem5, stride); - __builtin_spriv_OpJointMatrixLoadINTEL_Accumulator_RowMajor_SG16_16x16_i32_generic_v8i8_pi32_i32(c6, mem6, stride); - __builtin_spriv_OpJointMatrixLoadINTEL_Accumulator_RowMajor_SG16_16x16_i32_generic_v8i8_pi32_i32(c7, mem7, stride); -} +#define DEFINE_LOAD_LARGE(layout, sg, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, shape, order, us, stride_opt) \ + DEFINE_LOAD_IMPL_LARGE(layout, sg, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, shape, _##order, us, stride_opt, AS_GENERIC) \ + DEFINE_LOAD_IMPL_LARGE(layout, sg, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, shape, _##order, us, stride_opt, AS_LOCAL) \ + DEFINE_LOAD_IMPL_LARGE(layout, sg, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, shape, _##order, us, stride_opt, AS_GLOBAL) + +DEFINE_LOAD_LARGE(Accumulator_RowMajor, _SG16, int, 32, int, 32, 16, 16, 16x16, ROW_MAJOR, , 16) +DEFINE_LOAD_LARGE(PackedA_RowMajor, _SG16, short, 16, short, 16, 16, 16, 16x16, ROW_MAJOR, , 16) + +#define DEFINE_ACC_ROW_MAJOR_32x64(address_space) \ + INLINE void __builtin_spriv_OpJointMatrixLoadINTEL_Accumulator_RowMajor_SG16_32x64_i32_32_##address_space##_v8i8_pi32_i32(__private char *dst, char *mem, long stride) { \ + __private char *c0 = dst + 0 * 16 * (sizeof (int)); \ + __private char *c1 = dst + 1 * 16 * (sizeof (int)); \ + __private char *c2 = dst + 2 * 16 * (sizeof (int)); \ + __private char *c3 = dst + 3 * 16 * (sizeof (int)); \ + __private char *c4 = dst + 4 * 16 * (sizeof (int)); \ + __private char *c5 = dst + 5 * 16 * (sizeof (int)); \ + __private char *c6 = dst + 6 * 16 * (sizeof (int)); \ + __private char *c7 = dst + 7 * 16 * (sizeof (int)); \ +\ + char *mem0 = mem + 0 * 16 * (sizeof (int)); \ + char *mem1 = mem + 1 * 16 * (sizeof (int)); \ + char *mem2 = mem + 2 * 16 * (sizeof (int)); \ + char *mem3 = mem + 3 * 16 * (sizeof (int)); \ + char *mem4 = mem + 0 * 16 * (sizeof (int)) + 16 * (sizeof (int)) * stride;\ + char *mem5 = mem + 1 * 16 * (sizeof (int)) + 16 * (sizeof (int)) * stride;\ + char *mem6 = mem + 2 * 16 * (sizeof (int)) + 16 * (sizeof (int)) * stride;\ + char *mem7 = mem + 3 * 16 * (sizeof (int)) + 16 * (sizeof (int)) * stride;\ +\ + __builtin_spriv_OpJointMatrixLoadINTEL_Accumulator_RowMajor_SG16_16x16_i32_16_##address_space##_v8i8_pi32_i32(c0, mem0, stride); \ + __builtin_spriv_OpJointMatrixLoadINTEL_Accumulator_RowMajor_SG16_16x16_i32_16_##address_space##_v8i8_pi32_i32(c1, mem1, stride); \ + __builtin_spriv_OpJointMatrixLoadINTEL_Accumulator_RowMajor_SG16_16x16_i32_16_##address_space##_v8i8_pi32_i32(c2, mem2, stride); \ + __builtin_spriv_OpJointMatrixLoadINTEL_Accumulator_RowMajor_SG16_16x16_i32_16_##address_space##_v8i8_pi32_i32(c3, mem3, stride); \ + __builtin_spriv_OpJointMatrixLoadINTEL_Accumulator_RowMajor_SG16_16x16_i32_16_##address_space##_v8i8_pi32_i32(c4, mem4, stride); \ + __builtin_spriv_OpJointMatrixLoadINTEL_Accumulator_RowMajor_SG16_16x16_i32_16_##address_space##_v8i8_pi32_i32(c5, mem5, stride); \ + __builtin_spriv_OpJointMatrixLoadINTEL_Accumulator_RowMajor_SG16_16x16_i32_16_##address_space##_v8i8_pi32_i32(c6, mem6, stride); \ + __builtin_spriv_OpJointMatrixLoadINTEL_Accumulator_RowMajor_SG16_16x16_i32_16_##address_space##_v8i8_pi32_i32(c7, mem7, stride); \ + } -INLINE void __builtin_spriv_OpJointMatrixLoadINTEL_PackedA_RowMajor_SG16_32x16_i16_generic_v8i8_pi32_i32(__private char *dst, char *mem, long stride) { - __private char *dst0 = dst; - __private char *dst1 = dst + 16 * (sizeof (short)); +DEFINE_ACC_ROW_MAJOR_32x64(generic) +DEFINE_ACC_ROW_MAJOR_32x64(global) +DEFINE_ACC_ROW_MAJOR_32x64(local) + +#define DEFINE_A_ROW_MAJOR_32x16(address_space) \ + INLINE void __builtin_spriv_OpJointMatrixLoadINTEL_PackedA_RowMajor_SG16_32x16_i16_32_##address_space##_v8i8_pi32_i32(__private char *dst, char *mem, long stride) { \ + __private char *dst0 = dst; \ + __private char *dst1 = dst + 16 * (sizeof (short)); \ +\ + char *mem0 = mem; \ + char *mem1 = mem + 16 * (sizeof (short)) * stride; \ +\ + __builtin_spriv_OpJointMatrixLoadINTEL_PackedA_RowMajor_SG16_16x16_i16_16_##address_space##_v8i8_pi32_i32(dst0, mem0, stride); \ + __builtin_spriv_OpJointMatrixLoadINTEL_PackedA_RowMajor_SG16_16x16_i16_16_##address_space##_v8i8_pi32_i32(dst1, mem1, stride); \ + } - char *mem0 = mem; - char *mem1 = mem + 16 * (sizeof (short)) * stride; +DEFINE_A_ROW_MAJOR_32x16(generic) +DEFINE_A_ROW_MAJOR_32x16(global) +DEFINE_A_ROW_MAJOR_32x16(local) + +#define DEFINE_B_B_16x64(address_space) \ + INLINE void __builtin_spriv_OpJointMatrixLoadINTEL_PackedB_PackedB_SG16_16x64_i16_8_##address_space##_v8i8_pi32_i32(__private char *dst, char *mem, long stride) { \ + __private char *b0 = dst; \ + __private char *b1 = dst + 1 * 16 * (sizeof (short)); \ + __private char *b2 = dst + 2 * 16 * (sizeof (short)); \ + __private char *b3 = dst + 3 * 16 * (sizeof (short)); \ +\ + char *mem0 = mem + 0 * 16 * (sizeof (int)); \ + char *mem1 = mem + 1 * 16 * (sizeof (int)); \ + char *mem2 = mem + 2 * 16 * (sizeof (int)); \ + char *mem3 = mem + 3 * 16 * (sizeof (int)); \ +\ + __builtin_spriv_OpJointMatrixLoadINTEL_PackedB_PackedB_SG16_16x16_i16_8_##address_space##_v8i8_pi32_i32(b0, mem0, stride); \ + __builtin_spriv_OpJointMatrixLoadINTEL_PackedB_PackedB_SG16_16x16_i16_8_##address_space##_v8i8_pi32_i32(b1, mem1, stride); \ + __builtin_spriv_OpJointMatrixLoadINTEL_PackedB_PackedB_SG16_16x16_i16_8_##address_space##_v8i8_pi32_i32(b2, mem2, stride); \ + __builtin_spriv_OpJointMatrixLoadINTEL_PackedB_PackedB_SG16_16x16_i16_8_##address_space##_v8i8_pi32_i32(b3, mem3, stride); \ + } - __builtin_spriv_OpJointMatrixLoadINTEL_PackedA_RowMajor_SG16_16x16_i16_generic_v8i8_pi32_i32(dst0, mem0, stride); - __builtin_spriv_OpJointMatrixLoadINTEL_PackedA_RowMajor_SG16_16x16_i16_generic_v8i8_pi32_i32(dst1, mem1, stride); -} +DEFINE_B_B_16x64(generic) +DEFINE_B_B_16x64(global) +DEFINE_B_B_16x64(local) -INLINE void __builtin_spriv_OpJointMatrixLoadINTEL_PackedB_PackedB_SG16_16x64_i16_generic_v8i8_pi32_i32(__private char *dst, char *mem, long stride) { - __private char *b0 = dst; - __private char *b1 = dst + 1 * 16 * (sizeof (short)); - __private char *b2 = dst + 2 * 16 * (sizeof (short)); - __private char *b3 = dst + 3 * 16 * (sizeof (short)); - - char *mem0 = mem + 0 * 16 * (sizeof (int)); - char *mem1 = mem + 1 * 16 * (sizeof (int)); - char *mem2 = mem + 2 * 16 * (sizeof (int)); - char *mem3 = mem + 3 * 16 * (sizeof (int)); - - __builtin_spriv_OpJointMatrixLoadINTEL_PackedB_PackedB_SG16_16x16_i16_8_generic_v8i8_pi32_i32(b0, mem0, stride); - __builtin_spriv_OpJointMatrixLoadINTEL_PackedB_PackedB_SG16_16x16_i16_8_generic_v8i8_pi32_i32(b1, mem1, stride); - __builtin_spriv_OpJointMatrixLoadINTEL_PackedB_PackedB_SG16_16x16_i16_8_generic_v8i8_pi32_i32(b2, mem2, stride); - __builtin_spriv_OpJointMatrixLoadINTEL_PackedB_PackedB_SG16_16x16_i16_8_generic_v8i8_pi32_i32(b3, mem3, stride); -} +#define DEFINE_STORE_IMPL_LARGE(layout, sg, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, shape, order, us, stride_opt, address_space) \ + INLINE void MANGLE_STORE_NAME_##address_space(layout, sg, elem_bitwidth, shape, M) (char *mem, __private char *src, long stride) { \ + int sg_size = get_sub_group_size(); \ + if (__JointMatrixLoadStoreOpt >= BLOCK2D_IMPL && (M == 2 || M == 4 || M == 8) \ + && order == _ROW_MAJOR && address_space == AS_GLOBAL && elem_bitwidth > 8 \ + ) { \ + IMPLEMENT_BLOCK2D_STORE##sg(element_type, contrib_type, contrib_bitwidth, M, K, src) \ + } \ + contrib_type *ptr = (contrib_type *)mem; \ + int slid = get_sub_group_local_id(); \ + int pack_factor = sizeof (contrib_type) / sizeof (element_type); \ + stride = stride / pack_factor; \ + int sg_cols = K / pack_factor; \ + int skip_factor = sg_size / sg_cols; \ + __private contrib_type *slice = (__private contrib_type *)src; \ + for (int i = 0; i < M; i++) { \ + ptr[IND##order(slid, stride, skip_factor, i, sg_cols)] = slice[i]; \ + } \ + } -INLINE void __builtin_spriv_OpJointMatrixStoreINTEL_Accumulator_RowMajor_SG16_16x16_i32_generic_pi64_v8i8(char *mem, __private char *src, long stride) { - IMPLEMENT_BLOCK2D_STORE_SG16(int, int, 32, 16, 16, slice) -} +#define DEFINE_STORE_LARGE(layout, sg, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, shape, order, us, stride_opt) \ + DEFINE_STORE_IMPL_LARGE(layout, sg, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, shape, _##order, us, stride_opt, AS_GENERIC) \ + DEFINE_STORE_IMPL_LARGE(layout, sg, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, shape, _##order, us, stride_opt, AS_LOCAL) \ + DEFINE_STORE_IMPL_LARGE(layout, sg, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, shape, _##order, us, stride_opt, AS_GLOBAL) + +DEFINE_STORE_LARGE(Accumulator_RowMajor, _SG16, int, 32, int, 32, 16, 16, 16x16, ROW_MAJOR, , 16) + +#define DEFINE_STORE_ACC_ROW_MAJOR_32x64(address_space) \ + INLINE void __builtin_spriv_OpJointMatrixStoreINTEL_Accumulator_RowMajor_SG16_32x64_i32_32_##address_space##_pi64_v8i8(char *mem, __private char *src, long stride) { \ + __private char *c0 = src + 0 * 16 * (sizeof (int)); \ + __private char *c1 = src + 1 * 16 * (sizeof (int)); \ + __private char *c2 = src + 2 * 16 * (sizeof (int)); \ + __private char *c3 = src + 3 * 16 * (sizeof (int)); \ + __private char *c4 = src + 4 * 16 * (sizeof (int)); \ + __private char *c5 = src + 5 * 16 * (sizeof (int)); \ + __private char *c6 = src + 6 * 16 * (sizeof (int)); \ + __private char *c7 = src + 7 * 16 * (sizeof (int)); \ +\ + char *mem0 = mem + 0 * 16 * (sizeof (int)); \ + char *mem1 = mem + 1 * 16 * (sizeof (int)); \ + char *mem2 = mem + 2 * 16 * (sizeof (int)); \ + char *mem3 = mem + 3 * 16 * (sizeof (int)); \ + char *mem4 = mem + 0 * 16 * (sizeof (int)) + 16 * (sizeof (int)) * stride; \ + char *mem5 = mem + 1 * 16 * (sizeof (int)) + 16 * (sizeof (int)) * stride; \ + char *mem6 = mem + 2 * 16 * (sizeof (int)) + 16 * (sizeof (int)) * stride; \ + char *mem7 = mem + 3 * 16 * (sizeof (int)) + 16 * (sizeof (int)) * stride; \ +\ + __builtin_spriv_OpJointMatrixStoreINTEL_Accumulator_RowMajor_SG16_16x16_i32_16_##address_space##_pi64_v8i8(mem0, c0, stride); \ + __builtin_spriv_OpJointMatrixStoreINTEL_Accumulator_RowMajor_SG16_16x16_i32_16_##address_space##_pi64_v8i8(mem1, c1, stride); \ + __builtin_spriv_OpJointMatrixStoreINTEL_Accumulator_RowMajor_SG16_16x16_i32_16_##address_space##_pi64_v8i8(mem2, c2, stride); \ + __builtin_spriv_OpJointMatrixStoreINTEL_Accumulator_RowMajor_SG16_16x16_i32_16_##address_space##_pi64_v8i8(mem3, c3, stride); \ + __builtin_spriv_OpJointMatrixStoreINTEL_Accumulator_RowMajor_SG16_16x16_i32_16_##address_space##_pi64_v8i8(mem4, c4, stride); \ + __builtin_spriv_OpJointMatrixStoreINTEL_Accumulator_RowMajor_SG16_16x16_i32_16_##address_space##_pi64_v8i8(mem5, c5, stride); \ + __builtin_spriv_OpJointMatrixStoreINTEL_Accumulator_RowMajor_SG16_16x16_i32_16_##address_space##_pi64_v8i8(mem6, c6, stride); \ + __builtin_spriv_OpJointMatrixStoreINTEL_Accumulator_RowMajor_SG16_16x16_i32_16_##address_space##_pi64_v8i8(mem7, c7, stride); \ + } -INLINE void __builtin_spriv_OpJointMatrixStoreINTEL_Accumulator_RowMajor_SG16_32x64_i32_generic_pi64_v8i8(char *mem, __private char *src, long stride) { - __private char *c0 = src + 0 * 16 * (sizeof (int)); - __private char *c1 = src + 1 * 16 * (sizeof (int)); - __private char *c2 = src + 2 * 16 * (sizeof (int)); - __private char *c3 = src + 3 * 16 * (sizeof (int)); - __private char *c4 = src + 4 * 16 * (sizeof (int)); - __private char *c5 = src + 5 * 16 * (sizeof (int)); - __private char *c6 = src + 6 * 16 * (sizeof (int)); - __private char *c7 = src + 7 * 16 * (sizeof (int)); - - char *mem0 = mem + 0 * 16 * (sizeof (int)); - char *mem1 = mem + 1 * 16 * (sizeof (int)); - char *mem2 = mem + 2 * 16 * (sizeof (int)); - char *mem3 = mem + 3 * 16 * (sizeof (int)); - char *mem4 = mem + 0 * 16 * (sizeof (int)) + 16 * (sizeof (int)) * stride; - char *mem5 = mem + 1 * 16 * (sizeof (int)) + 16 * (sizeof (int)) * stride; - char *mem6 = mem + 2 * 16 * (sizeof (int)) + 16 * (sizeof (int)) * stride; - char *mem7 = mem + 3 * 16 * (sizeof (int)) + 16 * (sizeof (int)) * stride; - - __builtin_spriv_OpJointMatrixStoreINTEL_Accumulator_RowMajor_SG16_16x16_i32_generic_pi64_v8i8(mem0, c0, stride); - __builtin_spriv_OpJointMatrixStoreINTEL_Accumulator_RowMajor_SG16_16x16_i32_generic_pi64_v8i8(mem1, c1, stride); - __builtin_spriv_OpJointMatrixStoreINTEL_Accumulator_RowMajor_SG16_16x16_i32_generic_pi64_v8i8(mem2, c2, stride); - __builtin_spriv_OpJointMatrixStoreINTEL_Accumulator_RowMajor_SG16_16x16_i32_generic_pi64_v8i8(mem3, c3, stride); - __builtin_spriv_OpJointMatrixStoreINTEL_Accumulator_RowMajor_SG16_16x16_i32_generic_pi64_v8i8(mem4, c4, stride); - __builtin_spriv_OpJointMatrixStoreINTEL_Accumulator_RowMajor_SG16_16x16_i32_generic_pi64_v8i8(mem5, c5, stride); - __builtin_spriv_OpJointMatrixStoreINTEL_Accumulator_RowMajor_SG16_16x16_i32_generic_pi64_v8i8(mem6, c6, stride); - __builtin_spriv_OpJointMatrixStoreINTEL_Accumulator_RowMajor_SG16_16x16_i32_generic_pi64_v8i8(mem7, c7, stride); -} +DEFINE_STORE_ACC_ROW_MAJOR_32x64(generic) +DEFINE_STORE_ACC_ROW_MAJOR_32x64(global) +DEFINE_STORE_ACC_ROW_MAJOR_32x64(local) diff --git a/IGC/Compiler/tests/JointMatrixFuncsResolutionPass/address-spaces.ll b/IGC/Compiler/tests/JointMatrixFuncsResolutionPass/address-spaces.ll new file mode 100644 index 000000000000..6ab1d271d89a --- /dev/null +++ b/IGC/Compiler/tests/JointMatrixFuncsResolutionPass/address-spaces.ll @@ -0,0 +1,214 @@ +;=========================== begin_copyright_notice ============================ +; +; Copyright (C) 2023 Intel Corporation +; +; SPDX-License-Identifier: MIT +; +;============================ end_copyright_notice ============================= +; +; RUN: igc_opt -igc-joint-matrix-resolution -dce -S 2>&1 < %s | FileCheck %s +; ------------------------------------------------ +; JointMatrixFuncsResolutionPass +; ------------------------------------------------ + +%intel.joint_matrix_packedA_8x16_i32_t = type opaque +%intel.joint_matrix_acc_32x64_f32_t = type opaque + +define spir_kernel void @test_generic(i8* %src, i8* %dst) { + call void @load_store_generic(i8* %src, i8* %dst) + call void @load_store_large_generic(i8* %src, i8* %dst) + ret void +} + +define spir_kernel void @test_global(i8 addrspace(1)* %src, i8 addrspace(1)* %dst) { + call void @load_store_global(i8 addrspace(1)* %src, i8 addrspace(1)* %dst) + call void @load_store_large_global(i8 addrspace(1)* %src, i8 addrspace(1)* %dst) + ret void +} + +define spir_kernel void @test_local(i8 addrspace(3)* %src, i8 addrspace(3)* %dst) { + call void @load_store_local(i8 addrspace(3)* %src, i8 addrspace(3)* %dst) + call void @load_store_large_local(i8 addrspace(3)* %src, i8 addrspace(3)* %dst) + ret void +} + +; CHECK-LABEL: define void @load_store_generic( +define void @load_store_generic(i8* %src, i8* %dst) { + +; Matrix load sequence: +; CHECK: [[PTR:%.*]] = alloca <8 x i32> +; CHECK: [[MATPTR:%.*]] = bitcast <8 x i32>* [[PTR]] to i8* +; CHECK: call void @__builtin_spriv_OpJointMatrixLoadINTEL_PackedA_RowMajor_8x16_i32_8_generic_v8i8_pi32_i32(i8* [[MATPTR]], i8* %src, i32 16) +; CHECK: [[MATRIX:%.*]] = load <8 x i32>, <8 x i32>* [[PTR]] + + %1 = call spir_func %intel.joint_matrix_packedA_8x16_i32_t* @__builtin_spirv_OpJointMatrixLoadINTEL_generic(i8* %src, i32 16, i32 0) + +; Matrix store sequence: +; CHECK: [[TMP4:%.*]] = alloca <8 x i32> +; CHECK: store <8 x i32> [[MATRIX]], <8 x i32>* [[TMP4]] +; CHECK: [[TMP5:%.*]] = bitcast <8 x i32>* [[TMP4]] to i8* +; CHECK: call void @__builtin_spriv_OpJointMatrixStoreINTEL_PackedA_RowMajor_8x16_i32_8_generic_pi64_v8i8(i8* %dst, i8* [[TMP5]], i32 8) + + call spir_func void @__builtin_spirv_OpJointMatrixStoreINTEL.8x16_generic(i8* %dst, %intel.joint_matrix_packedA_8x16_i32_t* %1, i32 8, i32 0) + +; CHECK: ret void + + ret void +} + +; CHECK-LABEL: define void @load_store_large_generic( +define void @load_store_large_generic(i8* %src, i8* %dst) { + +; Matrix load sequence: +; CHECK: [[PTR:%.*]] = alloca [2 x <32 x i64>] +; CHECK: [[MATPTR:%.*]] = bitcast [2 x <32 x i64>]* [[PTR]] to i8* +; CHECK: call void @__builtin_spriv_OpJointMatrixLoadINTEL_Accumulator_RowMajor_32x64_i32_32_generic_v8i8_pi32_i32(i8* [[MATPTR]], i8* %src, i64 16) +; CHECK: [[HALF_PTR_0:%.*]] = bitcast [2 x <32 x i64>]* [[PTR]] to <32 x i64>* +; CHECK: [[HALF_VAL_0:%.*]] = load <32 x i64>, <32 x i64>* [[HALF_PTR_0]] +; CHECK: [[HALF_PTR_1:%.*]] = getelementptr <32 x i64>, <32 x i64>* [[HALF_PTR_0]], i32 1 +; CHECK: [[HALF_VAL_1:%.*]] = load <32 x i64>, <32 x i64>* [[HALF_PTR_1]] +; CHECK: [[MATRIX_PARTIAL:%.*]] = insertvalue [2 x <32 x i64>] undef, <32 x i64> [[HALF_VAL_0]], 0 +; CHECK: [[MATRIX:%.*]] = insertvalue [2 x <32 x i64>] [[MATRIX_PARTIAL]], <32 x i64> [[HALF_VAL_1]], 1 + + %1 = call spir_func %intel.joint_matrix_acc_32x64_f32_t* @__builtin_spirv_OpJointMatrixLoadINTELacc_32x64_f32_p1i8_i64_i32_generic(i8* %src, i64 16, i32 0) + +; Matrix store sequence: +; CHECK: [[TMP4:%.*]] = alloca [2 x <32 x i64>] +; CHECK: store [2 x <32 x i64>] [[MATRIX]], [2 x <32 x i64>]* [[TMP4]] +; CHECK: [[TMP5:%.*]] = bitcast [2 x <32 x i64>]* [[TMP4]] to i8* +; CHECK: call void @__builtin_spriv_OpJointMatrixStoreINTEL_Accumulator_RowMajor_32x64_i32_32_generic_pi64_v8i8(i8* %dst, i8* [[TMP5]], i64 8) + + call spir_func void @__builtin_spirv_OpJointMatrixStoreINTELacc_32x64_f32_p1i8_acc_32x64_f32_i64_i32_generic(i8* %dst, %intel.joint_matrix_acc_32x64_f32_t* %1, i64 8, i32 0) + +; CHECK: ret void + + ret void +} + +; CHECK-LABEL: define void @load_store_global( +define void @load_store_global(i8 addrspace(1)* %src, i8 addrspace(1)* %dst) { + +; Matrix load sequence: +; CHECK: [[PTR:%.*]] = alloca <8 x i32> +; CHECK: [[MATPTR:%.*]] = bitcast <8 x i32>* [[PTR]] to i8* +; CHECK: call void @__builtin_spriv_OpJointMatrixLoadINTEL_PackedA_RowMajor_8x16_i32_8_global_v8i8_pi32_i32(i8* [[MATPTR]], i8 addrspace(1)* %src, i32 16) +; CHECK: [[MATRIX:%.*]] = load <8 x i32>, <8 x i32>* [[PTR]] + + %1 = call spir_func %intel.joint_matrix_packedA_8x16_i32_t* @__builtin_spirv_OpJointMatrixLoadINTEL_global(i8 addrspace(1)* %src, i32 16, i32 0) + +; Matrix store sequence: +; CHECK: [[TMP4:%.*]] = alloca <8 x i32> +; CHECK: store <8 x i32> [[MATRIX]], <8 x i32>* [[TMP4]] +; CHECK: [[TMP5:%.*]] = bitcast <8 x i32>* [[TMP4]] to i8* +; CHECK: call void @__builtin_spriv_OpJointMatrixStoreINTEL_PackedA_RowMajor_8x16_i32_8_global_pi64_v8i8(i8 addrspace(1)* %dst, i8* [[TMP5]], i32 8) + + call spir_func void @__builtin_spirv_OpJointMatrixStoreINTEL.8x16_global(i8 addrspace(1)* %dst, %intel.joint_matrix_packedA_8x16_i32_t* %1, i32 8, i32 0) + +; CHECK: ret void + + ret void +} + +; CHECK-LABEL: define void @load_store_large_global( +define void @load_store_large_global(i8 addrspace(1)* %src, i8 addrspace(1)* %dst) { + +; Matrix load sequence: +; CHECK: [[PTR:%.*]] = alloca [2 x <32 x i64>] +; CHECK: [[MATPTR:%.*]] = bitcast [2 x <32 x i64>]* [[PTR]] to i8* +; CHECK: call void @__builtin_spriv_OpJointMatrixLoadINTEL_Accumulator_RowMajor_32x64_i32_32_global_v8i8_pi32_i32(i8* [[MATPTR]], i8 addrspace(1)* %src, i64 16) +; CHECK: [[HALF_PTR_0:%.*]] = bitcast [2 x <32 x i64>]* [[PTR]] to <32 x i64>* +; CHECK: [[HALF_VAL_0:%.*]] = load <32 x i64>, <32 x i64>* [[HALF_PTR_0]] +; CHECK: [[HALF_PTR_1:%.*]] = getelementptr <32 x i64>, <32 x i64>* [[HALF_PTR_0]], i32 1 +; CHECK: [[HALF_VAL_1:%.*]] = load <32 x i64>, <32 x i64>* [[HALF_PTR_1]] +; CHECK: [[MATRIX_PARTIAL:%.*]] = insertvalue [2 x <32 x i64>] undef, <32 x i64> [[HALF_VAL_0]], 0 +; CHECK: [[MATRIX:%.*]] = insertvalue [2 x <32 x i64>] [[MATRIX_PARTIAL]], <32 x i64> [[HALF_VAL_1]], 1 + + %1 = call spir_func %intel.joint_matrix_acc_32x64_f32_t* @__builtin_spirv_OpJointMatrixLoadINTELacc_32x64_f32_p1i8_i64_i32_global(i8 addrspace(1)* %src, i64 16, i32 0) + +; Matrix store sequence: +; CHECK: [[TMP4:%.*]] = alloca [2 x <32 x i64>] +; CHECK: store [2 x <32 x i64>] [[MATRIX]], [2 x <32 x i64>]* [[TMP4]] +; CHECK: [[TMP5:%.*]] = bitcast [2 x <32 x i64>]* [[TMP4]] to i8* +; CHECK: call void @__builtin_spriv_OpJointMatrixStoreINTEL_Accumulator_RowMajor_32x64_i32_32_global_pi64_v8i8(i8 addrspace(1)* %dst, i8* [[TMP5]], i64 8) + + call spir_func void @__builtin_spirv_OpJointMatrixStoreINTELacc_32x64_f32_p1i8_acc_32x64_f32_i64_i32_global(i8 addrspace(1)* %dst, %intel.joint_matrix_acc_32x64_f32_t* %1, i64 8, i32 0) + +; CHECK: ret void + + ret void +} + +; CHECK-LABEL: define void @load_store_local( +define void @load_store_local(i8 addrspace(3)* %src, i8 addrspace(3)* %dst) { + +; Matrix load sequence: +; CHECK: [[PTR:%.*]] = alloca <8 x i32> +; CHECK: [[MATPTR:%.*]] = bitcast <8 x i32>* [[PTR]] to i8* +; CHECK: call void @__builtin_spriv_OpJointMatrixLoadINTEL_PackedA_RowMajor_8x16_i32_8_local_v8i8_pi32_i32(i8* [[MATPTR]], i8 addrspace(3)* %src, i32 16) +; CHECK: [[MATRIX:%.*]] = load <8 x i32>, <8 x i32>* [[PTR]] + + %1 = call spir_func %intel.joint_matrix_packedA_8x16_i32_t* @__builtin_spirv_OpJointMatrixLoadINTEL_local(i8 addrspace(3)* %src, i32 16, i32 0) + +; Matrix store sequence: +; CHECK: [[TMP4:%.*]] = alloca <8 x i32> +; CHECK: store <8 x i32> [[MATRIX]], <8 x i32>* [[TMP4]] +; CHECK: [[TMP5:%.*]] = bitcast <8 x i32>* [[TMP4]] to i8* +; CHECK: call void @__builtin_spriv_OpJointMatrixStoreINTEL_PackedA_RowMajor_8x16_i32_8_local_pi64_v8i8(i8 addrspace(3)* %dst, i8* [[TMP5]], i32 8) + + call spir_func void @__builtin_spirv_OpJointMatrixStoreINTEL.8x16_local(i8 addrspace(3)* %dst, %intel.joint_matrix_packedA_8x16_i32_t* %1, i32 8, i32 0) + +; CHECK: ret void + + ret void +} + +; CHECK-LABEL: define void @load_store_large_local( +define void @load_store_large_local(i8 addrspace(3)* %src, i8 addrspace(3)* %dst) { + +; Matrix load sequence: +; CHECK: [[PTR:%.*]] = alloca [2 x <32 x i64>] +; CHECK: [[MATPTR:%.*]] = bitcast [2 x <32 x i64>]* [[PTR]] to i8* +; CHECK: call void @__builtin_spriv_OpJointMatrixLoadINTEL_Accumulator_RowMajor_32x64_i32_32_local_v8i8_pi32_i32(i8* [[MATPTR]], i8 addrspace(3)* %src, i64 16) +; CHECK: [[HALF_PTR_0:%.*]] = bitcast [2 x <32 x i64>]* [[PTR]] to <32 x i64>* +; CHECK: [[HALF_VAL_0:%.*]] = load <32 x i64>, <32 x i64>* [[HALF_PTR_0]] +; CHECK: [[HALF_PTR_1:%.*]] = getelementptr <32 x i64>, <32 x i64>* [[HALF_PTR_0]], i32 1 +; CHECK: [[HALF_VAL_1:%.*]] = load <32 x i64>, <32 x i64>* [[HALF_PTR_1]] +; CHECK: [[MATRIX_PARTIAL:%.*]] = insertvalue [2 x <32 x i64>] undef, <32 x i64> [[HALF_VAL_0]], 0 +; CHECK: [[MATRIX:%.*]] = insertvalue [2 x <32 x i64>] [[MATRIX_PARTIAL]], <32 x i64> [[HALF_VAL_1]], 1 + + %1 = call spir_func %intel.joint_matrix_acc_32x64_f32_t* @__builtin_spirv_OpJointMatrixLoadINTELacc_32x64_f32_p1i8_i64_i32_local(i8 addrspace(3)* %src, i64 16, i32 0) + +; Matrix store sequence: +; CHECK: [[TMP4:%.*]] = alloca [2 x <32 x i64>] +; CHECK: store [2 x <32 x i64>] [[MATRIX]], [2 x <32 x i64>]* [[TMP4]] +; CHECK: [[TMP5:%.*]] = bitcast [2 x <32 x i64>]* [[TMP4]] to i8* +; CHECK: call void @__builtin_spriv_OpJointMatrixStoreINTEL_Accumulator_RowMajor_32x64_i32_32_local_pi64_v8i8(i8 addrspace(3)* %dst, i8* [[TMP5]], i64 8) + + call spir_func void @__builtin_spirv_OpJointMatrixStoreINTELacc_32x64_f32_p1i8_acc_32x64_f32_i64_i32_local(i8 addrspace(3)* %dst, %intel.joint_matrix_acc_32x64_f32_t* %1, i64 8, i32 0) + +; CHECK: ret void + + ret void +} + +declare spir_func %intel.joint_matrix_packedA_8x16_i32_t* @__builtin_spirv_OpJointMatrixLoadINTEL_generic(i8*, i32, i32) +declare spir_func %intel.joint_matrix_packedA_8x16_i32_t* @__builtin_spirv_OpJointMatrixLoadINTEL_global(i8 addrspace(1)*, i32, i32) +declare spir_func %intel.joint_matrix_packedA_8x16_i32_t* @__builtin_spirv_OpJointMatrixLoadINTEL_local(i8 addrspace(3)*, i32, i32) +declare spir_func void @__builtin_spirv_OpJointMatrixStoreINTEL.8x16_generic(i8*, %intel.joint_matrix_packedA_8x16_i32_t*, i32, i32) +declare spir_func void @__builtin_spirv_OpJointMatrixStoreINTEL.8x16_global(i8 addrspace(1)*, %intel.joint_matrix_packedA_8x16_i32_t*, i32, i32) +declare spir_func void @__builtin_spirv_OpJointMatrixStoreINTEL.8x16_local(i8 addrspace(3)*, %intel.joint_matrix_packedA_8x16_i32_t*, i32, i32) + +declare %intel.joint_matrix_acc_32x64_f32_t* @__builtin_spirv_OpJointMatrixLoadINTELacc_32x64_f32_p1i8_i64_i32_generic(i8*, i64, i32) +declare %intel.joint_matrix_acc_32x64_f32_t* @__builtin_spirv_OpJointMatrixLoadINTELacc_32x64_f32_p1i8_i64_i32_global(i8 addrspace(1)*, i64, i32) +declare %intel.joint_matrix_acc_32x64_f32_t* @__builtin_spirv_OpJointMatrixLoadINTELacc_32x64_f32_p1i8_i64_i32_local(i8 addrspace(3)*, i64, i32) +declare void @__builtin_spirv_OpJointMatrixStoreINTELacc_32x64_f32_p1i8_acc_32x64_f32_i64_i32_generic(i8*, %intel.joint_matrix_acc_32x64_f32_t *, i64, i32) +declare void @__builtin_spirv_OpJointMatrixStoreINTELacc_32x64_f32_p1i8_acc_32x64_f32_i64_i32_global(i8 addrspace(1)*, %intel.joint_matrix_acc_32x64_f32_t *, i64, i32) +declare void @__builtin_spirv_OpJointMatrixStoreINTELacc_32x64_f32_p1i8_acc_32x64_f32_i64_i32_local(i8 addrspace(3)*, %intel.joint_matrix_acc_32x64_f32_t *, i64, i32) + +!igc.functions = !{!0, !4, !5} +!0 = !{void (i8*, i8*)* @test_generic, !1} +!4 = !{void (i8 addrspace(1)*, i8 addrspace(1)*)* @test_global, !1} +!5 = !{void (i8 addrspace(3)*, i8 addrspace(3)*)* @test_local, !1} +!1 = !{!2, !3} +!2 = !{!"function_type", i32 0} +!3 = !{!"sub_group_size", i32 16}