Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve compile time of evm.modifier.compose #12

Merged
merged 2 commits into from
Mar 15, 2024
Merged

Conversation

pfackeldey
Copy link
Owner

@pfackeldey pfackeldey commented Mar 15, 2024

This PR replaces most python for-loops by jax.lax.scan loop constructs in evm.modifier.compose. It automatically transforms an array of evm.Modifier(s) into a single evm.Modifier of arrays. This rearrangement allows to use jax.lax.scan over the stacked axis.
Currently it is not clear how to do this for every type of evm.modifier.ModifierBase (e.g. evm.modifier.where, ..., and custom modifiers), however most modifiers in a typical analysis will be of type evm.Modifier and thus will use jax.lax.scan.
The compile time reduction is roughly a factor of 4-5, without any performance losses:

compile time: original
took 87.3687 s
eval time: original
took 0.1036 s
compile time: this PR
took 21.7520 s
eval time: this PR
took 0.1016 s

This benchmark has been performed with 103.000 nuisance parameters, of which:

  • 100.000 are constrained by Poisson
  • 1000 by log normals
  • 1000 by Gaussians
  • 1000 by shape (up/down templates) uncertainties

These number correspond to roughly 10x the amount of modifiers that are used in a typical HEP analysis.

@pfackeldey pfackeldey merged commit 3299378 into main Mar 15, 2024
3 checks passed
@pfackeldey pfackeldey deleted the reduce_compiletime branch March 15, 2024 15:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant