-
Notifications
You must be signed in to change notification settings - Fork 196
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow Numba to perform polymorphic dispatching #230
Comments
Thanks for pointing this out. I initially started with numba's polymorphic dispatching but moved to strict types later on. There are advantages offered by strict typing like ease of debugging. For some users (like myself), the JIT compilation overhead also starts to add up over multiple Python sessions and can be somewhat frustrating when experimenting in an interactive context. For performance, my inclination would be to export multiple versions of each numba function as shown below and then infer types from the source data: https://numba.readthedocs.io/en/stable/user/pycc.html#standalone-example |
@mdbartos I'm not sure if you noticed the note in the documentation about the pending deprecation of AOT compilation. I can understand the ease of debugging, however the inputs to most of these functions are already strictly typed, homogeneous arrays. The main request here is to allow using the smaller native dtypes of the source data. The compiled functions can still be cached, so the comment about compilation overhead adding up doesn't really make sense to me. |
Hi @groutr, in my experience numba still recompiles the function each time unless types are specified, even when Try running the following code sample, then resetting the kernel and trying again. The untyped one incurs compilation overhead on subsequent runs while the typed one does not: from numba import njit
import numpy as np
from numba.types import float64
@njit(float64(float64[:]), cache=True)
def norm_squared_typed(vec):
n = len(vec)
result = 0.
for i in range(n):
result += vec[i]**2
return result
@njit(cache=True)
def norm_squared_untyped(vec):
n = len(vec)
result = 0.
for i in range(n):
result += vec[i]**2
return result
vec = np.arange(10, dtype=np.float64)
%time norm_squared_typed(vec)
%time norm_squared_untyped(vec)
|
@mdbartos I think there is a little bit of misunderstanding of what numba caching is actually doing. Your benchmark isn't measuring what you think it's measuring. After playing around with your example, what I think is happening is: When you reset the kernel, the both functions get recompiled. It's just one function gets compiled much earlier, when you define the function, making it appear to be faster than the lazy compiled function on the first call. In reality, they take the same time to complete. |
Thanks @groutr for the explanation. I re-tested with different breakpoints for timing and that appears to be correct. I would be open to removing strict typing, but it would have to be done carefully. Maybe as part of a longer term release. |
Strictly typing each function makes numba unable to compile specialized functions for smaller types. This negatively impacts both performance and memory usage of the numba code.
An example of this is
_par_get_candidates_numba
called in theresolve_flats
function. I have a fairly large test case dem of shape (11431, 11292). Reading from the raster produces an array of float32 values (the native dtype of the dem). In theresolve_flats
function, we construct an array calledinsides
which is an array of indexes which naturally fit in the int32 datatype. Because of the typing of_par_get_candidates_numba
, numba only accepts arrays of float64 and int64, doubling the amount of memory for each array.Removing the type annotations allows numba to be smarter with how the function is compiled. It can specialize a (float32, int32) version that runs in about half the time than the (float64, int64) version and uses half the memory. This translates to being able to more quickly process larger DEMs.
You can read more about numba's dispatching here: https://numba.readthedocs.io/en/stable/developer/dispatching.html
The text was updated successfully, but these errors were encountered: