Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Export trained model for Tensorflow/PyTorch/C++? #453

Open
liuyxpp opened this issue Nov 1, 2023 · 9 comments · Fixed by #1088
Open

Export trained model for Tensorflow/PyTorch/C++? #453

liuyxpp opened this issue Nov 1, 2023 · 9 comments · Fixed by #1088
Assignees
Labels
enhancement New feature or request reactant

Comments

@liuyxpp
Copy link

liuyxpp commented Nov 1, 2023

Is there any documented way to export the trained Lux model to be loaded by Tensorflow/PyTorch or even C++? Is ONNX the correct way to go? If so, how can I export Lux model as ONNX?

@avik-pal
Copy link
Member

avik-pal commented Nov 1, 2023

I don't think there is currently any way to export models from Lux (or more broadly julia) to ONNX. the other way is possible via https://github.com/DrChainsaw/ONNXNaiveNASflux.jl -> Flux Model -> Lux.transform

@liuyxpp
Copy link
Author

liuyxpp commented Nov 3, 2023

Just scanned the source code of ONNXNaiveNASflux.jl which seems it supports exporting flux model to ONNX in src/serialize/serialize.jl? Or do I miss anything?

@avik-pal avik-pal added the enhancement New feature or request label May 1, 2024
@avik-pal
Copy link
Member

@wsmoses we should be able to do this by piggybacking on Reactant's tracing + https://openxla.org/stablehlo/tutorials/jax-export, right?

@wsmoses
Copy link
Contributor

wsmoses commented Nov 15, 2024

I, in fact, did this alongside the Enzyme-JaX fficall (to call stablehlo generated from lux) in a test last night

@wsmoses
Copy link
Contributor

wsmoses commented Nov 15, 2024

So very much yes!!

We could even do a Julia native version of the fficall in the enzyme-jax docs, if there’s interest:

https://github.com/EnzymeAD/Enzyme-JAX

@avik-pal
Copy link
Member

I think we should, this line of question used to show up on slack quite often.

@wsmoses
Copy link
Contributor

wsmoses commented Nov 15, 2024

Relevant code to use the experiment from last night is:

def foo(arg0, arg1):
    return enzyme_ad.jax.primitives.ffi_call(arg0, arg1,
                           out_shapes=[jax.core.ShapedArray([100, 1], jnp.float32), jax.core.ShapedArray([100, 1], jnp.float32), jax.core.ShapedArray([16897], jnp.float32)],
                           fn="main",
                           source=code,
                           lang=enzyme_ad.jax.primitives.LANG_MHLO,
                           pipeline_options=JaXPipeline("")
                           )

a0 = jnp.ones((100, 1), jnp.float32)
a1 = jnp.ones((16897,), jnp.float32)

print(jax.jit(foo)(a0, a1))

Where code is a string of hlo generated from Reactant

And you’ll need this 3am commit by me: EnzymeAD/Enzyme-JAX@874de8c

This can probably be cleaned up a bit (eg we make a native hlo_call that auto parses the output shapes), and even a julia_call that uses reactant to auto trace variables (and ideally won’t even need to specify input shape either and retrace for the right ones).

@avik-pal
Copy link
Member

I will turn this into docs for now. We can make the interface nicer as we move along

@avik-pal
Copy link
Member

Oops didn't mean to close it. It works currently but the API is not great

@avik-pal avik-pal reopened this Nov 17, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request reactant
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants