diff --git a/polars/Cargo.toml b/polars/Cargo.toml index 99d4b8d8929c..a60afad9d290 100644 --- a/polars/Cargo.toml +++ b/polars/Cargo.toml @@ -12,7 +12,7 @@ description = "DataFrame Library based on Apache Arrow" [features] sql = ["polars-sql"] rows = ["polars-core/rows"] -simd = ["polars-core/simd", "polars-io/simd"] +simd = ["polars-core/simd", "polars-io/simd", "polars-ops/simd"] avx512 = ["polars-core/avx512"] nightly = ["polars-core/nightly", "polars-ops/nightly", "simd"] docs = ["polars-core/docs"] diff --git a/polars/polars-ops/Cargo.toml b/polars/polars-ops/Cargo.toml index 57c2800998e4..f50569f11086 100644 --- a/polars/polars-ops/Cargo.toml +++ b/polars/polars-ops/Cargo.toml @@ -10,6 +10,7 @@ description = "More operations on polars data structures" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +argminmax = { version = "0.6.1", default-features = false, features = ["float"] } arrow.workspace = true base64 = { version = "0.21", optional = true } either.workspace = true @@ -24,6 +25,7 @@ serde_json = { version = "1", optional = true } smartstring.workspace = true [features] +simd = ["argminmax/nightly_simd"] nightly = ["polars-utils/nightly"] dtype-categorical = ["polars-core/dtype-categorical"] dtype-date = ["polars-core/dtype-date", "polars-core/temporal"] diff --git a/polars/polars-ops/src/series/ops/arg_min_max.rs b/polars/polars-ops/src/series/ops/arg_min_max.rs index 26f2e7e72cae..172a63533cd1 100644 --- a/polars/polars-ops/src/series/ops/arg_min_max.rs +++ b/polars/polars-ops/src/series/ops/arg_min_max.rs @@ -1,3 +1,5 @@ +use argminmax::ArgMinMax; +use arrow::array::Array; use arrow::bitmap::utils::{BitChunkIterExact, BitChunksExact}; use arrow::bitmap::Bitmap; use polars_core::series::IsSorted; @@ -20,7 +22,7 @@ impl ArgAgg for Series { match s.dtype() { Utf8 => { let ca = s.utf8().unwrap(); - arg_min(ca) + arg_min_str(ca) } Boolean => { let ca = s.bool().unwrap(); @@ -29,10 +31,12 @@ impl ArgAgg for Series { dt if dt.is_numeric() => { with_match_physical_numeric_polars_type!(s.dtype(), |$T| { let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); - if let Ok(vals) = ca.cont_slice() { - arg_min_slice(vals, ca.is_sorted_flag2()) + if ca.is_empty() { // because argminmax assumes not empty + None + } else if let Ok(vals) = ca.cont_slice() { + arg_min_numeric_slice(vals, ca.is_sorted_flag2()) } else { - arg_min(ca) + arg_min_numeric(ca) } }) } @@ -46,7 +50,7 @@ impl ArgAgg for Series { match s.dtype() { Utf8 => { let ca = s.utf8().unwrap(); - arg_max(ca) + arg_max_str(ca) } Boolean => { let ca = s.bool().unwrap(); @@ -55,10 +59,12 @@ impl ArgAgg for Series { dt if dt.is_numeric() => { with_match_physical_numeric_polars_type!(s.dtype(), |$T| { let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); - if let Ok(vals) = ca.cont_slice() { - arg_max_slice(vals, ca.is_sorted_flag2()) + if ca.is_empty() { // because argminmax assumes not empty + None + } else if let Ok(vals) = ca.cont_slice() { + arg_max_numeric_slice(vals, ca.is_sorted_flag2()) } else { - arg_max(ca) + arg_max_numeric(ca) } }) } @@ -67,7 +73,7 @@ impl ArgAgg for Series { } } -fn arg_max_bool(ca: &BooleanChunked) -> Option { +pub(crate) fn arg_max_bool(ca: &BooleanChunked) -> Option { if ca.is_empty() { None } else if ca.null_count() == ca.len() { @@ -188,12 +194,7 @@ fn first_unset_bit(mask: &Bitmap) -> usize { } } -fn arg_min<'a, T>(ca: &'a ChunkedArray) -> Option -where - T: PolarsDataType, - &'a ChunkedArray: IntoIterator, - <&'a ChunkedArray as IntoIterator>::Item: PartialOrd, -{ +fn arg_min_str(ca: &Utf8Chunked) -> Option { match ca.is_sorted_flag2() { IsSorted::Ascending => Some(0), IsSorted::Descending => Some(ca.len() - 1), @@ -205,12 +206,7 @@ where } } -pub(crate) fn arg_max<'a, T>(ca: &'a ChunkedArray) -> Option -where - T: PolarsDataType, - &'a ChunkedArray: IntoIterator, - <&'a ChunkedArray as IntoIterator>::Item: PartialOrd, -{ +fn arg_max_str(ca: &Utf8Chunked) -> Option { match ca.is_sorted_flag2() { IsSorted::Ascending => Some(ca.len() - 1), IsSorted::Descending => Some(0), @@ -222,32 +218,134 @@ where } } -fn arg_min_slice(vals: &[T], is_sorted: IsSorted) -> Option +fn arg_min_numeric<'a, T>(ca: &'a ChunkedArray) -> Option +where + T: PolarsNumericType, + for<'b> &'b [T::Native]: ArgMinMax, +{ + match ca.is_sorted_flag2() { + IsSorted::Ascending => Some(0), + IsSorted::Descending => Some(ca.len() - 1), + IsSorted::Not => { + ca.downcast_iter() + .fold((None, None, 0), |acc, arr| { + if arr.len() == 0 { + return acc; + } + let chunk_min_idx: Option; + let chunk_min_val: Option; + if arr.null_count() > 0 { + // When there are nulls, we should compare Option + chunk_min_val = None; // because None < Some(_) + chunk_min_idx = arr + .into_iter() + .enumerate() + .reduce(|acc, (idx, val)| if acc.1 > val { (idx, val) } else { acc }) + .map(|tpl| tpl.0); + } else { + // When no nulls & array not empty => we can use fast argminmax + let min_idx: usize = arr.values().as_slice().argmin(); + chunk_min_idx = Some(min_idx); + chunk_min_val = Some(arr.value(min_idx)); + } + + let new_offset: usize = acc.2 + arr.len(); + match acc { + (Some(_), Some(_), offset) => { + if chunk_min_val < acc.1 { + match chunk_min_idx { + Some(idx) => (Some(idx + offset), chunk_min_val, new_offset), + None => (acc.0, acc.1, new_offset), + } + } else { + (acc.0, acc.1, new_offset) + } + } + (None, None, offset) => match chunk_min_idx { + Some(idx) => (Some(idx + offset), chunk_min_val, new_offset), + None => (None, None, new_offset), + }, + _ => unreachable!(), + } + }) + .0 + } + } +} + +fn arg_max_numeric<'a, T>(ca: &'a ChunkedArray) -> Option +where + T: PolarsNumericType, + for<'b> &'b [T::Native]: ArgMinMax, +{ + match ca.is_sorted_flag2() { + IsSorted::Ascending => Some(ca.len() - 1), + IsSorted::Descending => Some(0), + IsSorted::Not => { + ca.downcast_iter() + .fold((None, None, 0), |acc, arr| { + if arr.len() == 0 { + return acc; + } + let chunk_max_idx: Option; + let chunk_max_val: Option; + if arr.null_count() > 0 { + // When there are nulls, we should compare Option + chunk_max_idx = arr + .into_iter() + .enumerate() + .reduce(|acc, (idx, val)| if acc.1 < val { (idx, val) } else { acc }) + .map(|tpl| tpl.0); + chunk_max_val = chunk_max_idx.map(|idx| arr.value(idx)); + } else { + // When no nulls & array not empty => we can use fast argminmax + let max_idx: usize = arr.values().as_slice().argmax(); + chunk_max_idx = Some(max_idx); + chunk_max_val = Some(arr.value(max_idx)); + } + + let new_offset: usize = acc.2 + arr.len(); + match acc { + (Some(_), Some(_), offset) => { + if chunk_max_val > acc.1 { + match chunk_max_idx { + Some(idx) => (Some(idx + offset), chunk_max_val, new_offset), + _ => unreachable!(), // because None < Some(_) + } + } else { + (acc.0, acc.1, new_offset) + } + } + (None, None, offset) => match chunk_max_idx { + Some(idx) => (Some(idx + offset), chunk_max_val, new_offset), + None => (None, None, new_offset), + }, + _ => unreachable!(), + } + }) + .0 + } + } +} + +fn arg_min_numeric_slice(vals: &[T], is_sorted: IsSorted) -> Option where - T: PartialOrd, + for<'a> &'a [T]: ArgMinMax, { match is_sorted { IsSorted::Ascending => Some(0), IsSorted::Descending => Some(vals.len() - 1), - IsSorted::Not => vals - .iter() - .enumerate() - .reduce(|acc, (idx, val)| if acc.1 > val { (idx, val) } else { acc }) - .map(|tpl| tpl.0), + IsSorted::Not => Some(vals.argmin()), // assumes not empty } } -fn arg_max_slice(vals: &[T], is_sorted: IsSorted) -> Option +fn arg_max_numeric_slice(vals: &[T], is_sorted: IsSorted) -> Option where - T: PartialOrd, + for<'a> &'a [T]: ArgMinMax, { match is_sorted { IsSorted::Ascending => Some(vals.len() - 1), IsSorted::Descending => Some(0), - IsSorted::Not => vals - .iter() - .enumerate() - .reduce(|acc, (idx, val)| if acc.1 < val { (idx, val) } else { acc }) - .map(|tpl| tpl.0), + IsSorted::Not => Some(vals.argmax()), // assumes not empty } } diff --git a/polars/polars-ops/src/series/ops/is_first.rs b/polars/polars-ops/src/series/ops/is_first.rs index 39defb741710..f0c69d2ba80b 100644 --- a/polars/polars-ops/src/series/ops/is_first.rs +++ b/polars/polars-ops/src/series/ops/is_first.rs @@ -6,7 +6,7 @@ use polars_arrow::utils::CustomIterTools; use polars_core::prelude::*; use polars_core::with_match_physical_integer_polars_type; -use crate::series::ops::arg_min_max::arg_max; +use crate::series::ops::arg_min_max::arg_max_bool; fn is_first_numeric(ca: &ChunkedArray) -> BooleanChunked where @@ -47,7 +47,7 @@ fn is_first_bin(ca: &BinaryChunked) -> BooleanChunked { fn is_first_boolean(ca: &BooleanChunked) -> BooleanChunked { let mut out = MutableBitmap::with_capacity(ca.len()); out.extend_constant(ca.len(), false); - if let Some(index) = arg_max(ca) { + if let Some(index) = arg_max_bool(ca) { out.set(index, true) } if let Some(index) = ca.first_non_null() { diff --git a/py-polars/Cargo.lock b/py-polars/Cargo.lock index 99d813750deb..bc300466bd24 100644 --- a/py-polars/Cargo.lock +++ b/py-polars/Cargo.lock @@ -59,6 +59,15 @@ dependencies = [ "libc", ] +[[package]] +name = "argminmax" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "202108b46429b765ef483f8a24d5c46f48c14acfdacc086dd4ab6dddf6bcdbd2" +dependencies = [ + "num-traits", +] + [[package]] name = "array-init-cursor" version = "0.2.0" @@ -1614,6 +1623,7 @@ dependencies = [ name = "polars-ops" version = "0.28.0" dependencies = [ + "argminmax", "arrow2", "base64", "either", diff --git a/py-polars/tests/unit/test_series.py b/py-polars/tests/unit/test_series.py index 2044a346427e..96260022d898 100644 --- a/py-polars/tests/unit/test_series.py +++ b/py-polars/tests/unit/test_series.py @@ -1803,6 +1803,11 @@ def test_arg_min_and_arg_max() -> None: assert s.arg_min() == 4 assert s.arg_max() == 0 + # test empty series + s = pl.Series("a", []) + assert s.arg_min() is None + assert s.arg_max() is None + def test_is_null_is_not_null() -> None: s = pl.Series("a", [1.0, 2.0, 3.0, None])