-
Notifications
You must be signed in to change notification settings - Fork 69
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support for Python array API standard #197
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is going to take a bit to go through and I'll be out next week. Looks a truly fantastic PR!
Looks like plenty of errors for Py3.6 where |
…(since earlier versions don't have importlib.metadata)
Thanks Daniel! The errors for Python <= 3.7 should be fixed. I've set it to bypass array API backend discovery altogether in such cases. Numpy doesn't support <= 3.7 as-is, so I don't think this will be missed. I've fixed up the formatting too. |
|
||
def build_expression(_, expr): # pragma: no cover | ||
"""Build an array API function based on ``arrays`` and ``expr``.""" | ||
|
||
def array_api_contract(*arrays): | ||
return expr._contract([make_to_array_function(array_api)(x) for x in arrays], backend=array_api.__name__) | ||
return expr._contract([_to_array_api(x) for x in arrays], backend=array_api.__name__) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jcmgray How about this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apologies for the delay here, this appears quite polished. From your view, what needs to happen to get this in?
from ..sharing import to_backend_cache_wrap | ||
|
||
|
||
def discover_array_apis(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Possible to type this function?
@@ -122,6 +123,7 @@ def has_tensordot(backend: str) -> bool: | |||
"cupy": _cupy.build_expression, | |||
"torch": _torch.build_expression, | |||
"jax": _jax.build_expression, | |||
**_array_api.build_expression, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any way to check if Jax
adds an array interface so that we can catch it and test before hand?
return [ep.load() for ep in eps] | ||
else: | ||
# importlib.metadata was introduced in Python 3.8, so it isn't available here. Unable to discover any array APIs. | ||
return [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Numpy does not officially support below 3.8: https://numpy.org/neps/nep-0029-deprecation_policy.html
It would be worth considering dropping Python 3.7 as well. @jcmgray what do you think?
@IsaacBreen I'm looking to push out another release in the next month or so. Is this a PR which you would like to support? |
Description
This PR adds support for libraries that conform to the Python array API standard. It discovers complying modules upon import of opt_einsum.
Eventually (and hopefully soon), once the array API standard becomes ubiquitous, this should largely eliminate the need for backend-specific code.
Note that to test this you'll need a version of an array library which implements the API. Right now AFAIK closest thing to this is latest (dev) Numpy which you can install using
Todos
Status