Skip to content

Commit

Permalink
add multiversioning
Browse files Browse the repository at this point in the history
  • Loading branch information
LaihoE committed Jul 29, 2024
1 parent 256a096 commit 1c7520d
Show file tree
Hide file tree
Showing 6 changed files with 293 additions and 229 deletions.
48 changes: 28 additions & 20 deletions src/all_equal.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,37 @@
use crate::SIMD_LEN;
use multiversion::multiversion;
use std::simd::cmp::SimdPartialEq;
use std::simd::Mask;
use std::simd::Simd;
use std::simd::SimdElement;
use std::slice;

#[multiversion(targets = "simd")]
fn all_equal_simd_internal<T>(arr: &[T]) -> bool
where
T: SimdElement + std::cmp::PartialEq,
Simd<T, SIMD_LEN>: SimdPartialEq<Mask = Mask<T::Mask, SIMD_LEN>>,
{
if arr.is_empty() {
return true;
}
let first = arr[0];
let (prefix, simd_data, suffix) = arr.as_simd::<SIMD_LEN>();
// Prefix
if !prefix.iter().all(|x| *x == first) {
return false;
}
// SIMD
let simd_needle = Simd::splat(first);
for rest_slice in simd_data {
let mask = rest_slice.simd_ne(simd_needle).to_bitmask();
if mask != 0 {
return false;
}
}
// Suffix
suffix.iter().all(|x| *x == first)
}
pub trait AllEqualSimd<'a, T>
where
T: SimdElement + std::cmp::PartialEq,
Expand All @@ -19,26 +46,7 @@ where
Simd<T, SIMD_LEN>: SimdPartialEq<Mask = Mask<T::Mask, SIMD_LEN>>,
{
fn all_equal_simd(&self) -> bool {
let arr = self.as_slice();
if arr.is_empty() {
return true;
}
let first = arr[0];
let (prefix, simd_data, suffix) = arr.as_simd::<SIMD_LEN>();
// Prefix
if !prefix.iter().all(|x| *x == first) {
return false;
}
// SIMD
let simd_needle = Simd::splat(first);
for rest_slice in simd_data {
let mask = rest_slice.simd_ne(simd_needle).to_bitmask();
if mask != 0 {
return false;
}
}
// Suffix
suffix.iter().all(|x| *x == first)
all_equal_simd_internal(self.as_slice())
}
}

Expand Down
64 changes: 36 additions & 28 deletions src/contains.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,45 @@
use crate::SIMD_LEN;
use crate::UNROLL_FACTOR;
use multiversion::multiversion;
use std::simd::cmp::SimdPartialEq;
use std::simd::Mask;
use std::simd::{Simd, SimdElement};
use std::slice;

#[multiversion(targets = "simd")]
fn contains_simd_internal<T>(arr: &[T], needle: &T) -> bool
where
T: SimdElement + std::cmp::PartialEq,
Simd<T, SIMD_LEN>: SimdPartialEq<Mask = Mask<T::Mask, SIMD_LEN>>,
{
let (prefix, simd_data, suffix) = arr.as_simd::<SIMD_LEN>();
// Prefix
if prefix.contains(&needle) {
return true;
}
// SIMD
let simd_needle = Simd::splat(*needle);
// Unrolled loops
let mut chunks_iter = simd_data.chunks_exact(UNROLL_FACTOR);
for chunks in chunks_iter.by_ref() {
let mut mask = Mask::default();
for chunk in chunks {
mask |= chunk.simd_eq(simd_needle);
}
if mask.any() {
return true;
}
}
for chunk in chunks_iter.remainder() {
let mask = chunk.simd_eq(simd_needle);
if mask.any() {
return true;
}
}
// Suffix
suffix.contains(&needle)
}

pub trait ContainsSimd<'a, T>
where
T: SimdElement + std::cmp::PartialEq,
Expand All @@ -19,36 +54,9 @@ where
Simd<T, SIMD_LEN>: SimdPartialEq<Mask = Mask<T::Mask, SIMD_LEN>>,
{
fn contains_simd(&self, needle: &T) -> bool {
let arr = self.as_slice();
let (prefix, simd_data, suffix) = arr.as_simd::<SIMD_LEN>();
// Prefix
if prefix.contains(&needle) {
return true;
}
// SIMD
let simd_needle = Simd::splat(*needle);
// Unrolled loops
let mut chunks_iter = simd_data.chunks_exact(UNROLL_FACTOR);
for chunks in chunks_iter.by_ref() {
let mut mask = Mask::default();
for chunk in chunks {
mask |= chunk.simd_eq(simd_needle);
}
if mask.any() {
return true;
}
}
for chunk in chunks_iter.remainder() {
let mask = chunk.simd_eq(simd_needle);
if mask.any() {
return true;
}
}
// Suffix
suffix.contains(&needle)
contains_simd_internal(self.as_slice(), needle)
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
46 changes: 27 additions & 19 deletions src/eq.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,37 @@
use crate::SIMD_LEN;
use crate::UNROLL_FACTOR;
use multiversion::multiversion;
use std::simd::cmp::SimdPartialEq;
use std::simd::Mask;
use std::simd::Simd;
use std::simd::SimdElement;
use std::slice;

#[multiversion(targets = "simd")]
fn eq_simd_internal<T>(a: &[T], b: &[T]) -> bool
where
T: SimdElement + std::cmp::PartialEq,
Simd<T, SIMD_LEN>: SimdPartialEq<Mask = Mask<T::Mask, SIMD_LEN>>,
{
if a.len() != b.len() {
return false;
}

let mut chunks_a = a.chunks_exact(SIMD_LEN * UNROLL_FACTOR);
let mut chunks_b = b.chunks_exact(SIMD_LEN * UNROLL_FACTOR);
let mut mask = Mask::default();

for (aa, bb) in chunks_a.by_ref().zip(chunks_b.by_ref()) {
for (aaa, bbb) in aa.chunks_exact(SIMD_LEN).zip(bb.chunks_exact(SIMD_LEN)) {
mask |= Simd::from_slice(aaa).simd_ne(Simd::from_slice(bbb));
}
if mask.any() {
return false;
}
}
return chunks_a.remainder().eq(chunks_b.remainder());
}

pub trait EqSimd<'a, T>
where
T: SimdElement + std::cmp::PartialEq,
Expand All @@ -20,25 +46,7 @@ where
Simd<T, SIMD_LEN>: SimdPartialEq<Mask = Mask<T::Mask, SIMD_LEN>>,
{
fn eq_simd(&self, other: &Self) -> bool {
let a = self.as_slice();
let b = other.as_slice();
if a.len() != b.len() {
return false;
}

let mut chunks_a = a.chunks_exact(SIMD_LEN * UNROLL_FACTOR);
let mut chunks_b = b.chunks_exact(SIMD_LEN * UNROLL_FACTOR);
let mut mask = Mask::default();

for (aa, bb) in chunks_a.by_ref().zip(chunks_b.by_ref()) {
for (aaa, bbb) in aa.chunks_exact(SIMD_LEN).zip(bb.chunks_exact(SIMD_LEN)) {
mask |= Simd::from_slice(aaa).simd_ne(Simd::from_slice(bbb));
}
if mask.any() {
return false;
}
}
return chunks_a.remainder().eq(chunks_b.remainder());
eq_simd_internal(self.as_slice(), other.as_slice())
}
}

Expand Down
Loading

0 comments on commit 1c7520d

Please sign in to comment.