From b33fbf5e114dbe588a867f6006443365a5507979 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Fri, 19 Jan 2024 12:41:31 -0300 Subject: [PATCH] metal binary half kernel --- native/candlex/src/metal_kernels/custom_binary.metal | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/native/candlex/src/metal_kernels/custom_binary.metal b/native/candlex/src/metal_kernels/custom_binary.metal index bb855bb..e7a67e9 100644 --- a/native/candlex/src/metal_kernels/custom_binary.metal +++ b/native/candlex/src/metal_kernels/custom_binary.metal @@ -49,15 +49,19 @@ kernel void FN_NAME##_strided( \ output[tid] = OUT_TYPE(FN); \ } +#define CUSTOM_BINARY_OP(FN_NAME, FN)\ +CUSTOM_BINARY(float, float, FN_NAME##_f32, FN);\ +CUSTOM_BINARY(half, half, FN_NAME##_f16, FN); + +CUSTOM_BINARY_OP(atan2, atan2(x, y)) +CUSTOM_BINARY_OP(pow, pow(x, y)) + CUSTOM_BINARY(int64_t, int64_t, bit_and_i64, x & y) CUSTOM_BINARY(int64_t, int64_t, bit_or_i64, x | y) CUSTOM_BINARY(int64_t, int64_t, bit_xor_i64, x ^ y) CUSTOM_BINARY(int64_t, int64_t, shl_i64, x << y) CUSTOM_BINARY(int64_t, int64_t, shr_i64, x >> y) -CUSTOM_BINARY(float, float, atan2_f32, atan2(x, y)) -CUSTOM_BINARY(float, float, pow_f32, pow(x, y)) - /* pow */ /* remainder */ /* shl */