diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c651a39b..f5e6ff41 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -20,7 +20,7 @@ jobs: - stable - beta - nightly - - 1.49.0 # MSRV + - 1.60.0 # MSRV steps: - uses: actions/checkout@v2 - uses: actions-rs/toolchain@v1 @@ -32,6 +32,8 @@ jobs: run: cargo build --verbose - name: Run tests run: cargo test --verbose + - name: Run tests with Rayon + run: cargo test --verbose --features rayon cross_test: runs-on: ubuntu-latest @@ -61,6 +63,8 @@ jobs: run: cross build --verbose --target=${{ matrix.target }} - name: Run tests run: cross test --verbose --target=${{ matrix.target }} + - name: Run tests with Rayon + run: cross test --verbose --features rayon --target=${{ matrix.target }} format: runs-on: ubuntu-latest diff --git a/Cargo.toml b/Cargo.toml index 892533e5..136bb5c1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,8 +1,9 @@ [package] name = "ndarray-stats" -version = "0.5.1" +version = "0.6.0" +rust-version = "1.60" +edition = "2021" authors = ["Jim Turner ", "LukeMathWalker "] -edition = "2018" license = "MIT/Apache-2.0" @@ -17,12 +18,13 @@ categories = ["data-structures", "science"] [dependencies] ndarray = "0.15.0" +ndarray-slice = "0.2.2" noisy_float = "0.2.0" num-integer = "0.1" num-traits = "0.2" rand = "0.8.3" itertools = { version = "0.10.0", default-features = false } -indexmap = "1.6.2" +rayon = { version = "1.7.0", optional = true } [dev-dependencies] ndarray = { version = "0.15.0", features = ["approx"] } @@ -33,6 +35,9 @@ approx = "0.4" quickcheck_macros = "1.0.0" num-bigint = "0.4.0" +[features] +rayon = ["dep:rayon", "ndarray-slice/rayon", "ndarray/rayon"] + [[bench]] name = "sort" harness = false @@ -44,3 +49,6 @@ harness = false [[bench]] name = "deviation" harness = false + +[profile.test] +opt-level = 2 diff --git a/benches/sort.rs b/benches/sort.rs index 1a2f4429..01da9c15 100644 --- a/benches/sort.rs +++ b/benches/sort.rs @@ -2,7 +2,7 @@ use criterion::{ black_box, criterion_group, criterion_main, AxisScale, BatchSize, Criterion, PlotConfiguration, }; use ndarray::prelude::*; -use ndarray_stats::Sort1dExt; +use ndarray_slice::Slice1Ext; use rand::prelude::*; fn get_from_sorted_mut(c: &mut Criterion) { @@ -19,7 +19,7 @@ fn get_from_sorted_mut(c: &mut Criterion) { || Array1::from(data.clone()), |mut arr| { for &i in &indices { - black_box(arr.get_from_sorted_mut(i)); + black_box(arr.select_nth_unstable(i)); } }, BatchSize::SmallInput, @@ -42,7 +42,8 @@ fn get_many_from_sorted_mut(c: &mut Criterion) { b.iter_batched( || Array1::from(data.clone()), |mut arr| { - black_box(arr.get_many_from_sorted_mut(&indices)); + let mut values = Vec::with_capacity(indices.len()); + black_box(arr.select_many_nth_unstable(&indices, &mut values)); }, BatchSize::SmallInput, ) diff --git a/src/histogram/bins.rs b/src/histogram/bins.rs index f6ff818e..6b8c7924 100644 --- a/src/histogram/bins.rs +++ b/src/histogram/bins.rs @@ -3,6 +3,9 @@ use ndarray::prelude::*; use std::ops::{Index, Range}; +#[cfg(feature = "rayon")] +use rayon::slice::ParallelSliceMut; + /// A sorted collection of type `A` elements used to represent the boundaries of intervals, i.e. /// [`Bins`] on a 1-dimensional axis. /// @@ -30,11 +33,11 @@ use std::ops::{Index, Range}; /// /// [`Bins`]: struct.Bins.html #[derive(Clone, Debug, Eq, PartialEq)] -pub struct Edges { +pub struct Edges { edges: Vec, } -impl From> for Edges { +impl From> for Edges { /// Converts a `Vec` into an `Edges`, consuming the edges. /// The vector will be sorted in increasing order using an unstable sorting algorithm, with /// duplicates removed. @@ -65,6 +68,9 @@ impl From> for Edges { /// [pdqsort]: https://github.com/orlp/pdqsort fn from(mut edges: Vec) -> Self { // sort the array in-place + #[cfg(feature = "rayon")] + edges.par_sort_unstable(); + #[cfg(not(feature = "rayon"))] edges.sort_unstable(); // remove duplicates edges.dedup(); @@ -72,7 +78,7 @@ impl From> for Edges { } } -impl From> for Edges { +impl From> for Edges { /// Converts an `Array1` into an `Edges`, consuming the 1-dimensional array. /// The array will be sorted in increasing order using an unstable sorting algorithm, with /// duplicates removed. @@ -106,7 +112,7 @@ impl From> for Edges { } } -impl Index for Edges { +impl Index for Edges { type Output = A; /// Returns a reference to the `i`-th edge in `self`. @@ -131,7 +137,7 @@ impl Index for Edges { } } -impl Edges { +impl Edges { /// Returns the number of edges in `self`. /// /// # Examples @@ -258,11 +264,11 @@ impl Edges { /// ); /// ``` #[derive(Clone, Debug, Eq, PartialEq)] -pub struct Bins { +pub struct Bins { edges: Edges, } -impl Bins { +impl Bins { /// Returns a `Bins` instance where each bin corresponds to two consecutive members of the given /// [`Edges`], consuming the edges. /// diff --git a/src/histogram/grid.rs b/src/histogram/grid.rs index 57e85061..d941eefd 100644 --- a/src/histogram/grid.rs +++ b/src/histogram/grid.rs @@ -88,11 +88,11 @@ use std::ops::Range; /// [`GridBuilder`]: struct.GridBuilder.html /// [`strategy`]: strategies/index.html #[derive(Clone, Debug, Eq, PartialEq)] -pub struct Grid { +pub struct Grid { projections: Vec>, } -impl From>> for Grid { +impl From>> for Grid { /// Converts a `Vec>` into a `Grid`, consuming the vector of bins. /// /// The `i`-th element in `Vec>` represents the projection of the bin grid onto the @@ -106,7 +106,7 @@ impl From>> for Grid { } } -impl Grid { +impl Grid { /// Returns the number of dimensions of the region partitioned by the grid. /// /// # Examples @@ -220,7 +220,7 @@ impl Grid { } } -impl Grid { +impl Grid { /// Given an `n`-dimensional index, `i = (i_0, ..., i_{n-1})`, returns an `n`-dimensional bin, /// `I_{i_0} x ... x I_{i_{n-1}}`, where `I_{i_j}` is the `i_j`-th interval on the `j`-th /// projection of the grid on the coordinate axes. @@ -318,7 +318,7 @@ pub struct GridBuilder { impl GridBuilder where - A: Ord, + A: Ord + Send, B: BinsBuildingStrategy, { /// Returns a `GridBuilder` for building a [`Grid`] with a given [`strategy`] and some diff --git a/src/histogram/histograms.rs b/src/histogram/histograms.rs index 603a5019..65e25cd3 100644 --- a/src/histogram/histograms.rs +++ b/src/histogram/histograms.rs @@ -4,12 +4,12 @@ use ndarray::prelude::*; use ndarray::Data; /// Histogram data structure. -pub struct Histogram { +pub struct Histogram { counts: ArrayD, grid: Grid, } -impl Histogram { +impl Histogram { /// Returns a new instance of Histogram given a [`Grid`]. /// /// [`Grid`]: struct.Grid.html @@ -43,7 +43,7 @@ impl Histogram { /// [0, 1], /// ]; /// assert_eq!(histogram_matrix, expected.into_dyn()); - /// # Ok::<(), Box>(()) + /// # Ok::<(), Box>(()) /// ``` pub fn add_observation(&mut self, observation: &ArrayBase) -> Result<(), BinNotFound> where @@ -136,7 +136,7 @@ where /// ``` fn histogram(&self, grid: Grid) -> Histogram where - A: Ord; + A: Ord + Send; private_decl! {} } @@ -144,7 +144,7 @@ where impl HistogramExt for ArrayBase where S: Data, - A: Ord, + A: Ord + Send, { fn histogram(&self, grid: Grid) -> Histogram { let mut histogram = Histogram::new(grid); diff --git a/src/histogram/strategies.rs b/src/histogram/strategies.rs index a1522109..99334f02 100644 --- a/src/histogram/strategies.rs +++ b/src/histogram/strategies.rs @@ -65,7 +65,7 @@ use num_traits::{FromPrimitive, NumOps, Zero}; /// [`Grid`]: ../struct.Grid.html pub trait BinsBuildingStrategy { #[allow(missing_docs)] - type Elem: Ord; + type Elem: Ord + Send; /// Returns a strategy that has learnt the required parameter fo building [`Bins`] for given /// 1-dimensional array, or an `Err` if it is not possible to infer the required parameter /// with the given data and specified strategy. @@ -213,7 +213,7 @@ pub struct Auto { impl EquiSpaced where - T: Ord + Clone + FromPrimitive + NumOps + Zero, + T: Ord + Send + Clone + FromPrimitive + NumOps + Zero, { /// Returns `Err(BinsBuildError::Strategy)` if `bin_width<=0` or `min` >= `max`. /// Returns `Ok(Self)` otherwise. @@ -256,7 +256,7 @@ where impl BinsBuildingStrategy for Sqrt where - T: Ord + Clone + FromPrimitive + NumOps + Zero, + T: Ord + Send + Clone + FromPrimitive + NumOps + Zero, { type Elem = T; @@ -292,7 +292,7 @@ where impl Sqrt where - T: Ord + Clone + FromPrimitive + NumOps + Zero, + T: Ord + Send + Clone + FromPrimitive + NumOps + Zero, { /// The bin width (or bin length) according to the fitted strategy. pub fn bin_width(&self) -> T { @@ -302,7 +302,7 @@ where impl BinsBuildingStrategy for Rice where - T: Ord + Clone + FromPrimitive + NumOps + Zero, + T: Ord + Send + Clone + FromPrimitive + NumOps + Zero, { type Elem = T; @@ -338,7 +338,7 @@ where impl Rice where - T: Ord + Clone + FromPrimitive + NumOps + Zero, + T: Ord + Send + Clone + FromPrimitive + NumOps + Zero, { /// The bin width (or bin length) according to the fitted strategy. pub fn bin_width(&self) -> T { @@ -348,7 +348,7 @@ where impl BinsBuildingStrategy for Sturges where - T: Ord + Clone + FromPrimitive + NumOps + Zero, + T: Ord + Send + Clone + FromPrimitive + NumOps + Zero, { type Elem = T; @@ -384,7 +384,7 @@ where impl Sturges where - T: Ord + Clone + FromPrimitive + NumOps + Zero, + T: Ord + Send + Clone + FromPrimitive + NumOps + Zero, { /// The bin width (or bin length) according to the fitted strategy. pub fn bin_width(&self) -> T { @@ -394,7 +394,7 @@ where impl BinsBuildingStrategy for FreedmanDiaconis where - T: Ord + Clone + FromPrimitive + NumOps + Zero, + T: Ord + Send + Clone + FromPrimitive + NumOps + Zero, { type Elem = T; @@ -433,7 +433,7 @@ where impl FreedmanDiaconis where - T: Ord + Clone + FromPrimitive + NumOps + Zero, + T: Ord + Send + Clone + FromPrimitive + NumOps + Zero, { fn compute_bin_width(n_bins: usize, iqr: T) -> T { // casting `n_bins: usize` to `f64` may casus off-by-one error here if `n_bins` > 2 ^ 53, @@ -451,7 +451,7 @@ where impl BinsBuildingStrategy for Auto where - T: Ord + Clone + FromPrimitive + NumOps + Zero, + T: Ord + Send + Clone + FromPrimitive + NumOps + Zero, { type Elem = T; @@ -504,7 +504,7 @@ where impl Auto where - T: Ord + Clone + FromPrimitive + NumOps + Zero, + T: Ord + Send + Clone + FromPrimitive + NumOps + Zero, { /// The bin width (or bin length) according to the fitted strategy. pub fn bin_width(&self) -> T { @@ -524,7 +524,7 @@ where /// **Panics** if `n_bins == 0` and division by 0 panics for `T`. fn compute_bin_width(min: T, max: T, n_bins: usize) -> T where - T: Ord + Clone + FromPrimitive + NumOps + Zero, + T: Ord + Send + Clone + FromPrimitive + NumOps + Zero, { let range = max - min; range / T::from_usize(n_bins).unwrap() diff --git a/src/lib.rs b/src/lib.rs index 4ae11004..45d02e64 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -35,7 +35,6 @@ pub use crate::entropy::EntropyExt; pub use crate::histogram::HistogramExt; pub use crate::maybe_nan::{MaybeNan, MaybeNanExt}; pub use crate::quantile::{interpolate, Quantile1dExt, QuantileExt}; -pub use crate::sort::Sort1dExt; pub use crate::summary_statistics::SummaryStatisticsExt; #[cfg(test)] @@ -105,5 +104,4 @@ pub mod errors; pub mod histogram; mod maybe_nan; mod quantile; -mod sort; mod summary_statistics; diff --git a/src/quantile/mod.rs b/src/quantile/mod.rs index 3fea4a65..4cae6aa2 100644 --- a/src/quantile/mod.rs +++ b/src/quantile/mod.rs @@ -1,16 +1,17 @@ use self::interpolate::{higher_index, lower_index, Interpolate}; -use super::sort::get_many_from_sorted_mut_unchecked; use crate::errors::QuantileError; use crate::errors::{EmptyInput, MinMaxError, MinMaxError::UndefinedOrder}; use crate::{MaybeNan, MaybeNanExt}; use ndarray::prelude::*; use ndarray::{Data, DataMut, RemoveAxis, Zip}; +use ndarray_slice::Slice1Ext; use noisy_float::types::N64; -use std::cmp; +use std::{cmp, collections::HashMap}; /// Quantile methods for `ArrayBase`. pub trait QuantileExt where + A: Send, S: Data, D: Dimension, { @@ -38,7 +39,7 @@ where /// ``` fn argmin(&self) -> Result where - A: PartialOrd; + A: PartialOrd + Send; /// Finds the index of the minimum value of the array skipping NaN values. /// @@ -62,7 +63,7 @@ where fn argmin_skipnan(&self) -> Result where A: MaybeNan, - A::NotNan: Ord; + A::NotNan: Ord + Send; /// Finds the elementwise minimum of the array. /// @@ -77,7 +78,7 @@ where /// the memory layout of the array.) fn min(&self) -> Result<&A, MinMaxError> where - A: PartialOrd; + A: PartialOrd + Send; /// Finds the elementwise minimum of the array, skipping NaN values. /// @@ -91,7 +92,7 @@ where fn min_skipnan(&self) -> &A where A: MaybeNan, - A::NotNan: Ord; + A::NotNan: Ord + Send; /// Finds the index of the maximum value of the array. /// @@ -117,7 +118,7 @@ where /// ``` fn argmax(&self) -> Result where - A: PartialOrd; + A: PartialOrd + Send; /// Finds the index of the maximum value of the array skipping NaN values. /// @@ -141,7 +142,7 @@ where fn argmax_skipnan(&self) -> Result where A: MaybeNan, - A::NotNan: Ord; + A::NotNan: Ord + Send; /// Finds the elementwise maximum of the array. /// @@ -156,7 +157,7 @@ where /// the memory layout of the array.) fn max(&self) -> Result<&A, MinMaxError> where - A: PartialOrd; + A: PartialOrd + Send; /// Finds the elementwise maximum of the array, skipping NaN values. /// @@ -170,7 +171,7 @@ where fn max_skipnan(&self) -> &A where A: MaybeNan, - A::NotNan: Ord; + A::NotNan: Ord + Send; /// Return the qth quantile of the data along the specified axis. /// @@ -213,7 +214,7 @@ where ) -> Result, QuantileError> where D: RemoveAxis, - A: Ord + Clone, + A: Ord + Send + Clone, S: DataMut, I: Interpolate; @@ -257,7 +258,7 @@ where ) -> Result, QuantileError> where D: RemoveAxis, - A: Ord + Clone, + A: Ord + Send + Clone, S: DataMut, S2: Data, I: Interpolate; @@ -274,7 +275,7 @@ where where D: RemoveAxis, A: MaybeNan, - A::NotNan: Clone + Ord, + A::NotNan: Clone + Ord + Send, S: DataMut, I: Interpolate; @@ -283,12 +284,13 @@ where impl QuantileExt for ArrayBase where + A: Send, S: Data, D: Dimension, { fn argmin(&self) -> Result where - A: PartialOrd, + A: PartialOrd + Send, { let mut current_min = self.first().ok_or(EmptyInput)?; let mut current_pattern_min = D::zeros(self.ndim()).into_pattern(); @@ -306,7 +308,7 @@ where fn argmin_skipnan(&self) -> Result where A: MaybeNan, - A::NotNan: Ord, + A::NotNan: Ord + Send, { let mut pattern_min = D::zeros(self.ndim()).into_pattern(); let min = self.indexed_fold_skipnan(None, |current_min, (pattern, elem)| { @@ -327,7 +329,7 @@ where fn min(&self) -> Result<&A, MinMaxError> where - A: PartialOrd, + A: PartialOrd + Send, { let first = self.first().ok_or(EmptyInput)?; self.fold(Ok(first), |acc, elem| { @@ -342,7 +344,7 @@ where fn min_skipnan(&self) -> &A where A: MaybeNan, - A::NotNan: Ord, + A::NotNan: Ord + Send, { let first = self.first().and_then(|v| v.try_as_not_nan()); A::from_not_nan_ref_opt(self.fold_skipnan(first, |acc, elem| { @@ -355,7 +357,7 @@ where fn argmax(&self) -> Result where - A: PartialOrd, + A: PartialOrd + Send, { let mut current_max = self.first().ok_or(EmptyInput)?; let mut current_pattern_max = D::zeros(self.ndim()).into_pattern(); @@ -373,7 +375,7 @@ where fn argmax_skipnan(&self) -> Result where A: MaybeNan, - A::NotNan: Ord, + A::NotNan: Ord + Send, { let mut pattern_max = D::zeros(self.ndim()).into_pattern(); let max = self.indexed_fold_skipnan(None, |current_max, (pattern, elem)| { @@ -394,7 +396,7 @@ where fn max(&self) -> Result<&A, MinMaxError> where - A: PartialOrd, + A: PartialOrd + Send, { let first = self.first().ok_or(EmptyInput)?; self.fold(Ok(first), |acc, elem| { @@ -409,7 +411,7 @@ where fn max_skipnan(&self) -> &A where A: MaybeNan, - A::NotNan: Ord, + A::NotNan: Ord + Send, { let first = self.first().and_then(|v| v.try_as_not_nan()); A::from_not_nan_ref_opt(self.fold_skipnan(first, |acc, elem| { @@ -428,7 +430,7 @@ where ) -> Result, QuantileError> where D: RemoveAxis, - A: Ord + Clone, + A: Ord + Send + Clone, S: DataMut, S2: Data, I: Interpolate, @@ -442,7 +444,7 @@ where ) -> Result, QuantileError> where D: RemoveAxis, - A: Ord + Clone, + A: Ord + Send + Clone, I: Interpolate, { for &q in qs { @@ -471,23 +473,36 @@ where searched_indexes.push(higher_index(q, axis_len)); } } - searched_indexes.sort(); - searched_indexes.dedup(); + let mut indexes = Array1::from_vec(searched_indexes); + indexes.sort_unstable(); + let (indexes, _duplicates) = indexes.partition_dedup(); let mut results = Array::from_elem(results_shape, data.first().unwrap().clone()); Zip::from(results.lanes_mut(axis)) .and(data.lanes_mut(axis)) .for_each(|mut results, mut data| { - let index_map = - get_many_from_sorted_mut_unchecked(&mut data, &searched_indexes); + #[cfg(feature = "rayon")] + let values = { + let mut values = Vec::new(); + data.par_select_many_nth_unstable(&indexes, &mut values); + HashMap::::from_iter( + indexes.iter().copied().zip(values.into_iter()), + ) + }; + #[cfg(not(feature = "rayon"))] + let values = { + let mut values = HashMap::new(); + data.select_many_nth_unstable(&indexes, &mut values); + values + }; for (result, &q) in results.iter_mut().zip(qs) { let lower = if I::needs_lower(q, axis_len) { - Some(index_map[&lower_index(q, axis_len)].clone()) + Some(values[&lower_index(q, axis_len)].clone()) } else { None }; let higher = if I::needs_higher(q, axis_len) { - Some(index_map[&higher_index(q, axis_len)].clone()) + Some(values[&higher_index(q, axis_len)].clone()) } else { None }; @@ -508,7 +523,7 @@ where ) -> Result, QuantileError> where D: RemoveAxis, - A: Ord + Clone, + A: Ord + Send + Clone, S: DataMut, I: Interpolate, { @@ -525,7 +540,7 @@ where where D: RemoveAxis, A: MaybeNan, - A::NotNan: Clone + Ord, + A::NotNan: Clone + Ord + Send, S: DataMut, I: Interpolate, { @@ -559,6 +574,7 @@ where /// Quantile methods for 1-D arrays. pub trait Quantile1dExt where + A: Send, S: Data, { /// Return the qth quantile of the data. @@ -592,7 +608,7 @@ where /// Returns `Err(InvalidQuantile(q))` if `q` is not between `0.` and `1.` (inclusive). fn quantile_mut(&mut self, q: N64, interpolate: &I) -> Result where - A: Ord + Clone, + A: Ord + Send + Clone, S: DataMut, I: Interpolate; @@ -617,7 +633,7 @@ where interpolate: &I, ) -> Result, QuantileError> where - A: Ord + Clone, + A: Ord + Send + Clone, S: DataMut, S2: Data, I: Interpolate; @@ -627,11 +643,12 @@ where impl Quantile1dExt for ArrayBase where + A: Send, S: Data, { fn quantile_mut(&mut self, q: N64, interpolate: &I) -> Result where - A: Ord + Clone, + A: Ord + Send + Clone, S: DataMut, I: Interpolate, { @@ -646,7 +663,7 @@ where interpolate: &I, ) -> Result, QuantileError> where - A: Ord + Clone, + A: Ord + Send + Clone, S: DataMut, S2: Data, I: Interpolate, diff --git a/src/sort.rs b/src/sort.rs deleted file mode 100644 index f43a95b1..00000000 --- a/src/sort.rs +++ /dev/null @@ -1,298 +0,0 @@ -use indexmap::IndexMap; -use ndarray::prelude::*; -use ndarray::{Data, DataMut, Slice}; -use rand::prelude::*; -use rand::thread_rng; - -/// Methods for sorting and partitioning 1-D arrays. -pub trait Sort1dExt -where - S: Data, -{ - /// Return the element that would occupy the `i`-th position if - /// the array were sorted in increasing order. - /// - /// The array is shuffled **in place** to retrieve the desired element: - /// no copy of the array is allocated. - /// After the shuffling, all elements with an index smaller than `i` - /// are smaller than the desired element, while all elements with - /// an index greater or equal than `i` are greater than or equal - /// to the desired element. - /// - /// No other assumptions should be made on the ordering of the - /// elements after this computation. - /// - /// Complexity ([quickselect](https://en.wikipedia.org/wiki/Quickselect)): - /// - average case: O(`n`); - /// - worst case: O(`n`^2); - /// where n is the number of elements in the array. - /// - /// **Panics** if `i` is greater than or equal to `n`. - fn get_from_sorted_mut(&mut self, i: usize) -> A - where - A: Ord + Clone, - S: DataMut; - - /// A bulk version of [`get_from_sorted_mut`], optimized to retrieve multiple - /// indexes at once. - /// It returns an `IndexMap`, with indexes as keys and retrieved elements as - /// values. - /// The `IndexMap` is sorted with respect to indexes in increasing order: - /// this ordering is preserved when you iterate over it (using `iter`/`into_iter`). - /// - /// **Panics** if any element in `indexes` is greater than or equal to `n`, - /// where `n` is the length of the array.. - /// - /// [`get_from_sorted_mut`]: #tymethod.get_from_sorted_mut - fn get_many_from_sorted_mut(&mut self, indexes: &ArrayBase) -> IndexMap - where - A: Ord + Clone, - S: DataMut, - S2: Data; - - /// Partitions the array in increasing order based on the value initially - /// located at `pivot_index` and returns the new index of the value. - /// - /// The elements are rearranged in such a way that the value initially - /// located at `pivot_index` is moved to the position it would be in an - /// array sorted in increasing order. The return value is the new index of - /// the value after rearrangement. All elements smaller than the value are - /// moved to its left and all elements equal or greater than the value are - /// moved to its right. The ordering of the elements in the two partitions - /// is undefined. - /// - /// `self` is shuffled **in place** to operate the desired partition: - /// no copy of the array is allocated. - /// - /// The method uses Hoare's partition algorithm. - /// Complexity: O(`n`), where `n` is the number of elements in the array. - /// Average number of element swaps: n/6 - 1/3 (see - /// [link](https://cs.stackexchange.com/questions/11458/quicksort-partitioning-hoare-vs-lomuto/11550)) - /// - /// **Panics** if `pivot_index` is greater than or equal to `n`. - /// - /// # Example - /// - /// ``` - /// use ndarray::array; - /// use ndarray_stats::Sort1dExt; - /// - /// let mut data = array![3, 1, 4, 5, 2]; - /// let pivot_index = 2; - /// let pivot_value = data[pivot_index]; - /// - /// // Partition by the value located at `pivot_index`. - /// let new_index = data.partition_mut(pivot_index); - /// // The pivot value is now located at `new_index`. - /// assert_eq!(data[new_index], pivot_value); - /// // Elements less than that value are moved to the left. - /// for i in 0..new_index { - /// assert!(data[i] < pivot_value); - /// } - /// // Elements greater than or equal to that value are moved to the right. - /// for i in (new_index + 1)..data.len() { - /// assert!(data[i] >= pivot_value); - /// } - /// ``` - fn partition_mut(&mut self, pivot_index: usize) -> usize - where - A: Ord + Clone, - S: DataMut; - - private_decl! {} -} - -impl Sort1dExt for ArrayBase -where - S: Data, -{ - fn get_from_sorted_mut(&mut self, i: usize) -> A - where - A: Ord + Clone, - S: DataMut, - { - let n = self.len(); - if n == 1 { - self[0].clone() - } else { - let mut rng = thread_rng(); - let pivot_index = rng.gen_range(0..n); - let partition_index = self.partition_mut(pivot_index); - if i < partition_index { - self.slice_axis_mut(Axis(0), Slice::from(..partition_index)) - .get_from_sorted_mut(i) - } else if i == partition_index { - self[i].clone() - } else { - self.slice_axis_mut(Axis(0), Slice::from(partition_index + 1..)) - .get_from_sorted_mut(i - (partition_index + 1)) - } - } - } - - fn get_many_from_sorted_mut(&mut self, indexes: &ArrayBase) -> IndexMap - where - A: Ord + Clone, - S: DataMut, - S2: Data, - { - let mut deduped_indexes: Vec = indexes.to_vec(); - deduped_indexes.sort_unstable(); - deduped_indexes.dedup(); - - get_many_from_sorted_mut_unchecked(self, &deduped_indexes) - } - - fn partition_mut(&mut self, pivot_index: usize) -> usize - where - A: Ord + Clone, - S: DataMut, - { - let pivot_value = self[pivot_index].clone(); - self.swap(pivot_index, 0); - let n = self.len(); - let mut i = 1; - let mut j = n - 1; - loop { - loop { - if i > j { - break; - } - if self[i] >= pivot_value { - break; - } - i += 1; - } - while pivot_value <= self[j] { - if j == 1 { - break; - } - j -= 1; - } - if i >= j { - break; - } else { - self.swap(i, j); - i += 1; - j -= 1; - } - } - self.swap(0, i - 1); - i - 1 - } - - private_impl! {} -} - -/// To retrieve multiple indexes from the sorted array in an optimized fashion, -/// [get_many_from_sorted_mut] first of all sorts and deduplicates the -/// `indexes` vector. -/// -/// `get_many_from_sorted_mut_unchecked` does not perform this sorting and -/// deduplication, assuming that the user has already taken care of it. -/// -/// Useful when you have to call [get_many_from_sorted_mut] multiple times -/// using the same indexes. -/// -/// [get_many_from_sorted_mut]: ../trait.Sort1dExt.html#tymethod.get_many_from_sorted_mut -pub(crate) fn get_many_from_sorted_mut_unchecked( - array: &mut ArrayBase, - indexes: &[usize], -) -> IndexMap -where - A: Ord + Clone, - S: DataMut, -{ - if indexes.is_empty() { - return IndexMap::new(); - } - - // Since `!indexes.is_empty()` and indexes must be in-bounds, `array` must - // be non-empty. - let mut values = vec![array[0].clone(); indexes.len()]; - _get_many_from_sorted_mut_unchecked(array.view_mut(), &mut indexes.to_owned(), &mut values); - - // We convert the vector to a more search-friendly `IndexMap`. - indexes.iter().cloned().zip(values.into_iter()).collect() -} - -/// This is the recursive portion of `get_many_from_sorted_mut_unchecked`. -/// -/// `indexes` is the list of indexes to get. `indexes` is mutable so that it -/// can be used as scratch space for this routine; the value of `indexes` after -/// calling this routine should be ignored. -/// -/// `values` is a pre-allocated slice to use for writing the output. Its -/// initial element values are ignored. -fn _get_many_from_sorted_mut_unchecked( - mut array: ArrayViewMut1<'_, A>, - indexes: &mut [usize], - values: &mut [A], -) where - A: Ord + Clone, -{ - let n = array.len(); - debug_assert!(n >= indexes.len()); // because indexes must be unique and in-bounds - debug_assert_eq!(indexes.len(), values.len()); - - if indexes.is_empty() { - // Nothing to do in this case. - return; - } - - // At this point, `n >= 1` since `indexes.len() >= 1`. - if n == 1 { - // We can only reach this point if `indexes.len() == 1`, so we only - // need to assign the single value, and then we're done. - debug_assert_eq!(indexes.len(), 1); - values[0] = array[0].clone(); - return; - } - - // We pick a random pivot index: the corresponding element is the pivot value - let mut rng = thread_rng(); - let pivot_index = rng.gen_range(0..n); - - // We partition the array with respect to the pivot value. - // The pivot value moves to `array_partition_index`. - // Elements strictly smaller than the pivot value have indexes < `array_partition_index`. - // Elements greater or equal to the pivot value have indexes > `array_partition_index`. - let array_partition_index = array.partition_mut(pivot_index); - - // We use a divide-and-conquer strategy, splitting the indexes we are - // searching for (`indexes`) and the corresponding portions of the output - // slice (`values`) into pieces with respect to `array_partition_index`. - let (found_exact, index_split) = match indexes.binary_search(&array_partition_index) { - Ok(index) => (true, index), - Err(index) => (false, index), - }; - let (smaller_indexes, other_indexes) = indexes.split_at_mut(index_split); - let (smaller_values, other_values) = values.split_at_mut(index_split); - let (bigger_indexes, bigger_values) = if found_exact { - other_values[0] = array[array_partition_index].clone(); // Write exactly found value. - (&mut other_indexes[1..], &mut other_values[1..]) - } else { - (other_indexes, other_values) - }; - - // We search recursively for the values corresponding to strictly smaller - // indexes to the left of `partition_index`. - _get_many_from_sorted_mut_unchecked( - array.slice_axis_mut(Axis(0), Slice::from(..array_partition_index)), - smaller_indexes, - smaller_values, - ); - - // We search recursively for the values corresponding to strictly bigger - // indexes to the right of `partition_index`. Since only the right portion - // of the array is passed in, the indexes need to be shifted by length of - // the removed portion. - bigger_indexes - .iter_mut() - .for_each(|x| *x -= array_partition_index + 1); - _get_many_from_sorted_mut_unchecked( - array.slice_axis_mut(Axis(0), Slice::from(array_partition_index + 1..)), - bigger_indexes, - bigger_values, - ); -} diff --git a/src/summary_statistics/mod.rs b/src/summary_statistics/mod.rs index 1f8fe000..6cac75df 100644 --- a/src/summary_statistics/mod.rs +++ b/src/summary_statistics/mod.rs @@ -48,7 +48,7 @@ where /// * `MultiInputError::EmptyInput` if `self` is empty /// * `MultiInputError::ShapeMismatch` if `self` and `weights` don't have the same shape /// - /// [`arithmetic weighted mean`] https://en.wikipedia.org/wiki/Weighted_arithmetic_mean + /// [`arithmetic weighted mean`]: https://en.wikipedia.org/wiki/Weighted_arithmetic_mean fn weighted_mean(&self, weights: &Self) -> Result where A: Copy + Div + Mul + Zero; @@ -89,7 +89,7 @@ where /// * `MultiInputError::EmptyInput` if `self` is empty /// * `MultiInputError::ShapeMismatch` if `self` length along axis is not equal to `weights` length /// - /// [`arithmetic weighted mean`] https://en.wikipedia.org/wiki/Weighted_arithmetic_mean + /// [`arithmetic weighted mean`]: https://en.wikipedia.org/wiki/Weighted_arithmetic_mean fn weighted_mean_axis( &self, axis: Axis, diff --git a/tests/quantile.rs b/tests/quantile.rs index 9d58071f..bd418a3c 100644 --- a/tests/quantile.rs +++ b/tests/quantile.rs @@ -168,6 +168,45 @@ fn test_max_skipnan_all_nan() { assert!(a.max_skipnan().is_nan()); } +#[test] +fn test_quantile_mut_with_large_array_of_equal_floats() { + let mut array: Array1 = Array1::ones(10_000_000); + array.quantile_mut(n64(0.5), &Linear).unwrap(); +} + +#[test] +fn test_quantile_mut_with_large_array_of_sorted_floats() { + let mut array: Array1 = Array1::range(n64(0.0), n64(1e7), n64(1.0)); + array.quantile_mut(n64(0.5), &Linear).unwrap(); +} + +#[test] +fn test_quantile_mut_with_large_array_of_rev_sorted_floats() { + let mut array: Array1 = Array1::range(n64(1e7), n64(0.0), n64(-1.0)); + array.quantile_mut(n64(0.5), &Linear).unwrap(); +} + +#[test] +fn test_quantiles_mut_with_large_array_of_equal_floats() { + let mut array: Array1 = Array1::ones(10_000_000); + let quantiles: Array1 = Array1::range(n64(0.0), n64(1.0), n64(1e-5)); + array.quantiles_mut(&quantiles, &Linear).unwrap(); +} + +#[test] +fn test_quantiles_mut_with_large_array_of_sorted_floats() { + let mut array: Array1 = Array1::range(n64(0.0), n64(1e7), n64(1.0)); + let quantiles: Array1 = Array1::range(n64(0.0), n64(1.0), n64(1e-5)); + array.quantiles_mut(&quantiles, &Linear).unwrap(); +} + +#[test] +fn test_quantiles_mut_with_large_array_of_rev_sorted_floats() { + let mut array: Array1 = Array1::range(n64(1e7), n64(0.0), n64(-1.0)); + let quantiles: Array1 = Array1::range(n64(0.0), n64(1.0), n64(1e-5)); + array.quantiles_mut(&quantiles, &Linear).unwrap(); +} + #[test] fn test_quantile_axis_mut_with_odd_axis_length() { let mut a = arr2(&[[1, 3, 2, 10], [2, 4, 3, 11], [3, 5, 6, 12]]); diff --git a/tests/sort.rs b/tests/sort.rs index b2bd12f1..d46475e1 100644 --- a/tests/sort.rs +++ b/tests/sort.rs @@ -1,45 +1,16 @@ use ndarray::prelude::*; -use ndarray_stats::Sort1dExt; +use ndarray_slice::Slice1Ext; use quickcheck_macros::quickcheck; - -#[test] -fn test_partition_mut() { - let mut l = vec![ - arr1(&[1, 1, 1, 1, 1]), - arr1(&[1, 3, 2, 10, 10]), - arr1(&[2, 3, 4, 1]), - arr1(&[ - 355, 453, 452, 391, 289, 343, 44, 154, 271, 44, 314, 276, 160, 469, 191, 138, 163, 308, - 395, 3, 416, 391, 210, 354, 200, - ]), - arr1(&[ - 84, 192, 216, 159, 89, 296, 35, 213, 456, 278, 98, 52, 308, 418, 329, 173, 286, 106, - 366, 129, 125, 450, 23, 463, 151, - ]), - ]; - for a in l.iter_mut() { - let n = a.len(); - let pivot_index = n - 1; - let pivot_value = a[pivot_index].clone(); - let partition_index = a.partition_mut(pivot_index); - for i in 0..partition_index { - assert!(a[i] < pivot_value); - } - assert_eq!(a[partition_index], pivot_value); - for j in (partition_index + 1)..n { - assert!(pivot_value <= a[j]); - } - } -} +use std::collections::HashMap; #[test] fn test_sorted_get_mut() { let a = arr1(&[1, 3, 2, 10]); - let j = a.clone().view_mut().get_from_sorted_mut(2); + let j = *a.clone().view_mut().select_nth_unstable(2).1; assert_eq!(j, 3); - let j = a.clone().view_mut().get_from_sorted_mut(1); + let j = *a.clone().view_mut().select_nth_unstable(1).1; assert_eq!(j, 2); - let j = a.clone().view_mut().get_from_sorted_mut(3); + let j = *a.clone().view_mut().select_nth_unstable(3).1; assert_eq!(j, 10); } @@ -54,21 +25,16 @@ fn test_sorted_get_many_mut(mut xs: Vec) -> bool { // Insert each index twice, to get a set of indexes with duplicates, not sorted let mut indexes: Vec = (0..n).into_iter().collect(); indexes.append(&mut (0..n).collect()); + let mut indexes = Array::from(indexes); + indexes.sort_unstable(); + let (indexes, _duplicates) = indexes.partition_dedup(); - let mut sorted_v = Vec::with_capacity(n); - for (i, (key, value)) in v - .get_many_from_sorted_mut(&Array::from(indexes)) - .into_iter() - .enumerate() - { - if i != key { - return false; - } - sorted_v.push(value); - } + let mut map = HashMap::new(); + v.select_many_nth_unstable(&indexes, &mut map); + let sorted_v = indexes.map(|index| *map[index]); xs.sort(); println!("Sorted: {:?}. Truth: {:?}", sorted_v, xs); - xs == sorted_v + Array::from_vec(xs) == sorted_v } } @@ -79,7 +45,7 @@ fn test_sorted_get_mut_as_sorting_algorithm(mut xs: Vec) -> bool { true } else { let mut v = Array::from(xs.clone()); - let sorted_v: Vec<_> = (0..n).map(|i| v.get_from_sorted_mut(i)).collect(); + let sorted_v: Vec<_> = (0..n).map(|i| *v.select_nth_unstable(i).1).collect(); xs.sort(); xs == sorted_v }