Skip to content

Commit

Permalink
half ops calls
Browse files Browse the repository at this point in the history
  • Loading branch information
grzuy committed Jan 19, 2024
1 parent e64cc9e commit 86e7a65
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 0 deletions.
2 changes: 2 additions & 0 deletions native/candlex/src/metal_kernels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ macro_rules! ops {
use super::Kernel;

pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_f32"));
pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16"));
pub const I64: Kernel = Kernel(concat!(stringify!($name), "_i64"));
pub const U32: Kernel = Kernel(concat!(stringify!($name), "_u32"));
pub const U8: Kernel = Kernel(concat!(stringify!($name), "_u8"));
Expand All @@ -58,6 +59,7 @@ macro_rules! ops {
use super::Kernel;

pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_f32_strided"));
pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16_strided"));
pub const I64: Kernel = Kernel(concat!(stringify!($name), "_i64_strided"));
pub const U32: Kernel = Kernel(concat!(stringify!($name), "_u32_strided"));
pub const U8: Kernel = Kernel(concat!(stringify!($name), "_u8_strided"));
Expand Down
8 changes: 8 additions & 0 deletions native/candlex/src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ macro_rules! custom_unary_op {
if (layout.is_contiguous() && layout.start_offset() == 0) {
let kernel_name = match storage.dtype() {
DType::F32 => metal_kernels::custom_unary::contiguous::$name::FLOAT,
DType::F16 => metal_kernels::custom_unary::contiguous::$name::HALF,
DType::I64 => metal_kernels::custom_unary::contiguous::$name::I64,
DType::U32 => metal_kernels::custom_unary::contiguous::$name::U32,
DType::U8 => metal_kernels::custom_unary::contiguous::$name::U8,
Expand All @@ -134,6 +135,7 @@ macro_rules! custom_unary_op {
} else {
let kernel_name = match storage.dtype() {
DType::F32 => metal_kernels::custom_unary::strided::$name::FLOAT,
DType::F16 => metal_kernels::custom_unary::strided::$name::HALF,
DType::I64 => metal_kernels::custom_unary::strided::$name::I64,
DType::U32 => metal_kernels::custom_unary::strided::$name::U32,
DType::U8 => metal_kernels::custom_unary::strided::$name::U8,
Expand Down Expand Up @@ -264,6 +266,7 @@ macro_rules! custom_unary_bool_op {
if (layout.is_contiguous() && layout.start_offset() == 0) {
let kernel_name = match dtype {
DType::F32 => metal_kernels::custom_unary::contiguous::$name::FLOAT,
DType::F16 => metal_kernels::custom_unary::contiguous::$name::HALF,
DType::I64 => metal_kernels::custom_unary::contiguous::$name::I64,
DType::U32 => metal_kernels::custom_unary::contiguous::$name::U32,
DType::U8 => metal_kernels::custom_unary::contiguous::$name::U8,
Expand All @@ -283,6 +286,7 @@ macro_rules! custom_unary_bool_op {
} else {
let kernel_name = match dtype {
DType::F32 => metal_kernels::custom_unary::strided::$name::FLOAT,
DType::F16 => metal_kernels::custom_unary::strided::$name::HALF,
DType::I64 => metal_kernels::custom_unary::strided::$name::I64,
DType::U32 => metal_kernels::custom_unary::strided::$name::U32,
DType::U8 => metal_kernels::custom_unary::strided::$name::U8,
Expand Down Expand Up @@ -431,6 +435,7 @@ macro_rules! custom_binary_op {
if (l1.is_contiguous() && l1.start_offset() == 0 && l2.is_contiguous() && l2.start_offset() == 0) {
let kernel_name = match dtype {
DType::F32 => metal_kernels::custom_binary::contiguous::$name::FLOAT,
DType::F16 => metal_kernels::custom_binary::contiguous::$name::HALF,
DType::I64 => metal_kernels::custom_binary::contiguous::$name::I64,
DType::U32 => metal_kernels::custom_binary::contiguous::$name::U32,
DType::U8 => metal_kernels::custom_binary::contiguous::$name::U8,
Expand All @@ -451,6 +456,7 @@ macro_rules! custom_binary_op {
} else {
let kernel_name = match dtype {
DType::F32 => metal_kernels::custom_binary::strided::$name::FLOAT,
DType::F16 => metal_kernels::custom_binary::strided::$name::HALF,
DType::I64 => metal_kernels::custom_binary::strided::$name::I64,
DType::U32 => metal_kernels::custom_binary::strided::$name::U32,
DType::U8 => metal_kernels::custom_binary::strided::$name::U8,
Expand Down Expand Up @@ -607,6 +613,7 @@ macro_rules! custom_binary_bool_op {
if (l1.is_contiguous() && l1.start_offset() == 0 && l2.is_contiguous() && l2.start_offset() == 0) {
let kernel_name = match s1.dtype() {
DType::F32 => metal_kernels::custom_binary::contiguous::$name::FLOAT,
DType::F16 => metal_kernels::custom_binary::contiguous::$name::HALF,
DType::I64 => metal_kernels::custom_binary::contiguous::$name::I64,
DType::U32 => metal_kernels::custom_binary::contiguous::$name::U32,
DType::U8 => metal_kernels::custom_binary::contiguous::$name::U8,
Expand All @@ -627,6 +634,7 @@ macro_rules! custom_binary_bool_op {
} else {
let kernel_name = match s1.dtype() {
DType::F32 => metal_kernels::custom_binary::strided::$name::FLOAT,
DType::F16 => metal_kernels::custom_binary::strided::$name::HALF,
DType::I64 => metal_kernels::custom_binary::strided::$name::I64,
DType::U32 => metal_kernels::custom_binary::strided::$name::U32,
DType::U8 => metal_kernels::custom_binary::strided::$name::U8,
Expand Down

0 comments on commit 86e7a65

Please sign in to comment.