Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf(rust, python): SIMD accelerated arg_min/arg_max (via argminmax) #8074

Merged
merged 9 commits into from
Apr 14, 2023

Conversation

jvdd
Copy link
Contributor

@jvdd jvdd commented Apr 8, 2023

This PR uses argminmax to provide faster arg_min and arg_max operations.

Motivation

After playing around with writing a custom trait for argmin & argmax I observed considerable speed ups when using the argminmax crate over polars. Eventually, I ended up diving in the polars codebase and integrated argminmax - which resulted in this PR.

I conducted a very basic benchmark and found:

  • a ~4.5x improvement for a single chunk (internally handled as a slice)
  • a ~25x improvement for multiple chunks (6 chunks of 10k i32 values)

The benchmark code + results are included at the bottom.

Changes made

(I'll discuss the changes for arg_min, those for arg_max are identical)

Minor changes were made to the previous code, where the .iter().reduce() implementation is now only used for Utf8Chunked. Since this implementation is only used for string data, I renamed the arg_min to arg_min_str.

Updated code (using argminmax):

  • slice: arg_min_slice now calls vals.argmin()
    • renamed this to arg_min_numeric_slice (since it is only called on numeric dtypes)
  • chunked array:
    • calls argminmax on chunks (does not even require the "arrow2" feature as .values().as_slice() is called on the chunks)
    • renamed arg_min -> arg_min_numeric (since it is only called on numeric dtypes)

A big red flag in the current code is that argminmax assumes non-empty slices, resulting in adding .is_empty checks for each chunk or slice. 🤔

TODOs

  • create a new release of argminmax (that contains the now optimized argmin & argmax functions)
  • add more extensive testing for various edge cases (null chunks, multiple chunks with nulls, nan-handling, ...)

Potential improvements / considerations

I think the following three points are worth considering

  • I could adapt argminmax to return Option instead of usize (as such making it return None for empty arrays)
  • We should make sure that floats are handled identical - I believe currently polars ignore nans for arg_min and arg_max which complies with how the ArgMinMax trait handles this.
    • is it correct that float16 is not supported by polars?
  • It might be worth considering adding .arg_min_max to get both in one pass over the data. As this task is mainly bound by CPU memory bandwidth, it is unfortunate that the data needs to be passed twice to the CPU when interested in both argmin and argmax values. The benchmark results below demonstrate that argminmax can extract both values in approximately the same time required to extract either one.

Benchmarks

Code

main.rs ⬇️

use argminmax::ArgMinMax;
use polars::prelude::*;
use rand::prelude::*;

// import black_box to prevent the compiler from optimizing away the loop
use std::hint::black_box;

// ----------------------------------------------------------------
// basic benchmark
// ----------------------------------------------------------------

fn random_i32_vec(n: usize) -> Vec<i32> {
    let mut rng = rand::thread_rng();
    (0..n).map(|_| rng.gen_range(-100..100)).collect()
}

fn main() {
    // Create an empty polars series
    let s = Series::new("a", Vec::<i32>::new());
    assert_eq!(s.chunks().len(), 1);
    assert_eq!(s.arg_min(), None);
    assert_eq!(s.arg_max(), None);
    // Multiple empty chunks
    let mut s = Series::new("a", Vec::<i32>::new());
    _ = s.append(&Series::new("a", vec![1i32, 2i32]));
    _ = s.append(&Series::new("a", Vec::<i32>::new()));
    _ = s.append(&Series::new("a", vec![1i32, 2i32]));
    _ = s.append(&Series::new("a", Vec::<i32>::new()));
    assert!(s.chunks().len() > 2);
    assert_eq!(s.arg_min(), Some(0));
    assert_eq!(s.arg_max(), Some(1));

    println!("Benchmarking single chunk of 10k elements");

    // Create a polars i32 series of 10k elements (1 chunk)
    let mut v = random_i32_vec(10_000);
    v[376] = 1000;
    v[7853] = -1000;
    let s = Series::new("a", v.clone());
    assert_eq!(s.chunks().len(), 1);

    // polars
    let start = std::time::Instant::now();
    for _ in 0..10 {
        let max_idx = black_box(s.arg_max().unwrap());
        assert_eq!(max_idx, 376);
        let min_idx = black_box(s.arg_min().unwrap());
        assert_eq!(min_idx, 7853);
    }
    let end = std::time::Instant::now();
    println!("polars: {:?}", end - start);

    // argminmax
    let start = std::time::Instant::now();
    for _ in 0..10 {
        let argminmax_idx = black_box(v.argmax());
        assert_eq!(argminmax_idx, 376);
        let argminmax_idx = black_box(v.argmin());
        assert_eq!(argminmax_idx, 7853);
    }
    let end = std::time::Instant::now();
    println!("argminmax: {:?}", end - start);

    let start = std::time::Instant::now();
    for _ in 0..10 {
        let (min_idx, max_idx) = black_box(v.argminmax());
        assert_eq!(min_idx, 7853);
        assert_eq!(max_idx, 376);
    }
    let end = std::time::Instant::now();
    println!("argminmax both in 1 pass: {:?}", end - start);

    println!("");
    println!("Benchmarking chunked series of 6 chunks of 10k elements each");

    // Create a polars i32 series that consists of 6 chunks
    let mut v = random_i32_vec(10_000);
    v[376] = 1000;
    v[7853] = -1000;

    let mut s = Series::new("a", v.clone());
    // append 5 chunks
    for _ in 0..5 {
        _ = s.append(&Series::new("a", v.clone()));
    }
    assert_eq!(s.chunks().len(), 6);

    // polars
    let start = std::time::Instant::now();
    for _ in 0..10 {
        let max_idx = black_box(s.arg_max().unwrap());
        assert_eq!(max_idx, 376);
        let min_idx = black_box(s.arg_min().unwrap());
        assert_eq!(min_idx, 7853);
    }
    let end = std::time::Instant::now();
    println!("polars: {:?}", end - start);
}

cargo.toml ⬇️

[package]
name = "polars_argmm"
version = "0.1.0"
edition = "2021"

[dependencies]
polars = { version = "0.28.0", features = ["simd"] }
# comment the above line and uncomment the below line to bench this PR
# polars = { git = "https://github.com/jvdd/polars", branch = "argminmax_integration" }
argminmax ={ git = "https://github.com/jvdd/argminmax" }
rand = "0.8.5"
rand_distr = "0.4.3"

Results

polars = { git = "https://github.com/jvdd/polars", branch = "argminmax_integration" } ⬇️

Benchmarking single chunk of 10k elements
polars: 63.706µs
argminmax: 61.45µs
argminmax both in 1 pass: 32.693µs

Benchmarking chunked series of 6 chunks of 10k elements each
polars: 375.618µs

polars = { version = "0.28.0", features = ["simd"] } ⬇️

Benchmarking single chunk of 10k elements
polars: 295.682µs
argminmax: 85.043µs
argminmax both in 1 pass: 41.74µs

Benchmarking chunked series of 6 chunks of 10k elements each
polars: 9.817482ms

@github-actions github-actions bot added the performance Performance issues or improvements label Apr 8, 2023
@ritchie46
Copy link
Member

Thanks @jvdd that are some impressive benchmarks. 🙌

I have left some comments.

@jvdd
Copy link
Contributor Author

jvdd commented Apr 9, 2023

Thank you @ritchie46! It seems as though there are no comments over here 😅

polars/polars-ops/src/series/ops/arg_min_max.rs Outdated Show resolved Hide resolved
polars/polars-ops/src/series/ops/arg_min_max.rs Outdated Show resolved Hide resolved
polars/polars-ops/Cargo.toml Outdated Show resolved Hide resolved
@ritchie46
Copy link
Member

Thank you @ritchie46! It seems as though there are no comments over here sweat_smile

Whoops. I forgot the hit the finish review botton! 🙈

@jvdd jvdd requested a review from ritchie46 April 10, 2023 10:04
Comment on lines +239 to +244
chunk_min_val = None; // because None < Some(_)
chunk_min_idx = arr
.into_iter()
.enumerate()
.reduce(|acc, (idx, val)| if acc.1 > val { (idx, val) } else { acc })
.map(|tpl| tpl.0);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Basically we should return the index of the first null value

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, we can speed this up. I can do this later. 👍

@jvdd jvdd requested a review from ritchie46 April 11, 2023 06:32
@jvdd
Copy link
Contributor Author

jvdd commented Apr 11, 2023

If you can point me to where add some more extensive tests (either in Python or Rust) - I’ll gladly include some in this PR :)

Copy link
Member

@ritchie46 ritchie46 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is almost there @jvdd. However we also need to compile for the stable rust compiler, that is the reason CI is failing here.

Polars has a simd feature gate that is activated on the nightly compiler. I think we must feature gate this functionality under the simd feature and use the previous slower implementation on stable.

Comment on lines +239 to +244
chunk_min_val = None; // because None < Some(_)
chunk_min_idx = arr
.into_iter()
.enumerate()
.reduce(|acc, (idx, val)| if acc.1 > val { (idx, val) } else { acc })
.map(|tpl| tpl.0);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, we can speed this up. I can do this later. 👍

@jvdd
Copy link
Contributor Author

jvdd commented Apr 14, 2023

I've added "nightly_simd" as a (default) feature to argminmax since most x86/x86_64 SIMD intrinsics have been stabilized in Rust (except for AVX512 & most of NEON). Leaving this much (very common) compute power (SSE & AVX(2)) unused would have been a pity 😉

With this feature, the unstable AVX512 and NEON SIMD are toggled. I've created a new release that includes this feature and updated the code to disable it by default. When the user enables the polars "simd" feature, this "nightly_simd" feature will be enabled.

@jvdd jvdd requested a review from ritchie46 April 14, 2023 11:24
@ritchie46
Copy link
Member

With this feature, the unstable AVX512 and NEON SIMD are toggled. I've created a new release that includes this feature and updated the code to disable it by default. When the user enables the polars "simd" feature, this "nightly_simd" feature will be enabled.

That's great. And on stable? It will default to a fallback?

@jvdd
Copy link
Contributor Author

jvdd commented Apr 14, 2023

With this feature, the unstable AVX512 and NEON SIMD are toggled. I've created a new release that includes this feature and updated the code to disable it by default. When the user enables the polars "simd" feature, this "nightly_simd" feature will be enabled.

That's great. And on stable? It will default to a fallback?

Yes - if no SIMD CPU instructions are available on the hardware it defaults to a scalar implementation.

@ritchie46
Copy link
Member

Great!

Thanks a lot @jvdd. Great PR. 🙌

To address some of your comments:> It might be worth considering adding .arg_min_max to get both in one pass over the data. As this task is mainly bound by CPU memory bandwidth, it is unfortunate that the data needs to be passed twice to the CPU when interested in both argmin and argmax values. The benchmark results below demonstrate that argminmax can extract both values in approximately the same time required to extract either one.

This is indeed an operation that's common enough to add this. I think we could have an expression extrema that finds both in a single pass.

In any case, this one is going in! :)

@ritchie46 ritchie46 changed the title perf: use argminmax for faster arg_min and arg_max perf(rust, python): SIMD accelerated arg_min/arg_max (via argminmax) Apr 14, 2023
@ritchie46 ritchie46 merged commit 4f90f47 into pola-rs:main Apr 14, 2023
@github-actions github-actions bot added python Related to Python Polars rust Related to Rust Polars labels Apr 14, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance Performance issues or improvements python Related to Python Polars rust Related to Rust Polars
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants