diff --git a/native/candlex/src/metal_kernels/custom_unary.metal b/native/candlex/src/metal_kernels/custom_unary.metal index 36f07f5..7f7a7dd 100644 --- a/native/candlex/src/metal_kernels/custom_unary.metal +++ b/native/candlex/src/metal_kernels/custom_unary.metal @@ -52,8 +52,11 @@ CUSTOM_UNARY(float, float, atan_f32, atan) CUSTOM_UNARY(float, float, atanh_f32, atanh) CUSTOM_UNARY(float, float, cosh_f32, cosh) CUSTOM_UNARY(float, float, sign_f32, sign) +CUSTOM_UNARY(int64_t, int64_t, sign_i64, sign) CUSTOM_UNARY(float, float, sinh_f32, sinh) CUSTOM_UNARY(float, float, tan_f32, tan) +CUSTOM_UNARY(uint64_t, uint64_t, bit_not_i64, not) +CUSTOM_UNARY(uint8_t, uint8_t, bit_not_u8, not) /* bit_not */ /* cbrt */ diff --git a/native/candlex/src/ops.rs b/native/candlex/src/ops.rs index 15f14dd..51a9ccc 100644 --- a/native/candlex/src/ops.rs +++ b/native/candlex/src/ops.rs @@ -115,6 +115,8 @@ macro_rules! custom_unary_op { if (layout.is_contiguous() && layout.start_offset() == 0) { let kernel_name = match storage.dtype() { DType::F32 => metal_kernels::custom_unary::contiguous::$name::FLOAT, + DType::I64 => metal_kernels::custom_unary::contiguous::$name::I64, + DType::U8 => metal_kernels::custom_unary::contiguous::$name::U8, dtype => { candle_core::bail!("Metal contiguous custom unary {} {dtype:?} not implemented", stringify!($name)) }