diff --git a/native/candlex/src/metal_kernels.rs b/native/candlex/src/metal_kernels.rs index 17d68e7..513fcb6 100644 --- a/native/candlex/src/metal_kernels.rs +++ b/native/candlex/src/metal_kernels.rs @@ -45,6 +45,7 @@ macro_rules! ops { pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_f32")); pub const I64: Kernel = Kernel(concat!(stringify!($name), "_i64")); + pub const U32: Kernel = Kernel(concat!(stringify!($name), "_u32")); pub const U8: Kernel = Kernel(concat!(stringify!($name), "_u8")); } )+ @@ -58,6 +59,8 @@ macro_rules! ops { pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_f32_strided")); pub const I64: Kernel = Kernel(concat!(stringify!($name), "_i64_strided")); + pub const U32: Kernel = Kernel(concat!(stringify!($name), "_u32_strided")); + pub const U8: Kernel = Kernel(concat!(stringify!($name), "_u8_strided")); } )+ } diff --git a/native/candlex/src/ops.rs b/native/candlex/src/ops.rs index 51a9ccc..a5fd4f2 100644 --- a/native/candlex/src/ops.rs +++ b/native/candlex/src/ops.rs @@ -116,6 +116,7 @@ macro_rules! custom_unary_op { let kernel_name = match storage.dtype() { DType::F32 => metal_kernels::custom_unary::contiguous::$name::FLOAT, DType::I64 => metal_kernels::custom_unary::contiguous::$name::I64, + DType::U32 => metal_kernels::custom_unary::contiguous::$name::U32, DType::U8 => metal_kernels::custom_unary::contiguous::$name::U8, dtype => { candle_core::bail!("Metal contiguous custom unary {} {dtype:?} not implemented", stringify!($name)) @@ -133,6 +134,9 @@ macro_rules! custom_unary_op { } else { let kernel_name = match storage.dtype() { DType::F32 => metal_kernels::custom_unary::strided::$name::FLOAT, + DType::I64 => metal_kernels::custom_unary::strided::$name::I64, + DType::U32 => metal_kernels::custom_unary::strided::$name::U32, + DType::U8 => metal_kernels::custom_unary::strided::$name::U8, dtype => { candle_core::bail!("Metal strided custom unary {} {dtype:?} not implemented", stringify!($name)) } @@ -251,23 +255,52 @@ macro_rules! custom_unary_bool_op { use crate::metal_kernels; use candle_core::{backend::BackendStorage, DType}; - if !(layout.is_contiguous() && layout.start_offset() == 0) { - candle_core::bail!("Non contiguous not supported"); - } - let device = storage.device(); let command_buffer = device.command_buffer()?; let elem_count = layout.shape().elem_count(); let output_buffer = device.new_buffer(elem_count, DType::U8, stringify!($name))?; - metal_kernels::call_custom_unary_contiguous( - &device.device(), - &command_buffer, - metal_kernels::custom_unary::contiguous::$name::FLOAT, - elem_count, - storage.buffer(), - &output_buffer, - ).unwrap(); + if (layout.is_contiguous() && layout.start_offset() == 0) { + let kernel_name = match storage.dtype() { + DType::F32 => metal_kernels::custom_unary::contiguous::$name::FLOAT, + DType::I64 => metal_kernels::custom_unary::contiguous::$name::I64, + DType::U32 => metal_kernels::custom_unary::contiguous::$name::U32, + DType::U8 => metal_kernels::custom_unary::contiguous::$name::U8, + dtype => { + candle_core::bail!("Metal contiguous custom unary {} {dtype:?} not implemented", stringify!($name)) + } + }; + + metal_kernels::call_custom_unary_contiguous( + &device.device(), + &command_buffer, + kernel_name, + elem_count, + storage.buffer(), + &output_buffer, + ).unwrap(); + } else { + let kernel_name = match storage.dtype() { + DType::F32 => metal_kernels::custom_unary::strided::$name::FLOAT, + DType::I64 => metal_kernels::custom_unary::strided::$name::I64, + DType::U32 => metal_kernels::custom_unary::strided::$name::U32, + DType::U8 => metal_kernels::custom_unary::strided::$name::U8, + dtype => { + candle_core::bail!("Metal strided custom unary {} {dtype:?} not implemented", stringify!($name)) + } + }; + + metal_kernels::call_custom_unary_strided( + &device.device(), + &command_buffer, + kernel_name, + layout.dims(), + layout.stride(), + storage.buffer(), + layout.start_offset() * dtype.size_in_bytes(), + &output_buffer, + ).unwrap(); + } Ok((MetalStorage::new(output_buffer, device.clone(), DType::U8), layout.shape().clone())) } @@ -398,6 +431,8 @@ macro_rules! custom_binary_op { let kernel_name = match dtype { DType::F32 => metal_kernels::custom_binary::contiguous::$name::FLOAT, DType::I64 => metal_kernels::custom_binary::contiguous::$name::I64, + DType::U32 => metal_kernels::custom_binary::contiguous::$name::U32, + DType::U8 => metal_kernels::custom_binary::contiguous::$name::U8, dtype => { candle_core::bail!("Metal contiguous custom binary {} {dtype:?} not implemented", stringify!($name)) } @@ -416,6 +451,8 @@ macro_rules! custom_binary_op { let kernel_name = match dtype { DType::F32 => metal_kernels::custom_binary::strided::$name::FLOAT, DType::I64 => metal_kernels::custom_binary::strided::$name::I64, + DType::U32 => metal_kernels::custom_binary::strided::$name::U32, + DType::U8 => metal_kernels::custom_binary::strided::$name::U8, dtype => { candle_core::bail!("Metal strided custom binary {} {dtype:?} not implemented", stringify!($name)) } @@ -560,38 +597,58 @@ macro_rules! custom_binary_bool_op { use crate::metal_kernels; use candle_core::{backend::BackendStorage, DType}; - if !(l1.is_contiguous() && l1.start_offset() == 0) { - candle_core::bail!("Non contiguous not supported - l1"); - } - if !(l2.is_contiguous() && l2.start_offset() == 0) { - candle_core::bail!("Non contiguous not supported - l2"); - } - let device = s1.device(); let shape = l1.shape(); let elem_count = shape.elem_count(); let command_buffer = device.command_buffer()?; let output_buffer = device.new_buffer(elem_count, DType::U8, stringify!($name))?; - let kernel_name = match s1.dtype() { - DType::I64 => metal_kernels::custom_binary::contiguous::$name::I64, - DType::U8 => metal_kernels::custom_binary::contiguous::$name::U8, - dtype => { - candle_core::bail!("Metal contiguous custom binary {} {dtype:?} not implemented", stringify!($name)) - } - }; - - metal_kernels::call_custom_binary_contiguous( - &device.device(), - &command_buffer, - kernel_name, - elem_count, - &s1.buffer(), - &s2.buffer(), - &output_buffer, - ).unwrap(); - - // command_buffer.set_label("binary"); + if (l1.is_contiguous() && l1.start_offset() == 0 && l2.is_contiguous() && l2.start_offset() == 0) { + let kernel_name = match s1.dtype() { + DType::F32 => metal_kernels::custom_binary::contiguous::$name::FLOAT, + DType::I64 => metal_kernels::custom_binary::contiguous::$name::I64, + DType::U32 => metal_kernels::custom_binary::contiguous::$name::U32, + DType::U8 => metal_kernels::custom_binary::contiguous::$name::U8, + dtype => { + candle_core::bail!("Metal contiguous custom binary {} {dtype:?} not implemented", stringify!($name)) + } + }; + + metal_kernels::call_custom_binary_contiguous( + &device.device(), + &command_buffer, + kernel_name, + elem_count, + &s1.buffer(), + &s2.buffer(), + &output_buffer, + ).unwrap(); + } else { + let kernel_name = match s1.dtype() { + DType::F32 => metal_kernels::custom_binary::strided::$name::FLOAT, + DType::I64 => metal_kernels::custom_binary::strided::$name::I64, + DType::U32 => metal_kernels::custom_binary::strided::$name::U32, + DType::U8 => metal_kernels::custom_binary::strided::$name::U8, + dtype => { + candle_core::bail!("Metal strided custom binary {} {dtype:?} not implemented", stringify!($name)) + } + }; + + metal_kernels::call_custom_binary_strided( + &device.device(), + &command_buffer, + kernel_name, + l1.dims(), + &s1.buffer(), + l1.stride(), + l1.start_offset() * s1.dtype().size_in_bytes(), + &s2.buffer(), + l2.stride(), + l2.start_offset() * s2.dtype().size_in_bytes(), + &output_buffer, + ).unwrap(); + } + Ok((MetalStorage::new(output_buffer, device.clone(), DType::U8), l1.shape().clone())) } }