From 1c7520dd83ce53ca91d9d5245047d0bfb32ac4bc Mon Sep 17 00:00:00 2001 From: Laiho Date: Mon, 29 Jul 2024 19:50:17 +0300 Subject: [PATCH] add multiversioning --- src/all_equal.rs | 48 ++++++----- src/contains.rs | 64 ++++++++------- src/eq.rs | 46 ++++++----- src/filter.rs | 203 ++++++++++++++++++++++++++--------------------- src/is_sorted.rs | 55 +++++++------ src/position.rs | 106 ++++++++++++++----------- 6 files changed, 293 insertions(+), 229 deletions(-) diff --git a/src/all_equal.rs b/src/all_equal.rs index 455a2cb..c34f6fc 100644 --- a/src/all_equal.rs +++ b/src/all_equal.rs @@ -1,10 +1,37 @@ use crate::SIMD_LEN; +use multiversion::multiversion; use std::simd::cmp::SimdPartialEq; use std::simd::Mask; use std::simd::Simd; use std::simd::SimdElement; use std::slice; +#[multiversion(targets = "simd")] +fn all_equal_simd_internal(arr: &[T]) -> bool +where + T: SimdElement + std::cmp::PartialEq, + Simd: SimdPartialEq>, +{ + if arr.is_empty() { + return true; + } + let first = arr[0]; + let (prefix, simd_data, suffix) = arr.as_simd::(); + // Prefix + if !prefix.iter().all(|x| *x == first) { + return false; + } + // SIMD + let simd_needle = Simd::splat(first); + for rest_slice in simd_data { + let mask = rest_slice.simd_ne(simd_needle).to_bitmask(); + if mask != 0 { + return false; + } + } + // Suffix + suffix.iter().all(|x| *x == first) +} pub trait AllEqualSimd<'a, T> where T: SimdElement + std::cmp::PartialEq, @@ -19,26 +46,7 @@ where Simd: SimdPartialEq>, { fn all_equal_simd(&self) -> bool { - let arr = self.as_slice(); - if arr.is_empty() { - return true; - } - let first = arr[0]; - let (prefix, simd_data, suffix) = arr.as_simd::(); - // Prefix - if !prefix.iter().all(|x| *x == first) { - return false; - } - // SIMD - let simd_needle = Simd::splat(first); - for rest_slice in simd_data { - let mask = rest_slice.simd_ne(simd_needle).to_bitmask(); - if mask != 0 { - return false; - } - } - // Suffix - suffix.iter().all(|x| *x == first) + all_equal_simd_internal(self.as_slice()) } } diff --git a/src/contains.rs b/src/contains.rs index e1a7eff..43bfe14 100644 --- a/src/contains.rs +++ b/src/contains.rs @@ -1,10 +1,45 @@ use crate::SIMD_LEN; use crate::UNROLL_FACTOR; +use multiversion::multiversion; use std::simd::cmp::SimdPartialEq; use std::simd::Mask; use std::simd::{Simd, SimdElement}; use std::slice; +#[multiversion(targets = "simd")] +fn contains_simd_internal(arr: &[T], needle: &T) -> bool +where + T: SimdElement + std::cmp::PartialEq, + Simd: SimdPartialEq>, +{ + let (prefix, simd_data, suffix) = arr.as_simd::(); + // Prefix + if prefix.contains(&needle) { + return true; + } + // SIMD + let simd_needle = Simd::splat(*needle); + // Unrolled loops + let mut chunks_iter = simd_data.chunks_exact(UNROLL_FACTOR); + for chunks in chunks_iter.by_ref() { + let mut mask = Mask::default(); + for chunk in chunks { + mask |= chunk.simd_eq(simd_needle); + } + if mask.any() { + return true; + } + } + for chunk in chunks_iter.remainder() { + let mask = chunk.simd_eq(simd_needle); + if mask.any() { + return true; + } + } + // Suffix + suffix.contains(&needle) +} + pub trait ContainsSimd<'a, T> where T: SimdElement + std::cmp::PartialEq, @@ -19,36 +54,9 @@ where Simd: SimdPartialEq>, { fn contains_simd(&self, needle: &T) -> bool { - let arr = self.as_slice(); - let (prefix, simd_data, suffix) = arr.as_simd::(); - // Prefix - if prefix.contains(&needle) { - return true; - } - // SIMD - let simd_needle = Simd::splat(*needle); - // Unrolled loops - let mut chunks_iter = simd_data.chunks_exact(UNROLL_FACTOR); - for chunks in chunks_iter.by_ref() { - let mut mask = Mask::default(); - for chunk in chunks { - mask |= chunk.simd_eq(simd_needle); - } - if mask.any() { - return true; - } - } - for chunk in chunks_iter.remainder() { - let mask = chunk.simd_eq(simd_needle); - if mask.any() { - return true; - } - } - // Suffix - suffix.contains(&needle) + contains_simd_internal(self.as_slice(), needle) } } - #[cfg(test)] mod tests { use super::*; diff --git a/src/eq.rs b/src/eq.rs index 3996241..087aaa8 100644 --- a/src/eq.rs +++ b/src/eq.rs @@ -1,11 +1,37 @@ use crate::SIMD_LEN; use crate::UNROLL_FACTOR; +use multiversion::multiversion; use std::simd::cmp::SimdPartialEq; use std::simd::Mask; use std::simd::Simd; use std::simd::SimdElement; use std::slice; +#[multiversion(targets = "simd")] +fn eq_simd_internal(a: &[T], b: &[T]) -> bool +where + T: SimdElement + std::cmp::PartialEq, + Simd: SimdPartialEq>, +{ + if a.len() != b.len() { + return false; + } + + let mut chunks_a = a.chunks_exact(SIMD_LEN * UNROLL_FACTOR); + let mut chunks_b = b.chunks_exact(SIMD_LEN * UNROLL_FACTOR); + let mut mask = Mask::default(); + + for (aa, bb) in chunks_a.by_ref().zip(chunks_b.by_ref()) { + for (aaa, bbb) in aa.chunks_exact(SIMD_LEN).zip(bb.chunks_exact(SIMD_LEN)) { + mask |= Simd::from_slice(aaa).simd_ne(Simd::from_slice(bbb)); + } + if mask.any() { + return false; + } + } + return chunks_a.remainder().eq(chunks_b.remainder()); +} + pub trait EqSimd<'a, T> where T: SimdElement + std::cmp::PartialEq, @@ -20,25 +46,7 @@ where Simd: SimdPartialEq>, { fn eq_simd(&self, other: &Self) -> bool { - let a = self.as_slice(); - let b = other.as_slice(); - if a.len() != b.len() { - return false; - } - - let mut chunks_a = a.chunks_exact(SIMD_LEN * UNROLL_FACTOR); - let mut chunks_b = b.chunks_exact(SIMD_LEN * UNROLL_FACTOR); - let mut mask = Mask::default(); - - for (aa, bb) in chunks_a.by_ref().zip(chunks_b.by_ref()) { - for (aaa, bbb) in aa.chunks_exact(SIMD_LEN).zip(bb.chunks_exact(SIMD_LEN)) { - mask |= Simd::from_slice(aaa).simd_ne(Simd::from_slice(bbb)); - } - if mask.any() { - return false; - } - } - return chunks_a.remainder().eq(chunks_b.remainder()); + eq_simd_internal(self.as_slice(), other.as_slice()) } } diff --git a/src/filter.rs b/src/filter.rs index 604cc53..3d9f74e 100644 --- a/src/filter.rs +++ b/src/filter.rs @@ -1,3 +1,4 @@ +use multiversion::multiversion; use std::simd::cmp::SimdPartialOrd; use std::simd::prelude::SimdPartialEq; use std::simd::usizex8; @@ -6,116 +7,111 @@ use std::simd::Simd; use std::simd::SimdElement; use std::slice; -pub trait FilterSimd<'a, T> -where - T: SimdElement + std::cmp::PartialEq + std::cmp::PartialOrd, - Simd: SimdPartialEq>, -{ - fn filter_simd_lt(&self, needle: T) -> Vec; - fn filter_simd_gt(&self, needle: T) -> Vec; - fn filter_simd_eq(&self, needle: T) -> Vec; -} - -// TODO REMOVE DUPLICATION? -impl<'a, T> FilterSimd<'a, T> for slice::Iter<'a, T> +#[multiversion(targets = "simd")] +fn filter_simd_lt_internal(a: &[T], needle: T) -> Vec where T: SimdElement + std::cmp::PartialEq + std::cmp::PartialOrd + Default, Simd: SimdPartialOrd>, { - fn filter_simd_lt(&self, needle: T) -> Vec { - let a = self.as_slice(); - let mut indicies = vec![]; - let (prefix, simd_chunk, suffix) = a.as_simd::<8>(); - let prefix_filters = prefix.iter().filter(|x| **x < needle); + let mut indicies = vec![]; + let (prefix, simd_chunk, suffix) = a.as_simd::<8>(); + let prefix_filters = prefix.iter().filter(|x| **x < needle); - indicies.extend(prefix_filters); - let prefix_filters_len = indicies.len(); - indicies.resize(std::cmp::max(prefix_filters_len, 64), T::default()); - let simd_needle = Simd::splat(needle); - let mut simd_idx = prefix_filters_len; + indicies.extend(prefix_filters); + let prefix_filters_len = indicies.len(); + indicies.resize(std::cmp::max(prefix_filters_len, 64), T::default()); + let simd_needle = Simd::splat(needle); + let mut simd_idx = prefix_filters_len; - // SIMD - for chunk in simd_chunk { - let x = chunk.simd_lt(simd_needle); - let bitmask = x.to_bitmask(); - if bitmask != 0 { - let idxs = SET_BITS_TO_INDICIES[bitmask as usize]; - chunk.scatter(&mut indicies[simd_idx..], idxs); - simd_idx += bitmask.count_ones() as usize; - if simd_idx <= indicies.len() { - indicies.resize(indicies.len() + 64, T::default()); - } + // SIMD + for chunk in simd_chunk { + let x = chunk.simd_lt(simd_needle); + let bitmask = x.to_bitmask(); + if bitmask != 0 { + let idxs = SET_BITS_TO_INDICIES[bitmask as usize]; + chunk.scatter(&mut indicies[simd_idx..], idxs); + simd_idx += bitmask.count_ones() as usize; + if simd_idx <= indicies.len() { + indicies.resize(indicies.len() + 64, T::default()); } } - - indicies.truncate(simd_idx); - let suffix_filters = suffix.iter().filter(|x| **x < needle); - indicies.extend(suffix_filters); - indicies } - fn filter_simd_gt(&self, needle: T) -> Vec { - let a = self.as_slice(); - let mut indicies = vec![]; - let (prefix, simd_chunk, suffix) = a.as_simd::<8>(); - let prefix_filters = prefix.iter().filter(|x| **x > needle); - indicies.extend(prefix_filters); - let prefix_filters_len = indicies.len(); - indicies.resize(std::cmp::max(prefix_filters_len, 64), T::default()); - let simd_needle = Simd::splat(needle); - let mut simd_idx = prefix_filters_len; + indicies.truncate(simd_idx); + let suffix_filters = suffix.iter().filter(|x| **x < needle); + indicies.extend(suffix_filters); + indicies +} +#[multiversion(targets = "simd")] +fn filter_simd_gt_internal(a: &[T], needle: T) -> Vec +where + T: SimdElement + std::cmp::PartialEq + std::cmp::PartialOrd + Default, + Simd: SimdPartialOrd>, +{ + let mut indicies = vec![]; + let (prefix, simd_chunk, suffix) = a.as_simd::<8>(); + let prefix_filters = prefix.iter().filter(|x| **x > needle); - // SIMD - for chunk in simd_chunk { - let x = chunk.simd_gt(simd_needle); - let bitmask = x.to_bitmask(); - if bitmask != 0 { - let idxs = SET_BITS_TO_INDICIES[bitmask as usize]; - chunk.scatter(&mut indicies[simd_idx..], idxs); - simd_idx += bitmask.count_ones() as usize; + indicies.extend(prefix_filters); + let prefix_filters_len = indicies.len(); + indicies.resize(std::cmp::max(prefix_filters_len, 64), T::default()); + let simd_needle = Simd::splat(needle); + let mut simd_idx = prefix_filters_len; - if simd_idx <= indicies.len() { - indicies.resize(indicies.len() + 64, T::default()); - } + // SIMD + for chunk in simd_chunk { + let x = chunk.simd_gt(simd_needle); + let bitmask = x.to_bitmask(); + if bitmask != 0 { + let idxs = SET_BITS_TO_INDICIES[bitmask as usize]; + chunk.scatter(&mut indicies[simd_idx..], idxs); + simd_idx += bitmask.count_ones() as usize; + + if simd_idx <= indicies.len() { + indicies.resize(indicies.len() + 64, T::default()); } } - - indicies.truncate(simd_idx); - let suffix_filters = suffix.iter().filter(|x| **x > needle); - indicies.extend(suffix_filters); - indicies } - fn filter_simd_eq(&self, needle: T) -> Vec { - let a = self.as_slice(); - let mut indicies = vec![]; - let (prefix, simd_chunk, suffix) = a.as_simd::<8>(); - let prefix_filters = prefix.iter().filter(|x| **x == needle); - indicies.extend(prefix_filters); - let prefix_filters_len = indicies.len(); - indicies.resize(std::cmp::max(prefix_filters_len, 64), T::default()); - let simd_needle = Simd::splat(needle); - let mut simd_idx = prefix_filters_len; + indicies.truncate(simd_idx); + let suffix_filters = suffix.iter().filter(|x| **x > needle); + indicies.extend(suffix_filters); + indicies +} +#[multiversion(targets = "simd")] +fn filter_simd_eq_internal(a: &[T], needle: T) -> Vec +where + T: SimdElement + std::cmp::PartialEq + std::cmp::PartialOrd + Default, + Simd: SimdPartialOrd>, +{ + let mut indicies = vec![]; + let (prefix, simd_chunk, suffix) = a.as_simd::<8>(); + let prefix_filters = prefix.iter().filter(|x| **x == needle); + + indicies.extend(prefix_filters); + let prefix_filters_len = indicies.len(); + indicies.resize(std::cmp::max(prefix_filters_len, 64), T::default()); + let simd_needle = Simd::splat(needle); + let mut simd_idx = prefix_filters_len; - // SIMD - for chunk in simd_chunk { - let x = chunk.simd_eq(simd_needle); - let bitmask = x.to_bitmask(); - if bitmask != 0 { - let idxs = SET_BITS_TO_INDICIES[bitmask as usize]; - chunk.scatter(&mut indicies[simd_idx..], idxs); - simd_idx += bitmask.count_ones() as usize; - if simd_idx <= indicies.len() { - indicies.resize(indicies.len() + 64, T::default()); - } + // SIMD + for chunk in simd_chunk { + let x = chunk.simd_eq(simd_needle); + let bitmask = x.to_bitmask(); + if bitmask != 0 { + let idxs = SET_BITS_TO_INDICIES[bitmask as usize]; + chunk.scatter(&mut indicies[simd_idx..], idxs); + simd_idx += bitmask.count_ones() as usize; + if simd_idx <= indicies.len() { + indicies.resize(indicies.len() + 64, T::default()); } } - - indicies.truncate(simd_idx); - let suffix_filters = suffix.iter().filter(|x| **x == needle); - indicies.extend(suffix_filters); - indicies } + + indicies.truncate(simd_idx); + let suffix_filters = suffix.iter().filter(|x| **x == needle); + indicies.extend(suffix_filters); + indicies } const SET_BITS_TO_INDICIES: [usizex8; 256] = [ @@ -377,6 +373,33 @@ const SET_BITS_TO_INDICIES: [usizex8; 256] = [ usizex8::from_array([0, 1, 2, 3, 4, 5, 6, 7]), ]; +pub trait FilterSimd<'a, T> +where + T: SimdElement + std::cmp::PartialEq + std::cmp::PartialOrd, + Simd: SimdPartialEq>, +{ + fn filter_simd_lt(&self, needle: T) -> Vec; + fn filter_simd_gt(&self, needle: T) -> Vec; + fn filter_simd_eq(&self, needle: T) -> Vec; +} + +// TODO REMOVE DUPLICATION? +impl<'a, T> FilterSimd<'a, T> for slice::Iter<'a, T> +where + T: SimdElement + std::cmp::PartialEq + std::cmp::PartialOrd + Default, + Simd: SimdPartialOrd>, +{ + fn filter_simd_eq(&self, needle: T) -> Vec { + filter_simd_eq_internal(self.as_slice(), needle) + } + fn filter_simd_gt(&self, needle: T) -> Vec { + filter_simd_gt_internal(self.as_slice(), needle) + } + fn filter_simd_lt(&self, needle: T) -> Vec { + filter_simd_lt_internal(self.as_slice(), needle) + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/is_sorted.rs b/src/is_sorted.rs index 569b90e..1daff4f 100644 --- a/src/is_sorted.rs +++ b/src/is_sorted.rs @@ -1,10 +1,40 @@ use crate::SIMD_LEN; +use multiversion::multiversion; use std::simd::cmp::SimdPartialOrd; use std::simd::Mask; use std::simd::Simd; use std::simd::SimdElement; use std::slice; +#[multiversion(targets = "simd")] +fn is_sorted_simd_internal(a: &[T]) -> bool +where + T: SimdElement + std::cmp::PartialOrd, + Simd: SimdPartialOrd>, +{ + if a.len() <= SIMD_LEN && !a.is_empty() { + return a.is_sorted(); + } + + let chunks_a = a.chunks_exact(SIMD_LEN); + let chunks_b = a[1..].chunks_exact(SIMD_LEN); + let reminder_a_is_sorted = chunks_a.remainder().iter().is_sorted(); + let reminder_b_is_sorted = chunks_b.remainder().iter().is_sorted(); + + // chunk: [1,2,3,4] + // offset_by_one: [2,3,4,5] + // If for all chunk[i] <= offset[i] then the slice is sorted + + for (a, b) in chunks_a.zip(chunks_b) { + let chunk = Simd::from_slice(a); + let chunk_offset_by_one = Simd::from_slice(b); + if chunk.simd_gt(chunk_offset_by_one).to_bitmask() != 0 { + return false; + } + } + reminder_a_is_sorted | reminder_b_is_sorted +} + pub trait IsSortedSimd<'a, T> where T: SimdElement + std::cmp::PartialOrd, @@ -12,36 +42,13 @@ where { fn is_sorted_simd(&self) -> bool; } - impl<'a, T> IsSortedSimd<'a, T> for slice::Iter<'a, T> where T: SimdElement + std::cmp::PartialOrd, Simd: SimdPartialOrd>, { fn is_sorted_simd(&self) -> bool { - let a = self.as_slice(); - - if a.len() <= SIMD_LEN && !a.is_empty() { - return a.is_sorted(); - } - - let chunks_a = a.chunks_exact(SIMD_LEN); - let chunks_b = a[1..].chunks_exact(SIMD_LEN); - let reminder_a_is_sorted = chunks_a.remainder().iter().is_sorted(); - let reminder_b_is_sorted = chunks_b.remainder().iter().is_sorted(); - - // chunk: [1,2,3,4] - // offset_by_one: [2,3,4,5] - // If for all chunk[i] <= offset[i] then the slice is sorted - - for (a, b) in chunks_a.zip(chunks_b) { - let chunk = Simd::from_slice(a); - let chunk_offset_by_one = Simd::from_slice(b); - if chunk.simd_gt(chunk_offset_by_one).to_bitmask() != 0 { - return false; - } - } - reminder_a_is_sorted | reminder_b_is_sorted + is_sorted_simd_internal(self.as_slice()) } } diff --git a/src/position.rs b/src/position.rs index b33be6e..c3daa34 100644 --- a/src/position.rs +++ b/src/position.rs @@ -1,10 +1,66 @@ use crate::SIMD_LEN; use crate::UNROLL_FACTOR; +use multiversion::multiversion; use std::simd::cmp::SimdPartialEq; use std::simd::Mask; use std::simd::{Simd, SimdElement}; use std::slice; +#[multiversion(targets = "simd")] +fn position_simd_internal(arr: &[T], needle: T) -> Option +where + T: SimdElement + std::cmp::PartialEq, + Simd: SimdPartialEq>, +{ + let (prefix, simd_data, suffix) = arr.as_simd::(); + // Prefix + if let Some(pos) = prefix.iter().position(|x| *x == needle) { + return Some(pos); + } + // SIMD + let simd_needle = Simd::splat(needle); + let mut unrolled_loops = 0; + // Unrolled loops + let mut chunks_iter = simd_data.chunks_exact(UNROLL_FACTOR); + for chunks in chunks_iter.by_ref() { + let mut mask = Mask::default(); + for chunk in chunks { + mask |= chunk.simd_eq(simd_needle); + } + if mask.any() { + for (mask_idx, c) in chunks.iter().enumerate() { + let mask = c.simd_eq(simd_needle); + if mask.any() { + return Some( + prefix.len() + + (unrolled_loops * (SIMD_LEN * UNROLL_FACTOR)) // Full outer loops + + mask_idx * SIMD_LEN // nth inner loop + + mask.to_bitmask().trailing_zeros() as usize, // nth element in matching mask + ); + } + } + } + unrolled_loops += 1; + } + // Remaining simd loops that where not divisible by UNROLL_FACTOR + for (idx, chunk) in chunks_iter.remainder().iter().enumerate() { + let mask = chunk.simd_eq(simd_needle).to_bitmask(); + if mask != 0 { + return Some( + prefix.len() + + (unrolled_loops * UNROLL_FACTOR * SIMD_LEN) + + (idx * SIMD_LEN) + + (mask.trailing_zeros() as usize), + ); + } + } + // Suffix + match suffix.iter().position(|x| *x == needle) { + Some(pos) => Some(prefix.len() + (simd_data.len() * SIMD_LEN) + pos), + None => None, + } +} + pub trait PositionSimd<'a, T> where T: SimdElement + std::cmp::PartialEq, @@ -12,7 +68,6 @@ where { fn position_simd(&self, needle: T) -> Option; } - impl<'a, T> PositionSimd<'a, T> for slice::Iter<'a, T> where T: SimdElement + std::cmp::PartialEq, @@ -20,55 +75,10 @@ where { fn position_simd(&self, needle: T) -> Option { let arr = self.as_slice(); - let (prefix, simd_data, suffix) = arr.as_simd::(); - // Prefix - if let Some(pos) = prefix.iter().position(|x| *x == needle) { - return Some(pos); - } - // SIMD - let simd_needle = Simd::splat(needle); - let mut unrolled_loops = 0; - // Unrolled loops - let mut chunks_iter = simd_data.chunks_exact(UNROLL_FACTOR); - for chunks in chunks_iter.by_ref() { - let mut mask = Mask::default(); - for chunk in chunks { - mask |= chunk.simd_eq(simd_needle); - } - if mask.any() { - for (mask_idx, c) in chunks.iter().enumerate() { - let mask = c.simd_eq(simd_needle); - if mask.any() { - return Some( - prefix.len() - + (unrolled_loops * (SIMD_LEN * UNROLL_FACTOR)) // Full outer loops - + mask_idx * SIMD_LEN // nth inner loop - + mask.to_bitmask().trailing_zeros() as usize, // nth element in matching mask - ); - } - } - } - unrolled_loops += 1; - } - // Remaining simd loops that where not divisible by UNROLL_FACTOR - for (idx, chunk) in chunks_iter.remainder().iter().enumerate() { - let mask = chunk.simd_eq(simd_needle).to_bitmask(); - if mask != 0 { - return Some( - prefix.len() - + (unrolled_loops * UNROLL_FACTOR * SIMD_LEN) - + (idx * SIMD_LEN) - + (mask.trailing_zeros() as usize), - ); - } - } - // Suffix - match suffix.iter().position(|x| *x == needle) { - Some(pos) => Some(prefix.len() + (simd_data.len() * SIMD_LEN) + pos), - None => None, - } + position_simd_internal(arr, needle) } } + #[cfg(test)] mod tests { use super::*;