Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Facilitate gradient interoperability between other ml frameworks (TensorFlow, Torch, etc) and jax #3

Open
phinate opened this issue Mar 1, 2020 · 2 comments
Labels
enhancement New feature or request help wanted Extra attention is needed

Comments

@phinate
Copy link
Member

phinate commented Mar 1, 2020

To use the setup as it is now, one would need to define their neural network (or whatever parametric observable you like that yields histograms) using only jax, which allows the tracing of gradients. This is generally seen as a limitation for two reasons:

  • New users will not want to learn yet another way to do neural networks
  • Model implementations that are non-trivially portable between frameworks have no way to link with the functionality of the library

This issue outlines the kind of syntax and operations needed to enable this, but jax is still in it's infancy, so there is no good user-facing API for this yet. It seems like it could be possible! (:

@phinate phinate added enhancement New feature or request help wanted Extra attention is needed labels Mar 1, 2020
@lukasheinrich
Copy link
Collaborator

we have an example here for pytorch using metadifff

Uploading image.png…

@phinate
Copy link
Member Author

phinate commented Nov 10, 2020

we have an example here for pytorch using metadifff

Uploading image.png…

Awesome, but I think that you've just posted a blank markdown link there :P

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

2 participants