Here are some examples
- -diff --git a/404.html b/404.html index 823777c1..a011a9ee 100644 --- a/404.html +++ b/404.html @@ -66,11 +66,9 @@
Here are some examples
- -python project/main.py experiment=cluster_sweep_example
You can use Jax for your dataloading, your network, or the learning algorithm, all while still benefiting from the nice stuff that comes from using PyTorch-Lightning.
-How does this work? -Well, we use torch-jax-interop, another package developed here at Mila, which allows easy interop between torch and jax code. See the readme on that repo for more details.
-You can use Jax for your training step, but not the entire training loop (since that is handled by Lightning). -There are a few good reasons why you should let Lightning handle the training loop, most notably the fact that it handles all the logging, checkpointing, and other stuff that you'd lose if you swapped out the entire training framework for something based on Jax.
-In this example Jax algorithm,
-a Neural network written in Jax (using flax) is wrapped using the torch_jax_interop.JaxFunction
, so that its parameters are learnable. The parameters are saved on the LightningModule as nn.Parameters (which use the same underlying memory as the jax arrays). In this example, the loss function is written in PyTorch, while the network forward and backward passes are written in Jax.
(todo)
You can use Jax for your dataloading, your network, or the learning algorithm, all while still benefiting from the nice stuff that comes from using PyTorch-Lightning.
+How does this work? +Well, we use torch-jax-interop, another package developed here at Mila, which allows easy interop between torch and jax code. See the readme on that repo for more details.
+You can use Jax for your training step, but not the entire training loop (since that is handled by Lightning). +There are a few good reasons why you should let Lightning handle the training loop, most notably the fact that it handles all the logging, checkpointing, and other stuff that you'd lose if you swapped out the entire training framework for something based on Jax.
+In this example Jax algorithm,
+a Neural network written in Jax (using flax) is wrapped using the torch_jax_interop.JaxFunction
, so that its parameters are learnable. The parameters are saved on the LightningModule as nn.Parameters (which use the same underlying memory as the jax arrays). In this example, the loss function is written in PyTorch, while the network forward and backward passes are written in Jax.
(todo)