diff --git a/native/candlex/src/metal_kernels/custom_unary.metal b/native/candlex/src/metal_kernels/custom_unary.metal index 7f7a7dd..1285d7f 100644 --- a/native/candlex/src/metal_kernels/custom_unary.metal +++ b/native/candlex/src/metal_kernels/custom_unary.metal @@ -44,17 +44,29 @@ kernel void FN_NAME##_strided( \ output[tid] = OUT_TYPE(FN(IN_TYPE(input[get_strided_index(tid, num_dims, dims, strides)]))); \ } -CUSTOM_UNARY(float, float, acos_f32, acos) -CUSTOM_UNARY(float, float, acosh_f32, acosh) -CUSTOM_UNARY(float, float, asin_f32, asin) -CUSTOM_UNARY(float, float, asinh_f32, asinh) -CUSTOM_UNARY(float, float, atan_f32, atan) -CUSTOM_UNARY(float, float, atanh_f32, atanh) -CUSTOM_UNARY(float, float, cosh_f32, cosh) -CUSTOM_UNARY(float, float, sign_f32, sign) +#define CUSTOM_UNARY_OP(FN_NAME, FN) \ +CUSTOM_UNARY(float, float, FN_NAME##_f32, FN);\ +CUSTOM_UNARY(half, half, FN_NAME##_f16, FN); + +#define CUSTOM_UNARY_BOOL_OP(FN_NAME, FN) \ +CUSTOM_UNARY(float, uint8_t, FN_NAME##_f32, FN);\ +CUSTOM_UNARY(half, uint8_t, FN_NAME##_f16, FN); + +CUSTOM_UNARY_OP(acos, acos) +CUSTOM_UNARY_OP(acosh, acosh) +CUSTOM_UNARY_OP(asin, asin) +CUSTOM_UNARY_OP(asinh, asinh) +CUSTOM_UNARY_OP(atan, atan) +CUSTOM_UNARY_OP(atanh, atanh) +CUSTOM_UNARY_OP(cosh, cosh) +CUSTOM_UNARY_OP(sign, sign) +CUSTOM_UNARY_OP(sinh, sinh) +CUSTOM_UNARY_OP(tan, tan) + +CUSTOM_UNARY_BOOL_OP(is_inf, isinf) +CUSTOM_UNARY_BOOL_OP(is_nan, isnan) + CUSTOM_UNARY(int64_t, int64_t, sign_i64, sign) -CUSTOM_UNARY(float, float, sinh_f32, sinh) -CUSTOM_UNARY(float, float, tan_f32, tan) CUSTOM_UNARY(uint64_t, uint64_t, bit_not_i64, not) CUSTOM_UNARY(uint8_t, uint8_t, bit_not_u8, not) @@ -65,6 +77,3 @@ CUSTOM_UNARY(uint8_t, uint8_t, bit_not_u8, not) /* expm1 */ /* ln_1p */ /* sigmoid */ - -CUSTOM_UNARY(float, uint8_t, is_inf_f32, isinf) -CUSTOM_UNARY(float, uint8_t, is_nan_f32, isnan)