diff --git a/native/candlex/src/metal_kernels/custom_binary.metal b/native/candlex/src/metal_kernels/custom_binary.metal index bb855bb..e7a67e9 100644 --- a/native/candlex/src/metal_kernels/custom_binary.metal +++ b/native/candlex/src/metal_kernels/custom_binary.metal @@ -49,15 +49,19 @@ kernel void FN_NAME##_strided( \ output[tid] = OUT_TYPE(FN); \ } +#define CUSTOM_BINARY_OP(FN_NAME, FN)\ +CUSTOM_BINARY(float, float, FN_NAME##_f32, FN);\ +CUSTOM_BINARY(half, half, FN_NAME##_f16, FN); + +CUSTOM_BINARY_OP(atan2, atan2(x, y)) +CUSTOM_BINARY_OP(pow, pow(x, y)) + CUSTOM_BINARY(int64_t, int64_t, bit_and_i64, x & y) CUSTOM_BINARY(int64_t, int64_t, bit_or_i64, x | y) CUSTOM_BINARY(int64_t, int64_t, bit_xor_i64, x ^ y) CUSTOM_BINARY(int64_t, int64_t, shl_i64, x << y) CUSTOM_BINARY(int64_t, int64_t, shr_i64, x >> y) -CUSTOM_BINARY(float, float, atan2_f32, atan2(x, y)) -CUSTOM_BINARY(float, float, pow_f32, pow(x, y)) - /* pow */ /* remainder */ /* shl */