-
Notifications
You must be signed in to change notification settings - Fork 62
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
docs: initial prototype of exporting Lux models to Jax #1088
Conversation
46fb13d
to
190970c
Compare
190970c
to
7798aec
Compare
Tested locally that this works (needs a custom build of EnzymeJAX) |
cc @wsmoses if you have any suggestions |
docs/src/manual/exporting_to_jax.md
Outdated
) | ||
|
||
|
||
# Note that all the inputs must be transposed, i.e. if the Lux model has an input of shape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Incidentally we can probably make a flag to compile/jit/code hlo to not transpose the inputs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But maybe regardless it’s worth explaining why: Julia default uses col major vs jax default uses row major
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Incidentally we can probably make a flag to compile/jit/code hlo to not transpose the inputs
I feel this will be very confusing for people who use Python exclusively. For people who use both (Julia and Python), switching is not that hard when converting the model.
I think we can add a function that serializes all the inputs from Julia + mlir code and we have a python function to deserialize it. That would also be handy for pre-trained models and such
But maybe regardless it’s worth explaining why: Julia default uses col major vs jax default uses row major
Agreed
|
||
# Note that all the inputs must be transposed, i.e. if the Lux model has an input of shape | ||
# (28, 28, 1, 4), then the input to the exported Lux model must be of shape (4, 1, 28, 28) | ||
# Input as defined in our exported Lux model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Extra fun fact, hypothetically if the model weights themselves weren’t traced, it should output MLIR containing the weights and thus one can export inference of a pre trained model too!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is actually a neat way to do it, but will need some additional plumbing on Lux end. Currently if luxlib sees mismatch in devices (CPUDevice vs ReactantDevice) it throws an error. This is very handy if users forgot to move something to GPU and would previously hit a cryptic error deep in the stack.
With Reactant it might be worth reworking some of those things
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #1088 +/- ##
===========================================
- Coverage 83.79% 37.18% -46.62%
===========================================
Files 147 61 -86
Lines 6036 2883 -3153
===========================================
- Hits 5058 1072 -3986
- Misses 978 1811 +833 ☔ View full report in Codecov by Sentry. |
fixes #453