Skip to content

Commit

Permalink
more metal ops
Browse files Browse the repository at this point in the history
  • Loading branch information
grzuy committed Jan 19, 2024
1 parent 6d52ef1 commit bbec4cf
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 38 deletions.
3 changes: 3 additions & 0 deletions native/candlex/src/metal_kernels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ macro_rules! ops {

pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_f32"));
pub const I64: Kernel = Kernel(concat!(stringify!($name), "_i64"));
pub const U32: Kernel = Kernel(concat!(stringify!($name), "_u32"));
pub const U8: Kernel = Kernel(concat!(stringify!($name), "_u8"));
}
)+
Expand All @@ -58,6 +59,8 @@ macro_rules! ops {

pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_f32_strided"));
pub const I64: Kernel = Kernel(concat!(stringify!($name), "_i64_strided"));
pub const U32: Kernel = Kernel(concat!(stringify!($name), "_u32_strided"));
pub const U8: Kernel = Kernel(concat!(stringify!($name), "_u8_strided"));
}
)+
}
Expand Down
133 changes: 95 additions & 38 deletions native/candlex/src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ macro_rules! custom_unary_op {
let kernel_name = match storage.dtype() {
DType::F32 => metal_kernels::custom_unary::contiguous::$name::FLOAT,
DType::I64 => metal_kernels::custom_unary::contiguous::$name::I64,
DType::U32 => metal_kernels::custom_unary::contiguous::$name::U32,
DType::U8 => metal_kernels::custom_unary::contiguous::$name::U8,
dtype => {
candle_core::bail!("Metal contiguous custom unary {} {dtype:?} not implemented", stringify!($name))
Expand All @@ -133,6 +134,9 @@ macro_rules! custom_unary_op {
} else {
let kernel_name = match storage.dtype() {
DType::F32 => metal_kernels::custom_unary::strided::$name::FLOAT,
DType::I64 => metal_kernels::custom_unary::strided::$name::I64,
DType::U32 => metal_kernels::custom_unary::strided::$name::U32,
DType::U8 => metal_kernels::custom_unary::strided::$name::U8,
dtype => {
candle_core::bail!("Metal strided custom unary {} {dtype:?} not implemented", stringify!($name))
}
Expand Down Expand Up @@ -251,23 +255,52 @@ macro_rules! custom_unary_bool_op {
use crate::metal_kernels;
use candle_core::{backend::BackendStorage, DType};

if !(layout.is_contiguous() && layout.start_offset() == 0) {
candle_core::bail!("Non contiguous not supported");
}

let device = storage.device();
let command_buffer = device.command_buffer()?;
let elem_count = layout.shape().elem_count();
let output_buffer = device.new_buffer(elem_count, DType::U8, stringify!($name))?;

metal_kernels::call_custom_unary_contiguous(
&device.device(),
&command_buffer,
metal_kernels::custom_unary::contiguous::$name::FLOAT,
elem_count,
storage.buffer(),
&output_buffer,
).unwrap();
if (layout.is_contiguous() && layout.start_offset() == 0) {
let kernel_name = match storage.dtype() {
DType::F32 => metal_kernels::custom_unary::contiguous::$name::FLOAT,
DType::I64 => metal_kernels::custom_unary::contiguous::$name::I64,
DType::U32 => metal_kernels::custom_unary::contiguous::$name::U32,
DType::U8 => metal_kernels::custom_unary::contiguous::$name::U8,
dtype => {
candle_core::bail!("Metal contiguous custom unary {} {dtype:?} not implemented", stringify!($name))
}
};

metal_kernels::call_custom_unary_contiguous(
&device.device(),
&command_buffer,
kernel_name,
elem_count,
storage.buffer(),
&output_buffer,
).unwrap();
} else {
let kernel_name = match storage.dtype() {
DType::F32 => metal_kernels::custom_unary::strided::$name::FLOAT,
DType::I64 => metal_kernels::custom_unary::strided::$name::I64,
DType::U32 => metal_kernels::custom_unary::strided::$name::U32,
DType::U8 => metal_kernels::custom_unary::strided::$name::U8,
dtype => {
candle_core::bail!("Metal strided custom unary {} {dtype:?} not implemented", stringify!($name))
}
};

metal_kernels::call_custom_unary_strided(
&device.device(),
&command_buffer,
kernel_name,
layout.dims(),
layout.stride(),
storage.buffer(),
layout.start_offset() * dtype.size_in_bytes(),
&output_buffer,
).unwrap();
}

Ok((MetalStorage::new(output_buffer, device.clone(), DType::U8), layout.shape().clone()))
}
Expand Down Expand Up @@ -398,6 +431,8 @@ macro_rules! custom_binary_op {
let kernel_name = match dtype {
DType::F32 => metal_kernels::custom_binary::contiguous::$name::FLOAT,
DType::I64 => metal_kernels::custom_binary::contiguous::$name::I64,
DType::U32 => metal_kernels::custom_binary::contiguous::$name::U32,
DType::U8 => metal_kernels::custom_binary::contiguous::$name::U8,
dtype => {
candle_core::bail!("Metal contiguous custom binary {} {dtype:?} not implemented", stringify!($name))
}
Expand All @@ -416,6 +451,8 @@ macro_rules! custom_binary_op {
let kernel_name = match dtype {
DType::F32 => metal_kernels::custom_binary::strided::$name::FLOAT,
DType::I64 => metal_kernels::custom_binary::strided::$name::I64,
DType::U32 => metal_kernels::custom_binary::strided::$name::U32,
DType::U8 => metal_kernels::custom_binary::strided::$name::U8,
dtype => {
candle_core::bail!("Metal strided custom binary {} {dtype:?} not implemented", stringify!($name))
}
Expand Down Expand Up @@ -560,38 +597,58 @@ macro_rules! custom_binary_bool_op {
use crate::metal_kernels;
use candle_core::{backend::BackendStorage, DType};

if !(l1.is_contiguous() && l1.start_offset() == 0) {
candle_core::bail!("Non contiguous not supported - l1");
}
if !(l2.is_contiguous() && l2.start_offset() == 0) {
candle_core::bail!("Non contiguous not supported - l2");
}

let device = s1.device();
let shape = l1.shape();
let elem_count = shape.elem_count();
let command_buffer = device.command_buffer()?;
let output_buffer = device.new_buffer(elem_count, DType::U8, stringify!($name))?;

let kernel_name = match s1.dtype() {
DType::I64 => metal_kernels::custom_binary::contiguous::$name::I64,
DType::U8 => metal_kernels::custom_binary::contiguous::$name::U8,
dtype => {
candle_core::bail!("Metal contiguous custom binary {} {dtype:?} not implemented", stringify!($name))
}
};

metal_kernels::call_custom_binary_contiguous(
&device.device(),
&command_buffer,
kernel_name,
elem_count,
&s1.buffer(),
&s2.buffer(),
&output_buffer,
).unwrap();

// command_buffer.set_label("binary");
if (l1.is_contiguous() && l1.start_offset() == 0 && l2.is_contiguous() && l2.start_offset() == 0) {
let kernel_name = match s1.dtype() {
DType::F32 => metal_kernels::custom_binary::contiguous::$name::FLOAT,
DType::I64 => metal_kernels::custom_binary::contiguous::$name::I64,
DType::U32 => metal_kernels::custom_binary::contiguous::$name::U32,
DType::U8 => metal_kernels::custom_binary::contiguous::$name::U8,
dtype => {
candle_core::bail!("Metal contiguous custom binary {} {dtype:?} not implemented", stringify!($name))
}
};

metal_kernels::call_custom_binary_contiguous(
&device.device(),
&command_buffer,
kernel_name,
elem_count,
&s1.buffer(),
&s2.buffer(),
&output_buffer,
).unwrap();
} else {
let kernel_name = match s1.dtype() {
DType::F32 => metal_kernels::custom_binary::strided::$name::FLOAT,
DType::I64 => metal_kernels::custom_binary::strided::$name::I64,
DType::U32 => metal_kernels::custom_binary::strided::$name::U32,
DType::U8 => metal_kernels::custom_binary::strided::$name::U8,
dtype => {
candle_core::bail!("Metal strided custom binary {} {dtype:?} not implemented", stringify!($name))
}
};

metal_kernels::call_custom_binary_strided(
&device.device(),
&command_buffer,
kernel_name,
l1.dims(),
&s1.buffer(),
l1.stride(),
l1.start_offset() * s1.dtype().size_in_bytes(),
&s2.buffer(),
l2.stride(),
l2.start_offset() * s2.dtype().size_in_bytes(),
&output_buffer,
).unwrap();
}

Ok((MetalStorage::new(output_buffer, device.clone(), DType::U8), l1.shape().clone()))
}
}
Expand Down

0 comments on commit bbec4cf

Please sign in to comment.