Skip to content

Commit

Permalink
feat: Metal strided kernels infra (#113)
Browse files Browse the repository at this point in the history
  • Loading branch information
grzuy authored Jan 18, 2024
1 parent b35a285 commit 4c2caef
Show file tree
Hide file tree
Showing 5 changed files with 275 additions and 44 deletions.
135 changes: 135 additions & 0 deletions native/candlex/src/metal_kernels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,17 @@ macro_rules! ops {
}
)+
}
pub mod strided {
pub struct Kernel(pub &'static str);

$(
pub mod $name {
use super::Kernel;

pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_f32_strided"));
}
)+
}
}
}

Expand Down Expand Up @@ -167,6 +178,62 @@ pub fn call_custom_unary_contiguous(
Ok(())
}

pub fn call_custom_unary_strided(
device: &Device,
command_buffer: &CommandBufferRef,
kernel_name: custom_unary::strided::Kernel,
shape: &[usize],
strides: &[usize],
input_buffer: &Buffer,
input_offset: usize,
output_buffer: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline =
CustomKernels::new().load_pipeline(device, Source::CustomUnary, kernel_name.0)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);

let num_dims: usize = shape.len();
let length: usize = shape.iter().product();

encoder.set_bytes(
0,
core::mem::size_of::<usize>() as u64,
&length as *const usize as *const c_void,
);
encoder.set_bytes(
1,
core::mem::size_of::<usize>() as u64,
&num_dims as *const usize as *const c_void,
);

encoder.set_bytes(
2,
core::mem::size_of_val(shape) as u64,
shape.as_ptr() as *const c_void,
);

encoder.set_bytes(
3,
core::mem::size_of_val(strides) as u64,
strides.as_ptr() as *const c_void,
);

encoder.set_buffer(4, Some(input_buffer), input_offset as u64);
encoder.set_buffer(5, Some(output_buffer), 0);

encoder.use_resource(input_buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output_buffer, metal::MTLResourceUsage::Write);

let width: usize = shape.iter().product();
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);

encoder.end_encoding();

Ok(())
}

pub fn call_custom_binary_contiguous(
device: &Device,
command_buffer: &CommandBufferRef,
Expand Down Expand Up @@ -202,3 +269,71 @@ pub fn call_custom_binary_contiguous(

Ok(())
}

pub fn call_custom_binary_strided(
device: &Device,
command_buffer: &CommandBufferRef,
kernel_name: custom_binary::strided::Kernel,
shape: &[usize],
left_buffer: &Buffer,
left_strides: &[usize],
left_offset: usize,
right_buffer: &Buffer,
right_strides: &[usize],
right_offset: usize,
output_buffer: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline =
CustomKernels::new().load_pipeline(device, Source::CustomBinary, kernel_name.0)?;

let num_dims: usize = shape.len();
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);

let length: usize = shape.iter().product();

encoder.set_bytes(
0,
core::mem::size_of::<usize>() as u64,
&length as *const usize as *const c_void,
);
encoder.set_bytes(
1,
core::mem::size_of::<usize>() as u64,
&num_dims as *const usize as *const c_void,
);

encoder.set_bytes(
2,
core::mem::size_of_val(shape) as u64,
shape.as_ptr() as *const c_void,
);

encoder.set_bytes(
3,
core::mem::size_of_val(left_strides) as u64,
left_strides.as_ptr() as *const c_void,
);

encoder.set_bytes(
4,
core::mem::size_of_val(right_strides) as u64,
right_strides.as_ptr() as *const c_void,
);

encoder.set_buffer(5, Some(left_buffer), left_offset as u64);
encoder.set_buffer(6, Some(right_buffer), right_offset as u64);
encoder.set_buffer(7, Some(output_buffer), 0);

encoder.use_resource(left_buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(right_buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output_buffer, metal::MTLResourceUsage::Write);

let width: usize = shape.iter().product();
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);

encoder.end_encoding();

Ok(())
}
34 changes: 34 additions & 0 deletions native/candlex/src/metal_kernels/custom_binary.metal
Original file line number Diff line number Diff line change
@@ -1,5 +1,20 @@
using namespace metal;

METAL_FUNC uint get_strided_index(
uint idx,
constant size_t &num_dims,
constant size_t *dims,
constant size_t *strides
) {
uint strided_i = 0;
for (uint d = 0; d < num_dims; d++) {
uint dim_idx = num_dims - 1 - d;
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
idx /= dims[dim_idx];
}
return strided_i;
}

#define CUSTOM_BINARY(IN_TYPE, OUT_TYPE, FN_NAME, FN) \
kernel void FN_NAME( \
constant size_t &dim, \
Expand All @@ -14,13 +29,32 @@ kernel void FN_NAME( \
IN_TYPE x = left[tid]; \
IN_TYPE y = right[tid]; \
output[tid] = OUT_TYPE(FN); \
}\
kernel void FN_NAME##_strided( \
constant size_t &dim, \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *left_strides, \
constant size_t *right_strides, \
device const IN_TYPE *left, \
device const IN_TYPE *right, \
device OUT_TYPE *output, \
uint tid [[ thread_position_in_grid ]] \
) { \
if (tid >= dim) { \
return; \
} \
IN_TYPE x = left[get_strided_index(tid, num_dims, dims, left_strides)]; \
IN_TYPE y = right[get_strided_index(tid, num_dims, dims, right_strides)]; \
output[tid] = OUT_TYPE(FN); \
}

CUSTOM_BINARY(int64_t, int64_t, bit_and_i64, x & y)
CUSTOM_BINARY(int64_t, int64_t, bit_or_i64, x | y)
CUSTOM_BINARY(int64_t, int64_t, bit_xor_i64, x ^ y)

CUSTOM_BINARY(float, float, atan2_f32, atan2(x, y))
CUSTOM_BINARY(float, float, pow_f32, pow(x, y))

/* pow */
/* remainder */
Expand Down
29 changes: 29 additions & 0 deletions native/candlex/src/metal_kernels/custom_unary.metal
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,21 @@

using namespace metal;

METAL_FUNC uint get_strided_index(
uint idx,
constant size_t &num_dims,
constant size_t *dims,
constant size_t *strides
) {
uint strided_i = 0;
for (uint d = 0; d < num_dims; d++) {
uint dim_idx = num_dims - 1 - d;
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
idx /= dims[dim_idx];
}
return strided_i;
}

#define CUSTOM_UNARY(IN_TYPE, OUT_TYPE, FN_NAME, FN) \
kernel void FN_NAME( \
constant size_t &dim, \
Expand All @@ -13,6 +28,20 @@ kernel void FN_NAME( \
return; \
} \
output[tid] = OUT_TYPE(FN(IN_TYPE(input[tid]))); \
}\
kernel void FN_NAME##_strided( \
constant size_t &dim, \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *strides, \
device const IN_TYPE *input, \
device OUT_TYPE *output, \
uint tid [[ thread_position_in_grid ]] \
) { \
if (tid >= dim) { \
return; \
} \
output[tid] = OUT_TYPE(FN(IN_TYPE(input[get_strided_index(tid, num_dims, dims, strides)]))); \
}

CUSTOM_UNARY(float, float, acos_f32, acos)
Expand Down
117 changes: 75 additions & 42 deletions native/candlex/src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,30 +106,47 @@ macro_rules! custom_unary_op {
use crate::metal_kernels;
use candle_core::{backend::BackendStorage, DType};

if !(layout.is_contiguous() && layout.start_offset() == 0) {
candle_core::bail!("Non contiguous not supported");
}

let device = storage.device();
let command_buffer = device.command_buffer()?;
let elem_count = layout.shape().elem_count();
let dtype = storage.dtype();
let output_buffer = device.new_buffer(elem_count, dtype, stringify!($name))?;

let kernel_name = match storage.dtype() {
DType::F32 => metal_kernels::custom_unary::contiguous::$name::FLOAT,
dtype => {
candle_core::bail!("Metal contiguous custom unary $name {dtype:?} not implemented")
}
};
metal_kernels::call_custom_unary_contiguous(
&device.device(),
&command_buffer,
kernel_name,
elem_count,
storage.buffer(),
&output_buffer,
).unwrap();
if (layout.is_contiguous() && layout.start_offset() == 0) {
let kernel_name = match storage.dtype() {
DType::F32 => metal_kernels::custom_unary::contiguous::$name::FLOAT,
dtype => {
candle_core::bail!("Metal contiguous custom unary $name {dtype:?} not implemented")
}
};

metal_kernels::call_custom_unary_contiguous(
&device.device(),
&command_buffer,
kernel_name,
elem_count,
storage.buffer(),
&output_buffer,
).unwrap();
} else {
let kernel_name = match storage.dtype() {
DType::F32 => metal_kernels::custom_unary::strided::$name::FLOAT,
dtype => {
candle_core::bail!("Metal strided custom unary $name {dtype:?} not implemented")
}
};

metal_kernels::call_custom_unary_strided(
&device.device(),
&command_buffer,
kernel_name,
layout.dims(),
layout.stride(),
storage.buffer(),
layout.start_offset() * dtype.size_in_bytes(),
&output_buffer,
).unwrap();
}

Ok((MetalStorage::new(output_buffer, device.clone(), dtype), layout.shape().clone()))
}
Expand Down Expand Up @@ -368,37 +385,53 @@ macro_rules! custom_binary_op {
use crate::metal_kernels;
use candle_core::{backend::BackendStorage, DType};

if !(l1.is_contiguous() && l1.start_offset() == 0) {
candle_core::bail!("Non contiguous not supported - l1");
}
if !(l2.is_contiguous() && l2.start_offset() == 0) {
candle_core::bail!("Non contiguous not supported - l2");
}

let device = s1.device();
let dtype = s1.dtype();
let shape = l1.shape();
let elem_count = shape.elem_count();
let command_buffer = device.command_buffer()?;
let output_buffer = device.new_buffer(elem_count, dtype, stringify!($name))?;

let kernel_name = match dtype {
DType::F32 => metal_kernels::custom_binary::contiguous::$name::FLOAT,
DType::I64 => metal_kernels::custom_binary::contiguous::$name::I64,
dtype => {
candle_core::bail!("Metal contiguous custom binary $name {dtype:?} not implemented")
}
};

metal_kernels::call_custom_binary_contiguous(
&device.device(),
&command_buffer,
kernel_name,
elem_count,
&s1.buffer(),
&s2.buffer(),
&output_buffer,
).unwrap();
if (l1.is_contiguous() && l1.start_offset() == 0 && l2.is_contiguous() && l2.start_offset() == 0) {
let kernel_name = match dtype {
DType::F32 => metal_kernels::custom_binary::contiguous::$name::FLOAT,
DType::I64 => metal_kernels::custom_binary::contiguous::$name::I64,
dtype => {
candle_core::bail!("Metal contiguous custom binary $name {dtype:?} not implemented")
}
};

metal_kernels::call_custom_binary_contiguous(
&device.device(),
&command_buffer,
kernel_name,
elem_count,
&s1.buffer(),
&s2.buffer(),
&output_buffer,
).unwrap();
} else {
let kernel_name = match dtype {
DType::F32 => metal_kernels::custom_binary::strided::$name::FLOAT,
dtype => {
candle_core::bail!("Metal strided custom binary $name {dtype:?} not implemented")
}
};

metal_kernels::call_custom_binary_strided(
&device.device(),
&command_buffer,
kernel_name,
l1.dims(),
&s1.buffer(),
l1.stride(),
l1.start_offset() * s1.dtype().size_in_bytes(),
&s2.buffer(),
l2.stride(),
l2.start_offset() * s2.dtype().size_in_bytes(),
&output_buffer,
).unwrap();
}

Ok((MetalStorage::new(output_buffer, device.clone(), dtype), l1.shape().clone()))
}
Expand Down
Loading

0 comments on commit 4c2caef

Please sign in to comment.