From e64cc9e9a0572a4e4238c57432e433cff1fbfda3 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Fri, 19 Jan 2024 10:21:12 -0300 Subject: [PATCH] fix --- native/candlex/src/ops.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/native/candlex/src/ops.rs b/native/candlex/src/ops.rs index a5fd4f2..0b0f72b 100644 --- a/native/candlex/src/ops.rs +++ b/native/candlex/src/ops.rs @@ -256,12 +256,13 @@ macro_rules! custom_unary_bool_op { use candle_core::{backend::BackendStorage, DType}; let device = storage.device(); + let dtype = storage.dtype(); let command_buffer = device.command_buffer()?; let elem_count = layout.shape().elem_count(); let output_buffer = device.new_buffer(elem_count, DType::U8, stringify!($name))?; if (layout.is_contiguous() && layout.start_offset() == 0) { - let kernel_name = match storage.dtype() { + let kernel_name = match dtype { DType::F32 => metal_kernels::custom_unary::contiguous::$name::FLOAT, DType::I64 => metal_kernels::custom_unary::contiguous::$name::I64, DType::U32 => metal_kernels::custom_unary::contiguous::$name::U32, @@ -280,7 +281,7 @@ macro_rules! custom_unary_bool_op { &output_buffer, ).unwrap(); } else { - let kernel_name = match storage.dtype() { + let kernel_name = match dtype { DType::F32 => metal_kernels::custom_unary::strided::$name::FLOAT, DType::I64 => metal_kernels::custom_unary::strided::$name::I64, DType::U32 => metal_kernels::custom_unary::strided::$name::U32,