diff --git a/native/candlex/src/ops.rs b/native/candlex/src/ops.rs index a5fd4f2..0b0f72b 100644 --- a/native/candlex/src/ops.rs +++ b/native/candlex/src/ops.rs @@ -256,12 +256,13 @@ macro_rules! custom_unary_bool_op { use candle_core::{backend::BackendStorage, DType}; let device = storage.device(); + let dtype = storage.dtype(); let command_buffer = device.command_buffer()?; let elem_count = layout.shape().elem_count(); let output_buffer = device.new_buffer(elem_count, DType::U8, stringify!($name))?; if (layout.is_contiguous() && layout.start_offset() == 0) { - let kernel_name = match storage.dtype() { + let kernel_name = match dtype { DType::F32 => metal_kernels::custom_unary::contiguous::$name::FLOAT, DType::I64 => metal_kernels::custom_unary::contiguous::$name::I64, DType::U32 => metal_kernels::custom_unary::contiguous::$name::U32, @@ -280,7 +281,7 @@ macro_rules! custom_unary_bool_op { &output_buffer, ).unwrap(); } else { - let kernel_name = match storage.dtype() { + let kernel_name = match dtype { DType::F32 => metal_kernels::custom_unary::strided::$name::FLOAT, DType::I64 => metal_kernels::custom_unary::strided::$name::I64, DType::U32 => metal_kernels::custom_unary::strided::$name::U32,