Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

docs: initial prototype of exporting Lux models to Jax #1088

Merged
merged 4 commits into from
Nov 17, 2024

Conversation

avik-pal
Copy link
Member

fixes #453

@avik-pal avik-pal force-pushed the ap/export_lux_to_jax branch 2 times, most recently from 46fb13d to 190970c Compare November 16, 2024 19:29
@avik-pal avik-pal force-pushed the ap/export_lux_to_jax branch from 190970c to 7798aec Compare November 16, 2024 20:30
@avik-pal
Copy link
Member Author

avik-pal commented Nov 16, 2024

Tested locally that this works (needs a custom build of EnzymeJAX)

@avik-pal
Copy link
Member Author

cc @wsmoses if you have any suggestions

)


# Note that all the inputs must be transposed, i.e. if the Lux model has an input of shape
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Incidentally we can probably make a flag to compile/jit/code hlo to not transpose the inputs

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But maybe regardless it’s worth explaining why: Julia default uses col major vs jax default uses row major

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Incidentally we can probably make a flag to compile/jit/code hlo to not transpose the inputs

I feel this will be very confusing for people who use Python exclusively. For people who use both (Julia and Python), switching is not that hard when converting the model.

I think we can add a function that serializes all the inputs from Julia + mlir code and we have a python function to deserialize it. That would also be handy for pre-trained models and such

But maybe regardless it’s worth explaining why: Julia default uses col major vs jax default uses row major

Agreed


# Note that all the inputs must be transposed, i.e. if the Lux model has an input of shape
# (28, 28, 1, 4), then the input to the exported Lux model must be of shape (4, 1, 28, 28)
# Input as defined in our exported Lux model
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Extra fun fact, hypothetically if the model weights themselves weren’t traced, it should output MLIR containing the weights and thus one can export inference of a pre trained model too!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is actually a neat way to do it, but will need some additional plumbing on Lux end. Currently if luxlib sees mismatch in devices (CPUDevice vs ReactantDevice) it throws an error. This is very handy if users forgot to move something to GPU and would previously hit a cryptic error deep in the stack.

With Reactant it might be worth reworking some of those things

Copy link

codecov bot commented Nov 17, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 37.18%. Comparing base (3986545) to head (b52f96c).
Report is 2 commits behind head on main.

❗ There is a different number of reports uploaded between BASE (3986545) and HEAD (b52f96c). Click for more details.

HEAD has 56 uploads less than BASE
Flag BASE (3986545) HEAD (b52f96c)
57 1
Additional details and impacted files
@@             Coverage Diff             @@
##             main    #1088       +/-   ##
===========================================
- Coverage   83.79%   37.18%   -46.62%     
===========================================
  Files         147       61       -86     
  Lines        6036     2883     -3153     
===========================================
- Hits         5058     1072     -3986     
- Misses        978     1811      +833     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@avik-pal avik-pal merged commit cf8fd61 into main Nov 17, 2024
16 of 17 checks passed
@avik-pal avik-pal deleted the ap/export_lux_to_jax branch November 17, 2024 01:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Export trained model for Tensorflow/PyTorch/C++?
2 participants