Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GemmType trait for dispatching gemm fn calls #22

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 54 additions & 141 deletions gemm/src/gemm.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use crate::Parallelism;
use core::any::TypeId;

#[allow(non_camel_case_types)]
pub type c32 = num_complex::Complex32;
Expand All @@ -9,151 +8,65 @@ pub type c64 = num_complex::Complex64;
#[allow(non_camel_case_types)]
pub type f16 = gemm_f16::f16;

unsafe fn gemm_dispatch<T: 'static>(
m: usize,
n: usize,
k: usize,
dst: *mut T,
dst_cs: isize,
dst_rs: isize,
read_dst: bool,
lhs: *const T,
lhs_cs: isize,
lhs_rs: isize,
rhs: *const T,
rhs_cs: isize,
rhs_rs: isize,
alpha: T,
beta: T,
conj_dst: bool,
conj_lhs: bool,
conj_rhs: bool,
parallelism: Parallelism,
) {
#[cfg(feature = "f16")]
if TypeId::of::<T>() == TypeId::of::<f16>() {
return gemm_f16::gemm::f16::get_gemm_fn()(
m,
n,
k,
dst as *mut f16,
dst_cs,
dst_rs,
read_dst,
lhs as *mut f16,
lhs_cs,
lhs_rs,
rhs as *mut f16,
rhs_cs,
rhs_rs,
*(&alpha as *const T as *const f16),
*(&beta as *const T as *const f16),
false,
false,
false,
parallelism,
);
type GemmFn<T> = unsafe fn(
usize,
usize,
usize,
*mut T,
isize,
isize,
bool,
*const T,
isize,
isize,
*const T,
isize,
isize,
T,
T,
bool,
bool,
bool,
Parallelism,
);

pub trait GemmType {
fn get_gemm_fn() -> GemmFn<Self>;
}

impl GemmType for f32 {
fn get_gemm_fn() -> GemmFn<Self> {
gemm_f32::gemm::f32::get_gemm_fn()
}
}

if TypeId::of::<T>() == TypeId::of::<f64>() {
gemm_f64::gemm::f64::get_gemm_fn()(
m,
n,
k,
dst as *mut f64,
dst_cs,
dst_rs,
read_dst,
lhs as *mut f64,
lhs_cs,
lhs_rs,
rhs as *mut f64,
rhs_cs,
rhs_rs,
*(&alpha as *const T as *const f64),
*(&beta as *const T as *const f64),
false,
false,
false,
parallelism,
)
} else if TypeId::of::<T>() == TypeId::of::<f32>() {
gemm_f32::gemm::f32::get_gemm_fn()(
m,
n,
k,
dst as *mut f32,
dst_cs,
dst_rs,
read_dst,
lhs as *mut f32,
lhs_cs,
lhs_rs,
rhs as *mut f32,
rhs_cs,
rhs_rs,
*(&alpha as *const T as *const f32),
*(&beta as *const T as *const f32),
false,
false,
false,
parallelism,
)
} else if TypeId::of::<T>() == TypeId::of::<c64>() {
gemm_c64::gemm::f64::get_gemm_fn()(
m,
n,
k,
dst as *mut c64,
dst_cs,
dst_rs,
read_dst,
lhs as *mut c64,
lhs_cs,
lhs_rs,
rhs as *mut c64,
rhs_cs,
rhs_rs,
*(&alpha as *const T as *const c64),
*(&beta as *const T as *const c64),
conj_dst,
conj_lhs,
conj_rhs,
parallelism,
)
} else if TypeId::of::<T>() == TypeId::of::<c32>() {
gemm_c32::gemm::f32::get_gemm_fn()(
m,
n,
k,
dst as *mut c32,
dst_cs,
dst_rs,
read_dst,
lhs as *mut c32,
lhs_cs,
lhs_rs,
rhs as *mut c32,
rhs_cs,
rhs_rs,
*(&alpha as *const T as *const c32),
*(&beta as *const T as *const c32),
conj_dst,
conj_lhs,
conj_rhs,
parallelism,
)
} else {
panic!();
impl GemmType for f64 {
fn get_gemm_fn() -> GemmFn<Self> {
gemm_f64::gemm::f64::get_gemm_fn()
}
}

#[cfg(feature = "f16")]
impl GemmType for f16 {
fn get_gemm_fn() -> GemmFn<Self> {
gemm_f16::gemm::f16::get_gemm_fn()
}
}

impl GemmType for c32 {
fn get_gemm_fn() -> GemmFn<Self> {
gemm_c32::gemm::f32::get_gemm_fn()
}
}

impl GemmType for c64 {
fn get_gemm_fn() -> GemmFn<Self> {
gemm_c64::gemm::f64::get_gemm_fn()
}
}

/// dst := alpha×dst + beta×lhs×rhs
///
/// # Panics
///
/// Panics if `T` is not `f32`, `f64`, `gemm::f16`, `gemm::c32`, or `gemm::c64`.
pub unsafe fn gemm<T: 'static>(
pub unsafe fn gemm<T: GemmType + 'static>(
m: usize,
n: usize,
k: usize,
Expand Down Expand Up @@ -215,7 +128,7 @@ pub unsafe fn gemm<T: 'static>(
rhs_cs = -rhs_cs;
}

gemm_dispatch(
T::get_gemm_fn()(
m,
n,
k,
Expand Down