diff --git a/native/candlex/src/metal_kernels.rs b/native/candlex/src/metal_kernels.rs index 4b93f1b..78cc0a5 100644 --- a/native/candlex/src/metal_kernels.rs +++ b/native/candlex/src/metal_kernels.rs @@ -49,6 +49,17 @@ macro_rules! ops { } )+ } + pub mod strided { + pub struct Kernel(pub &'static str); + + $( + pub mod $name { + use super::Kernel; + + pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_f32_strided")); + } + )+ + } } } @@ -167,6 +178,62 @@ pub fn call_custom_unary_contiguous( Ok(()) } +pub fn call_custom_unary_strided( + device: &Device, + command_buffer: &CommandBufferRef, + kernel_name: custom_unary::strided::Kernel, + shape: &[usize], + strides: &[usize], + input_buffer: &Buffer, + input_offset: usize, + output_buffer: &Buffer, +) -> Result<(), MetalKernelError> { + let pipeline = + CustomKernels::new().load_pipeline(device, Source::CustomUnary, kernel_name.0)?; + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + + let num_dims: usize = shape.len(); + let length: usize = shape.iter().product(); + + encoder.set_bytes( + 0, + core::mem::size_of::() as u64, + &length as *const usize as *const c_void, + ); + encoder.set_bytes( + 1, + core::mem::size_of::() as u64, + &num_dims as *const usize as *const c_void, + ); + + encoder.set_bytes( + 2, + core::mem::size_of_val(shape) as u64, + shape.as_ptr() as *const c_void, + ); + + encoder.set_bytes( + 3, + core::mem::size_of_val(strides) as u64, + strides.as_ptr() as *const c_void, + ); + + encoder.set_buffer(4, Some(input_buffer), input_offset as u64); + encoder.set_buffer(5, Some(output_buffer), 0); + + encoder.use_resource(input_buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(output_buffer, metal::MTLResourceUsage::Write); + + let width: usize = shape.iter().product(); + let (thread_group_count, thread_group_size) = linear_split(&pipeline, width); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + + encoder.end_encoding(); + + Ok(()) +} + pub fn call_custom_binary_contiguous( device: &Device, command_buffer: &CommandBufferRef, @@ -202,3 +269,71 @@ pub fn call_custom_binary_contiguous( Ok(()) } + +pub fn call_custom_binary_strided( + device: &Device, + command_buffer: &CommandBufferRef, + kernel_name: custom_binary::strided::Kernel, + shape: &[usize], + left_buffer: &Buffer, + left_strides: &[usize], + left_offset: usize, + right_buffer: &Buffer, + right_strides: &[usize], + right_offset: usize, + output_buffer: &Buffer, +) -> Result<(), MetalKernelError> { + let pipeline = + CustomKernels::new().load_pipeline(device, Source::CustomBinary, kernel_name.0)?; + + let num_dims: usize = shape.len(); + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + + let length: usize = shape.iter().product(); + + encoder.set_bytes( + 0, + core::mem::size_of::() as u64, + &length as *const usize as *const c_void, + ); + encoder.set_bytes( + 1, + core::mem::size_of::() as u64, + &num_dims as *const usize as *const c_void, + ); + + encoder.set_bytes( + 2, + core::mem::size_of_val(shape) as u64, + shape.as_ptr() as *const c_void, + ); + + encoder.set_bytes( + 3, + core::mem::size_of_val(left_strides) as u64, + left_strides.as_ptr() as *const c_void, + ); + + encoder.set_bytes( + 4, + core::mem::size_of_val(right_strides) as u64, + right_strides.as_ptr() as *const c_void, + ); + + encoder.set_buffer(5, Some(left_buffer), left_offset as u64); + encoder.set_buffer(6, Some(right_buffer), right_offset as u64); + encoder.set_buffer(7, Some(output_buffer), 0); + + encoder.use_resource(left_buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(right_buffer, metal::MTLResourceUsage::Read); + encoder.use_resource(output_buffer, metal::MTLResourceUsage::Write); + + let width: usize = shape.iter().product(); + let (thread_group_count, thread_group_size) = linear_split(&pipeline, width); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + + encoder.end_encoding(); + + Ok(()) +} diff --git a/native/candlex/src/metal_kernels/custom_binary.metal b/native/candlex/src/metal_kernels/custom_binary.metal index e6c89e9..9ef34ee 100644 --- a/native/candlex/src/metal_kernels/custom_binary.metal +++ b/native/candlex/src/metal_kernels/custom_binary.metal @@ -1,5 +1,20 @@ using namespace metal; +METAL_FUNC uint get_strided_index( + uint idx, + constant size_t &num_dims, + constant size_t *dims, + constant size_t *strides +) { + uint strided_i = 0; + for (uint d = 0; d < num_dims; d++) { + uint dim_idx = num_dims - 1 - d; + strided_i += (idx % dims[dim_idx]) * strides[dim_idx]; + idx /= dims[dim_idx]; + } + return strided_i; +} + #define CUSTOM_BINARY(IN_TYPE, OUT_TYPE, FN_NAME, FN) \ kernel void FN_NAME( \ constant size_t &dim, \ @@ -14,6 +29,24 @@ kernel void FN_NAME( \ IN_TYPE x = left[tid]; \ IN_TYPE y = right[tid]; \ output[tid] = OUT_TYPE(FN); \ +}\ +kernel void FN_NAME##_strided( \ + constant size_t &dim, \ + constant size_t &num_dims, \ + constant size_t *dims, \ + constant size_t *left_strides, \ + constant size_t *right_strides, \ + device const IN_TYPE *left, \ + device const IN_TYPE *right, \ + device OUT_TYPE *output, \ + uint tid [[ thread_position_in_grid ]] \ +) { \ + if (tid >= dim) { \ + return; \ + } \ + IN_TYPE x = left[get_strided_index(tid, num_dims, dims, left_strides)]; \ + IN_TYPE y = right[get_strided_index(tid, num_dims, dims, right_strides)]; \ + output[tid] = OUT_TYPE(FN); \ } CUSTOM_BINARY(int64_t, int64_t, bit_and_i64, x & y) @@ -21,6 +54,7 @@ CUSTOM_BINARY(int64_t, int64_t, bit_or_i64, x | y) CUSTOM_BINARY(int64_t, int64_t, bit_xor_i64, x ^ y) CUSTOM_BINARY(float, float, atan2_f32, atan2(x, y)) +CUSTOM_BINARY(float, float, pow_f32, pow(x, y)) /* pow */ /* remainder */ diff --git a/native/candlex/src/metal_kernels/custom_unary.metal b/native/candlex/src/metal_kernels/custom_unary.metal index e9e424e..36f07f5 100644 --- a/native/candlex/src/metal_kernels/custom_unary.metal +++ b/native/candlex/src/metal_kernels/custom_unary.metal @@ -2,6 +2,21 @@ using namespace metal; +METAL_FUNC uint get_strided_index( + uint idx, + constant size_t &num_dims, + constant size_t *dims, + constant size_t *strides +) { + uint strided_i = 0; + for (uint d = 0; d < num_dims; d++) { + uint dim_idx = num_dims - 1 - d; + strided_i += (idx % dims[dim_idx]) * strides[dim_idx]; + idx /= dims[dim_idx]; + } + return strided_i; +} + #define CUSTOM_UNARY(IN_TYPE, OUT_TYPE, FN_NAME, FN) \ kernel void FN_NAME( \ constant size_t &dim, \ @@ -13,6 +28,20 @@ kernel void FN_NAME( \ return; \ } \ output[tid] = OUT_TYPE(FN(IN_TYPE(input[tid]))); \ +}\ +kernel void FN_NAME##_strided( \ + constant size_t &dim, \ + constant size_t &num_dims, \ + constant size_t *dims, \ + constant size_t *strides, \ + device const IN_TYPE *input, \ + device OUT_TYPE *output, \ + uint tid [[ thread_position_in_grid ]] \ +) { \ + if (tid >= dim) { \ + return; \ + } \ + output[tid] = OUT_TYPE(FN(IN_TYPE(input[get_strided_index(tid, num_dims, dims, strides)]))); \ } CUSTOM_UNARY(float, float, acos_f32, acos) diff --git a/native/candlex/src/ops.rs b/native/candlex/src/ops.rs index dfbd457..40e17ac 100644 --- a/native/candlex/src/ops.rs +++ b/native/candlex/src/ops.rs @@ -106,30 +106,47 @@ macro_rules! custom_unary_op { use crate::metal_kernels; use candle_core::{backend::BackendStorage, DType}; - if !(layout.is_contiguous() && layout.start_offset() == 0) { - candle_core::bail!("Non contiguous not supported"); - } - let device = storage.device(); let command_buffer = device.command_buffer()?; let elem_count = layout.shape().elem_count(); let dtype = storage.dtype(); let output_buffer = device.new_buffer(elem_count, dtype, stringify!($name))?; - let kernel_name = match storage.dtype() { - DType::F32 => metal_kernels::custom_unary::contiguous::$name::FLOAT, - dtype => { - candle_core::bail!("Metal contiguous custom unary $name {dtype:?} not implemented") - } - }; - metal_kernels::call_custom_unary_contiguous( - &device.device(), - &command_buffer, - kernel_name, - elem_count, - storage.buffer(), - &output_buffer, - ).unwrap(); + if (layout.is_contiguous() && layout.start_offset() == 0) { + let kernel_name = match storage.dtype() { + DType::F32 => metal_kernels::custom_unary::contiguous::$name::FLOAT, + dtype => { + candle_core::bail!("Metal contiguous custom unary $name {dtype:?} not implemented") + } + }; + + metal_kernels::call_custom_unary_contiguous( + &device.device(), + &command_buffer, + kernel_name, + elem_count, + storage.buffer(), + &output_buffer, + ).unwrap(); + } else { + let kernel_name = match storage.dtype() { + DType::F32 => metal_kernels::custom_unary::strided::$name::FLOAT, + dtype => { + candle_core::bail!("Metal strided custom unary $name {dtype:?} not implemented") + } + }; + + metal_kernels::call_custom_unary_strided( + &device.device(), + &command_buffer, + kernel_name, + layout.dims(), + layout.stride(), + storage.buffer(), + layout.start_offset() * dtype.size_in_bytes(), + &output_buffer, + ).unwrap(); + } Ok((MetalStorage::new(output_buffer, device.clone(), dtype), layout.shape().clone())) } @@ -368,13 +385,6 @@ macro_rules! custom_binary_op { use crate::metal_kernels; use candle_core::{backend::BackendStorage, DType}; - if !(l1.is_contiguous() && l1.start_offset() == 0) { - candle_core::bail!("Non contiguous not supported - l1"); - } - if !(l2.is_contiguous() && l2.start_offset() == 0) { - candle_core::bail!("Non contiguous not supported - l2"); - } - let device = s1.device(); let dtype = s1.dtype(); let shape = l1.shape(); @@ -382,23 +392,46 @@ macro_rules! custom_binary_op { let command_buffer = device.command_buffer()?; let output_buffer = device.new_buffer(elem_count, dtype, stringify!($name))?; - let kernel_name = match dtype { - DType::F32 => metal_kernels::custom_binary::contiguous::$name::FLOAT, - DType::I64 => metal_kernels::custom_binary::contiguous::$name::I64, - dtype => { - candle_core::bail!("Metal contiguous custom binary $name {dtype:?} not implemented") - } - }; - - metal_kernels::call_custom_binary_contiguous( - &device.device(), - &command_buffer, - kernel_name, - elem_count, - &s1.buffer(), - &s2.buffer(), - &output_buffer, - ).unwrap(); + if (l1.is_contiguous() && l1.start_offset() == 0 && l2.is_contiguous() && l2.start_offset() == 0) { + let kernel_name = match dtype { + DType::F32 => metal_kernels::custom_binary::contiguous::$name::FLOAT, + DType::I64 => metal_kernels::custom_binary::contiguous::$name::I64, + dtype => { + candle_core::bail!("Metal contiguous custom binary $name {dtype:?} not implemented") + } + }; + + metal_kernels::call_custom_binary_contiguous( + &device.device(), + &command_buffer, + kernel_name, + elem_count, + &s1.buffer(), + &s2.buffer(), + &output_buffer, + ).unwrap(); + } else { + let kernel_name = match dtype { + DType::F32 => metal_kernels::custom_binary::strided::$name::FLOAT, + dtype => { + candle_core::bail!("Metal strided custom binary $name {dtype:?} not implemented") + } + }; + + metal_kernels::call_custom_binary_strided( + &device.device(), + &command_buffer, + kernel_name, + l1.dims(), + &s1.buffer(), + l1.stride(), + l1.start_offset() * s1.dtype().size_in_bytes(), + &s2.buffer(), + l2.stride(), + l2.start_offset() * s2.dtype().size_in_bytes(), + &output_buffer, + ).unwrap(); + } Ok((MetalStorage::new(output_buffer, device.clone(), dtype), l1.shape().clone())) } diff --git a/test/candlex_test.exs b/test/candlex_test.exs index ff44501..ff33535 100644 --- a/test/candlex_test.exs +++ b/test/candlex_test.exs @@ -1428,11 +1428,11 @@ defmodule CandlexTest do t([1.0, 2.0, 3.0]) |> Nx.pow(2) - |> assert_equal(t([1.0, 4.0, 9.0])) + |> assert_close(t([1.0, 4.0, 9.0])) 2 |> Nx.pow(t([1.0, 2.0, 3.0])) - |> assert_equal(t([2.0, 4.0, 8.0])) + |> assert_close(t([2.0, 4.0, 8.0])) # t([[2], [3]]) # |> Nx.pow(t([[4, 5]]))