-
Notifications
You must be signed in to change notification settings - Fork 62
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Export trained model for Tensorflow/PyTorch/C++? #453
Comments
I don't think there is currently any way to export models from Lux (or more broadly julia) to ONNX. the other way is possible via |
Just scanned the source code of |
@wsmoses we should be able to do this by piggybacking on Reactant's tracing + https://openxla.org/stablehlo/tutorials/jax-export, right? |
I, in fact, did this alongside the Enzyme-JaX fficall (to call stablehlo generated from lux) in a test last night |
So very much yes!! We could even do a Julia native version of the fficall in the enzyme-jax docs, if there’s interest: |
I think we should, this line of question used to show up on slack quite often. |
Relevant code to use the experiment from last night is: def foo(arg0, arg1):
return enzyme_ad.jax.primitives.ffi_call(arg0, arg1,
out_shapes=[jax.core.ShapedArray([100, 1], jnp.float32), jax.core.ShapedArray([100, 1], jnp.float32), jax.core.ShapedArray([16897], jnp.float32)],
fn="main",
source=code,
lang=enzyme_ad.jax.primitives.LANG_MHLO,
pipeline_options=JaXPipeline("")
)
a0 = jnp.ones((100, 1), jnp.float32)
a1 = jnp.ones((16897,), jnp.float32)
print(jax.jit(foo)(a0, a1)) Where code is a string of hlo generated from Reactant And you’ll need this 3am commit by me: EnzymeAD/Enzyme-JAX@874de8c This can probably be cleaned up a bit (eg we make a native hlo_call that auto parses the output shapes), and even a julia_call that uses reactant to auto trace variables (and ideally won’t even need to specify input shape either and retrace for the right ones). |
I will turn this into docs for now. We can make the interface nicer as we move along |
Oops didn't mean to close it. It works currently but the API is not great |
Is there any documented way to export the trained Lux model to be loaded by Tensorflow/PyTorch or even C++? Is ONNX the correct way to go? If so, how can I export Lux model as ONNX?
The text was updated successfully, but these errors were encountered: