-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Simplifying lowering and abstract evaluation rules #89
Conversation
else: | ||
return source.update( | ||
shape=source.shape[:2] + (points[0].shape[-1],), dtype=source_dtype | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is an unrelated change that I noticed. We're no longer supposed to manually construct a ShapedArray
here because it should return the same type as the input (tracer vs concrete value).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like a nice simplification! If this requires a recent jax, do we need to set a minimum version in the dependencies?
@lgarrison — Good point! It looks like the oldest version of JAX that this supports is 0.4.20 from Nov 2023. This would make us somewhat bleeding edge, but I'm not too concerned about that. What do you think? |
I think that sounds fine, we already have a release that works with older JAX on PyPI. LGTM! |
With relatively recent versions of JAX we can remove some of the boilerplate in the MLIR lowering rules.