diff --git a/src/eq.rs b/src/eq.rs index bfd321e..3996241 100644 --- a/src/eq.rs +++ b/src/eq.rs @@ -1,4 +1,5 @@ use crate::SIMD_LEN; +use crate::UNROLL_FACTOR; use std::simd::cmp::SimdPartialEq; use std::simd::Mask; use std::simd::Simd; @@ -21,25 +22,23 @@ where fn eq_simd(&self, other: &Self) -> bool { let a = self.as_slice(); let b = other.as_slice(); - if a.len() != b.len() { return false; } - if a.len() <= SIMD_LEN || b.len() <= SIMD_LEN { - return a.iter().eq(b); - } - let chunks_a = a.chunks_exact(SIMD_LEN); - let chunks_b = b.chunks_exact(SIMD_LEN); - let remainder_is_sorted = chunks_a.remainder().iter().eq(chunks_b.remainder().iter()); - for (a, b) in chunks_a.zip(chunks_b) { - let chunk_a = Simd::from_slice(a); - let chunk_b = Simd::from_slice(b); - if chunk_a.simd_ne(chunk_b).to_bitmask() != 0 { + let mut chunks_a = a.chunks_exact(SIMD_LEN * UNROLL_FACTOR); + let mut chunks_b = b.chunks_exact(SIMD_LEN * UNROLL_FACTOR); + let mut mask = Mask::default(); + + for (aa, bb) in chunks_a.by_ref().zip(chunks_b.by_ref()) { + for (aaa, bbb) in aa.chunks_exact(SIMD_LEN).zip(bb.chunks_exact(SIMD_LEN)) { + mask |= Simd::from_slice(aaa).simd_ne(Simd::from_slice(bbb)); + } + if mask.any() { return false; } } - return remainder_is_sorted; + return chunks_a.remainder().eq(chunks_b.remainder()); } } @@ -61,7 +60,7 @@ mod tests { Simd: SimdPartialEq>, Standard: Distribution, { - for len in 0..100 { + for len in 0..1000 { for _ in 0..5 { let mut v: Vec = vec![T::default(); len]; let mut rng = rand::thread_rng(); @@ -98,7 +97,7 @@ mod tests { Simd: SimdPartialEq>, Standard: Distribution, { - for len in 0..100 { + for len in 0..1000 { for _ in 0..5 { let mut v: Vec = vec![T::default(); len]; let mut rng = rand::thread_rng();