-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add a proof-of-concept for an Algorithm that uses Jax for its forward…
…/backward passes [RT-71] (#4) * Add an example algo that uses jax! Signed-off-by: Fabrice Normandin <[email protected]> * Simplify the jax example Signed-off-by: Fabrice Normandin <[email protected]> * Slightly tweak the jax example Signed-off-by: Fabrice Normandin <[email protected]> * Tweak the jax example Signed-off-by: Fabrice Normandin <[email protected]> * Tweak algo a bit (again) Signed-off-by: Fabrice Normandin <[email protected]> * Use flax nn.Module Signed-off-by: Fabrice Normandin <[email protected]> * Hacky: Wrap jax fn into a torch.autograd.Function Signed-off-by: Fabrice Normandin <[email protected]> * Make it work with automatic optimization and jit! Signed-off-by: Fabrice Normandin <[email protected]> * Able to use jax in intermediate node in graph! Signed-off-by: Fabrice Normandin <[email protected]> * Update to use git packages Signed-off-by: Fabrice Normandin <[email protected]> * Rename `batch_idx`->`batch_index` everywhere Signed-off-by: Fabrice Normandin <[email protected]> * Fix broken callback due to `batch_idx` rename Signed-off-by: Fabrice Normandin <[email protected]> * Use a callback to log classification metrics Signed-off-by: Fabrice Normandin <[email protected]> * Update the jax algo Signed-off-by: Fabrice Normandin <[email protected]> * Make the callback compatible with more recent PL Signed-off-by: Fabrice Normandin <[email protected]> * Make the Jax algo usable from CLI, tweak configs Signed-off-by: Fabrice Normandin <[email protected]> * Fix tests to use the tensor_regression package Signed-off-by: Fabrice Normandin <[email protected]> * Fix some issues with config registration in tests Signed-off-by: Fabrice Normandin <[email protected]> * Fix other tiny issues in test code Signed-off-by: Fabrice Normandin <[email protected]> * Fix issue with resnet50 config Signed-off-by: Fabrice Normandin <[email protected]> * Add some generated tests for the Jax algo example Signed-off-by: Fabrice Normandin <[email protected]> * Fix tests for algo that doesnt support jax Signed-off-by: Fabrice Normandin <[email protected]> * 'fix' issue with doctest of some configs Signed-off-by: Fabrice Normandin <[email protected]> * Set JAX_PLATFORMS=cpu in GitHub CI Signed-off-by: Fabrice Normandin <[email protected]> * Tweak build.yml again Signed-off-by: Fabrice Normandin <[email protected]> * Fix build.yml Signed-off-by: Fabrice Normandin <[email protected]> * Set rounding precision for regression tests Signed-off-by: Fabrice Normandin <[email protected]> --------- Signed-off-by: Fabrice Normandin <[email protected]>
- Loading branch information
Showing
37 changed files
with
1,246 additions
and
1,059 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.