Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
grzuy committed Jan 19, 2024
1 parent bbec4cf commit e64cc9e
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions native/candlex/src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -256,12 +256,13 @@ macro_rules! custom_unary_bool_op {
use candle_core::{backend::BackendStorage, DType};

let device = storage.device();
let dtype = storage.dtype();
let command_buffer = device.command_buffer()?;
let elem_count = layout.shape().elem_count();
let output_buffer = device.new_buffer(elem_count, DType::U8, stringify!($name))?;

if (layout.is_contiguous() && layout.start_offset() == 0) {
let kernel_name = match storage.dtype() {
let kernel_name = match dtype {
DType::F32 => metal_kernels::custom_unary::contiguous::$name::FLOAT,
DType::I64 => metal_kernels::custom_unary::contiguous::$name::I64,
DType::U32 => metal_kernels::custom_unary::contiguous::$name::U32,
Expand All @@ -280,7 +281,7 @@ macro_rules! custom_unary_bool_op {
&output_buffer,
).unwrap();
} else {
let kernel_name = match storage.dtype() {
let kernel_name = match dtype {
DType::F32 => metal_kernels::custom_unary::strided::$name::FLOAT,
DType::I64 => metal_kernels::custom_unary::strided::$name::I64,
DType::U32 => metal_kernels::custom_unary::strided::$name::U32,
Expand Down

0 comments on commit e64cc9e

Please sign in to comment.