-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
perf(rust, python): SIMD accelerated arg_min
/arg_max
(via argminmax
)
#8074
Conversation
Thanks @jvdd that are some impressive benchmarks. 🙌 I have left some comments. |
Thank you @ritchie46! It seems as though there are no comments over here 😅 |
Whoops. I forgot the hit the |
chunk_min_val = None; // because None < Some(_) | ||
chunk_min_idx = arr | ||
.into_iter() | ||
.enumerate() | ||
.reduce(|acc, (idx, val)| if acc.1 > val { (idx, val) } else { acc }) | ||
.map(|tpl| tpl.0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Basically we should return the index of the first null value
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, we can speed this up. I can do this later. 👍
If you can point me to where add some more extensive tests (either in Python or Rust) - I’ll gladly include some in this PR :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it is almost there @jvdd. However we also need to compile for the stable rust compiler, that is the reason CI is failing here.
Polars has a simd
feature gate that is activated on the nightly
compiler. I think we must feature gate this functionality under the simd
feature and use the previous slower implementation on stable
.
chunk_min_val = None; // because None < Some(_) | ||
chunk_min_idx = arr | ||
.into_iter() | ||
.enumerate() | ||
.reduce(|acc, (idx, val)| if acc.1 > val { (idx, val) } else { acc }) | ||
.map(|tpl| tpl.0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, we can speed this up. I can do this later. 👍
I've added With this feature, the unstable |
That's great. And on |
Yes - if no SIMD CPU instructions are available on the hardware it defaults to a scalar implementation. |
Great! Thanks a lot @jvdd. Great PR. 🙌
This is indeed an operation that's common enough to add this. I think we could have an expression In any case, this one is going in! :) |
arg_min
/arg_max
(via argminmax
)
This PR uses
argminmax
to provide fasterarg_min
andarg_max
operations.Motivation
After playing around with writing a custom trait for
argmin
&argmax
I observed considerable speed ups when using the argminmax crate over polars. Eventually, I ended up diving in the polars codebase and integrated argminmax - which resulted in this PR.I conducted a very basic benchmark and found:
The benchmark code + results are included at the bottom.
Changes made
(I'll discuss the changes for
arg_min
, those forarg_max
are identical)Minor changes were made to the previous code, where the
.iter().reduce()
implementation is now only used forUtf8Chunked
. Since this implementation is only used for string data, I renamed thearg_min
toarg_min_str
.Updated code (using argminmax):
arg_min_slice
now callsvals.argmin()
arg_min_numeric_slice
(since it is only called on numeric dtypes).values().as_slice()
is called on the chunks)arg_min
->arg_min_numeric
(since it is only called on numeric dtypes)A big red flag in the current code is that
argminmax
assumes non-empty slices, resulting in adding.is_empty
checks for each chunk or slice. 🤔TODOs
Potential improvements / considerations
I think the following three points are worth considering
polars
ignore nans forarg_min
andarg_max
which complies with how theArgMinMax
trait handles this.float16
is not supported bypolars
?argminmax
can extract both values in approximately the same time required to extract either one.Benchmarks
Code
main.rs
⬇️cargo.toml
⬇️Results
polars = { git = "https://github.com/jvdd/polars", branch = "argminmax_integration" }
⬇️polars = { version = "0.28.0", features = ["simd"] }
⬇️