Skip to content

Commit

Permalink
[BM] [amx] Test Enhance: add amx_fp16 test support based on gcc version
Browse files Browse the repository at this point in the history
add amx_fp16 support based on gcc version, gcc 13.3 or higher will
include amx_fp16 test support, otherwise bypass it since gcc don't
support it likely and test code compile will fail

[Test Components] amx/tmul
[Test Types] any
[Supported Devices] all-generic

Signed-off-by: Hongyu Ning <[email protected]>
  • Loading branch information
hongyuni committed Jun 14, 2024
1 parent 740bf88 commit 1106850
Show file tree
Hide file tree
Showing 3 changed files with 245 additions and 13 deletions.
12 changes: 10 additions & 2 deletions BM/amx/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,14 @@ It will be checked if the result is correct or not.

- How to build:

gcc 11.1 or above is required.
gcc 11.1 or above is required for amx_bf16, amx_int8 test.
gcc 13.3 or above is required for extra amx_fp16 test.

To compile,
To compile,
for amx_bf16, amx_int8 test:
$ make
for extra amx_fp16 test:
$ make fp16

To clean,
$ make clean
Expand Down Expand Up @@ -47,3 +51,7 @@ It will be checked if the result is correct or not.
f. Break sub-thread which is doing TMUL TDPBUUD calculation by futex
$ ./tmul -b 5 -t 10 -c 20 -i 4

g. Break sub-thread which is doing TMUL TDPFP16PS calculation by yield
$ ./tmul -b 1 -t 10 -c 20 -i 5


3 changes: 3 additions & 0 deletions BM/amx/tmul/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ CFILES_AMX = tmul.c
all:
$(CC) $(CFLAG) $(CFILES_AMX) -o $(BIN_AMX) $(LIBS)

fp16: CFLAG += -g -DFP16
fp16: all

clean:
-rm $(BIN_AMX)

Expand Down
243 changes: 232 additions & 11 deletions BM/amx/tmul/tmul.c
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,25 @@
#define FUTEX_VAL 0x5E5E5E5E

#define DPBD(c, x, y, type1, type2) \
c = c + \
(uint32_t)((type1)(((uint8_t *)(x))[0])) * (uint32_t)((type2)(((uint8_t *)(y))[0])) + \
(uint32_t)((type1)(((uint8_t *)(x))[1])) * (uint32_t)((type2)(((uint8_t *)(y))[1])) + \
(uint32_t)((type1)(((uint8_t *)(x))[2])) * (uint32_t)((type2)(((uint8_t *)(y))[2])) + \
(uint32_t)((type1)(((uint8_t *)(x))[3])) * (uint32_t)((type2)(((uint8_t *)(y))[3]))
{ \
c = (c + \
(uint32_t)((type1)(((uint8_t *)(x))[0])) * (uint32_t)((type2)(((uint8_t *)(y))[0])) + \
(uint32_t)((type1)(((uint8_t *)(x))[1])) * (uint32_t)((type2)(((uint8_t *)(y))[1])) + \
(uint32_t)((type1)(((uint8_t *)(x))[2])) * (uint32_t)((type2)(((uint8_t *)(y))[2])) + \
(uint32_t)((type1)(((uint8_t *)(x))[3])) * (uint32_t)((type2)(((uint8_t *)(y))[3]))); \
}

#define load_tile_reg(tmm_num, tile, stride) \
asm volatile("tileloadd\t(%0,%1,1), %%tmm" #tmm_num \
: : "r" ((void *)(tile)->buf), "r" ((long)stride) : "memory")
{ \
asm volatile("tileloadd\t(%0,%1,1), %%tmm" #tmm_num \
: : "r" ((void *)(tile)->buf), "r" ((long)stride) : "memory"); \
}

#define store_tile_reg(tmm_num, tile, stride) \
asm volatile("tilestored\t%%tmm" #tmm_num ", (%0,%1,1)" \
: : "r" ((void *)(tile)->buf), "r" ((long)stride) : "memory")
{ \
asm volatile("tilestored\t%%tmm" #tmm_num ", (%0,%1,1)" \
: : "r" ((void *)(tile)->buf), "r" ((long)stride) : "memory"); \
}

enum {
BREAK_BY_NOTHING = 0,
Expand All @@ -58,7 +64,12 @@ enum {
INS_TDPBSUD,
INS_TDPBUSD,
INS_TDPBUUD,
#ifdef FP16
INS_TDPFP16PS,
INS_MAX_NUM = INS_TDPFP16PS
#else
INS_MAX_NUM = INS_TDPBUUD
#endif
} ENUM_INSTRUCTION_TYPE;

struct __tile_config {
Expand Down Expand Up @@ -118,6 +129,77 @@ static float convert_bf16_to_fp32(uint16_t bf16)
return *((float *)&u);
}

#ifdef FP16
/*
* convert_fp32_to_fp16() - Convert data format.
* @fp32: A FP32 value.
*
* Covert FP32 to FP16.
*/
static uint16_t convert_fp32_to_fp16(float fp32)
{
uint32_t u = *((uint32_t *)&fp32);
uint16_t sign = (u >> 16) & 0x8000;
uint16_t fraction = (u & 0x007fffff) >> 13;
uint16_t exponent = (((u & 0x7f800000) >> 23) - 127 + 15) << 10;

uint16_t fp16 = sign | exponent | fraction;

return fp16;
}

/*
* convert_fp16_to_fp32() - Convert data format.
* @fp: A FP16 value.
*
* Covert FP16 to FP32.
*/
static float convert_fp16_to_fp32(uint16_t fp16)
{
int shift;
uint32_t u;
float fp32 = 0;
uint32_t sign = (fp16 & 0x8000) << 16;
uint32_t fraction = (fp16 & 0x3ff) << 13;
uint32_t exponent = (fp16 & 0x7c00) >> 10;

if (exponent == 0x1f && fraction == 0) {
if (sign)
fp32 = -INFINITY;
else
fp32 = INFINITY;
} else if (exponent == 0x1f && fraction != 0) {
fp32 = NAN;
} else if (exponent == 0) {
if (fraction == 0) {
/* +0.0, -0.0 */
u = sign;
fp32 = *((float *)&u);
} else {
/* Convert subnormal into normal fp32 number */
for (int i = 22; i >= 12; i--) {
if (fraction & (1 << i)) {
shift = (1 + 22 - i);
fraction = (fraction << shift) & 0x7fffff;
u = sign |
((exponent - 15 - (shift - 1) + 127) << 23) |
fraction;
fp32 = *((float *)&u);

break;
}
}
}
//fp32 = *((float *)&u);
} else {
u = sign | ((exponent - 15 + 127) << 23) | fraction;
fp32 = *((float *)&u);
}

return fp32;
}
#endif

/*
* do_syscall() - Execute syscall instruction.
* @nr: syscall number will be saved rax register.
Expand Down Expand Up @@ -191,6 +273,34 @@ static void init_bf16_tile(struct __tile *tile_ptr, uint8_t rows, uint8_t colsb)
}
}

#ifdef FP16
/*
* init_fp16_tile() - Init buffer.
* @buf: The buffer for saving data.
* @rows: Row number of the matrix.
* @colsb: Column number of the matrix.
*
* Init buffer with chaotic float.
*/
static void init_fp16_tile(struct __tile *tile_ptr, uint8_t rows,
uint8_t colsb)
{
int32_t i, j;
uint16_t *ptr = (uint16_t *)tile_ptr->buf;
int32_t cols = colsb / 2;
float f = 0;

tile_ptr->rows = rows;
tile_ptr->colsb = colsb;

for (i = 0; i < rows; i++)
for (j = 0; j < cols; j++) {
f = 2.718f;
ptr[i * cols + j] = convert_fp32_to_fp16(f);
}
}
#endif

/*
* init_dword_tile() - Init buffer.
* @buf: The buffer for saving data.
Expand Down Expand Up @@ -297,6 +407,41 @@ static void calc_matrix_tdpbf16ps(struct __tile *dst, struct __tile *src1, struc
}
}

#ifdef FP16
/*
* calc_matrix_tdpfp16ps() - Software algorithm for instruction TDPBF16PS.
* @dst: The product of matrix multiplication.
* @src1: The first multiplier.
* @src2: The second multiplier.
*
* Compute dot-product of FP16 (16-bit) floating-point pairs in tiles a and b,
* accumulating the intermediate single-precision (32-bit) floating-point
* elements with elements in dst,
* and store the 32-bit result back to tile dst.
*/
static void calc_matrix_tdpfp16ps(struct __tile *dst, struct __tile *src1, struct __tile *src2)
{
uint16_t *src1_buf = (uint16_t *)src1->buf;
uint16_t *src2_buf = (uint16_t *)src2->buf;
float *dst_buf = (float *)dst->buf;

int32_t M = src1->rows;
int32_t K = src1->colsb / 4;
int32_t N = src2->colsb / 4;
int32_t m, k, n;

for (m = 0; m < M; m++)
for (k = 0; k < K; k++)
for (n = 0; n < N; n++) {
dst_buf[m * N + n] +=
(convert_fp16_to_fp32(src1_buf[m * K * 2 + k * 2 + 0]) *
convert_fp16_to_fp32(src2_buf[k * N * 2 + n * 2 + 0])) +
(convert_fp16_to_fp32(src1_buf[m * K * 2 + k * 2 + 1]) *
convert_fp16_to_fp32(src2_buf[k * N * 2 + n * 2 + 1]));
}
}
#endif

/*
* calc_matrix_tdpbssd() - Software algorithm for instruction TDPBSSD.
* @dst: The product of matrix multiplication.
Expand Down Expand Up @@ -441,6 +586,16 @@ static void tile_dpbf16ps(void)
asm volatile("tdpbf16ps %tmm2, %tmm1, %tmm0");
}

#ifdef FP16
static void tile_dpfp16ps(void)
{
asm volatile("tdpfp16ps %tmm7, %tmm6, %tmm5");
asm volatile("tdpfp16ps %tmm4, %tmm3, %tmm2");
asm volatile("tdpfp16ps %tmm5, %tmm2, %tmm1");
asm volatile("tdpfp16ps %tmm2, %tmm1, %tmm0");
}
#endif

static void tile_dpbssd(void)
{
asm volatile("tdpbssd %tmm7, %tmm6, %tmm5");
Expand Down Expand Up @@ -511,6 +666,46 @@ static bool check_tile_bf16_register(struct __tile *ref, struct __tile *target)
return true;
}

#ifdef FP16
/*
* check_tile_fp16_register() - check calculation result.
* @ref: The result calculated by AMX/TMUL.
* @target: The result calculated by software.
*
* Check if the difference of the 2 results is small enough.
*
* Return:
* true - OK
* false - Abnormal
*/
static bool check_tile_fp16_register(struct __tile *ref, struct __tile *target)
{
/*
* Tile register should be stored from tmm to
* memory and compare with emulation results.
*/
int32_t rows = target->rows;
int32_t colsb = target->colsb / 4;
uint8_t *rbuf = ref->buf;
uint8_t *tbuf = target->buf;
int32_t i, j, idx;

for (i = 0; i < rows; i++)
for (j = 0; j < colsb; j++) {
idx = i * colsb + j;
if ((((float *)rbuf)[idx] - ((float *)tbuf)[idx]) > (0.5) ||
(((float *)rbuf)[idx] - ((float *)tbuf)[idx]) < (-0.5)) {
printf("Mismatch: idx=%d, ref=%f, target=%f\n", idx,
((float *)rbuf)[idx],
((float *)tbuf)[idx]);
return false;
}
}

return true;
}
#endif

/*
* check_tile_dword_register() - check calculation result.
* @ref: The result calculated by AMX/TMUL.
Expand Down Expand Up @@ -633,9 +828,9 @@ static void thread_break(int32_t reason, uint32_t thread_idx)
* @arg: The index of sub-thread.
* Index from 0 to the total number of threads - 1.
*
* Two results are generated by AMX/TMUL calcultion procedure,
* Two results are generated by AMX/TMUL calculation procedure,
* one is calculated by software, the other is calculated by TMUL.
* Interrupt the AMX/TMUL calcultion procedure by different reasons.
* Interrupt the AMX/TMUL calculation procedure by different reasons.
* These reasons may cause context-switch by Kernel.
* Check if the thread context is saved and restored correctly
* by comparing the two results.
Expand Down Expand Up @@ -663,6 +858,10 @@ static void *worker_thread(void *arg)
/* Init the test data in memory */
if (ins_type == INS_TDPBF16PS)
init_bf16_tile(ptr_tile1, ROW_NUM, COL_NUM);
#ifdef FP16
else if (ins_type == INS_TDPFP16PS)
init_fp16_tile(ptr_tile1, ROW_NUM, COL_NUM);
#endif
else
init_dword_tile(ptr_tile1, ROW_NUM, COL_NUM);

Expand All @@ -675,6 +874,12 @@ static void *worker_thread(void *arg)
calc_matrix_tdpbf16ps(ptr_tile4, ptr_tile3, ptr_tile2);
calc_matrix_tdpbf16ps(ptr_tile3, ptr_tile4, ptr_tile4);
calc_matrix_tdpbf16ps(ptr_tile2, ptr_tile3, ptr_tile4);
#ifdef FP16
} else if (ins_type == INS_TDPFP16PS) {
calc_matrix_tdpfp16ps(ptr_tile4, ptr_tile3, ptr_tile2);
calc_matrix_tdpfp16ps(ptr_tile3, ptr_tile4, ptr_tile4);
calc_matrix_tdpfp16ps(ptr_tile2, ptr_tile3, ptr_tile4);
#endif
} else if (ins_type == INS_TDPBSSD) {
calc_matrix_tdpbssd(ptr_tile4, ptr_tile3, ptr_tile2);
calc_matrix_tdpbssd(ptr_tile3, ptr_tile4, ptr_tile4);
Expand Down Expand Up @@ -719,6 +924,10 @@ static void *worker_thread(void *arg)
/* Step4: Calculate a result by TMUL and store it in TMM0 register */
if (ins_type == INS_TDPBF16PS)
tile_dpbf16ps();
#ifdef FP16
else if (ins_type == INS_TDPFP16PS)
tile_dpfp16ps();
#endif
else if (ins_type == INS_TDPBSSD)
tile_dpbssd();
else if (ins_type == INS_TDPBSUD)
Expand All @@ -743,6 +952,14 @@ static void *worker_thread(void *arg)
ins_type, thread_idx, i);
rtn = false;
}
#ifdef FP16
} else if (ins_type == INS_TDPFP16PS) {
if (!check_tile_fp16_register(ptr_tile3, ptr_tile2)) {
printf("Instruction %d test in Thread %d Cycle %d: failed\n",
ins_type, thread_idx, i);
rtn = false;
}
#endif
} else {
if (!check_tile_dword_register(ptr_tile3, ptr_tile2)) {
printf("Instruction %d test in Thread %d Cycle %d: failed\n",
Expand Down Expand Up @@ -788,7 +1005,11 @@ static void help(void)
" 5: break by futex\n"
" -t, --thread-count [Should not be less than %d]\n"
" -c, --cycle-number [Should not be less than 1]\n"
#ifdef FP16
" -i, --instruction-type [0:TDPBF16PS 1:TDPBSSD 2:TDPBSUD 3:TDPBUSD 4:TDPBUUD 5:TDPFP16PS]\n"
#else
" -i, --instruction-type [0:TDPBF16PS 1:TDPBSSD 2:TDPBSUD 3:TDPBUSD 4:TDPBUUD]\n"
#endif
, progname, progname, BREAK_BY_YIELD, BREAK_REASON_MAX, MIN_THREAD_NUM);
}

Expand Down

0 comments on commit 1106850

Please sign in to comment.