Skip to content

Commit

Permalink
Simplify the jax example
Browse files Browse the repository at this point in the history
Signed-off-by: Fabrice Normandin <[email protected]>
  • Loading branch information
lebrice committed May 31, 2024
1 parent 61c0662 commit 752afd1
Showing 1 changed file with 34 additions and 28 deletions.
62 changes: 34 additions & 28 deletions project/algorithms/jax_algo.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import dataclasses
from collections.abc import Callable

import jax
import torch
Expand All @@ -7,18 +8,20 @@

from project.algorithms.bases.algorithm import Algorithm
from project.datamodules.image_classification.mnist import MNISTDataModule
from project.utils.types import PhaseStr, is_sequence_of
from project.utils.types import PhaseStr
from project.utils.types.protocols import DataModule

type ParamsTuple = tuple[jax.Array, ...]

def forward_pass(

def fcnet(
input: jax.Array, w1: jax.Array, b1: jax.Array, w2: jax.Array, b2: jax.Array
) -> jax.Array:
"""Forward pass of a simple two-layer fully-connected neural network with relu activation."""
z1 = jax.numpy.matmul(input, w1) + b1
a1 = jax.nn.relu(z1)
z2 = jax.numpy.matmul(a1, w2) + b2
return z2
logits = jax.numpy.matmul(a1, w2) + b2
return logits


def loss_fn(
Expand All @@ -33,6 +36,16 @@ def loss_fn(
return loss


def forward_pass(params: ParamsTuple, x: jax.Array, y: jax.Array) -> tuple[jax.Array, jax.Array]:
logits = fcnet(x, *params)
return loss_fn(logits, y), logits


backward_pass: Callable[
[ParamsTuple, jax.Array, jax.Array], tuple[tuple[jax.Array, jax.Array], ParamsTuple]
] = jax.value_and_grad(forward_pass, argnums=0, has_aux=True)


class JaxAlgorithm(Algorithm):
"""Example of an algorithm where the forward / backward passes are written in Jax."""

Expand Down Expand Up @@ -67,11 +80,11 @@ def __init__(
jax_to_torch(jax.random.uniform(key=jax.random.fold_in(key, 4), shape=(10,))),
requires_grad=True,
)
self._forward = jax.jit(forward_pass) if not self.hp.debug else forward_pass

# Get the gradients with respect to all parameters.
# self._backward_pass = jax.value_and_grad(loss_fn, argnums=range(2, 6))
self.forward_pass = jax.jit(forward_pass) if not self.hp.debug else forward_pass
self.backward_pass = jax.jit(backward_pass) if not self.hp.debug else backward_pass

# We will do the backward pass ourselves, and PL will synchronize stuff between workers, etc.
self.automatic_optimization = False

def shared_step(
Expand All @@ -81,29 +94,21 @@ def shared_step(
# note: Also gets rid of the stride issues. in jax.from_dlpack.
torch_x = torch_x.flatten(start_dim=1)
jax_x, jax_y = jax.tree.map(torch_to_jax, [torch_x, torch_y])

optimizer = self.optimizers()
assert not isinstance(optimizer, list)
if phase == "train":
optimizer.zero_grad()

torch_params = tuple(p.data for p in self.parameters())
jax_params = jax.tree.map(torch_to_jax, torch_params)
assert isinstance(jax_x, jax.Array)
assert isinstance(jax_y, jax.Array)
assert is_sequence_of(jax_params, jax.Array)

jax_logits = self._forward(jax_x, *jax_params)
jax_loss = loss_fn(jax_logits, jax_y)

torch_loss = jax_to_torch(jax_loss)
if phase == "train":
torch_params = tuple(p.data for p in self.parameters())
jax_params: ParamsTuple = jax.tree.map(torch_to_jax, torch_params)

def _loss_fn(params, x, y):
logits = self._forward(x, *params)
return loss_fn(logits, y)
if phase != "train":
# Only use the forward pass.
loss, logits = self.forward_pass(jax_params, jax_x, jax_y)
else:
optimizer = self.optimizers()
assert isinstance(optimizer, torch.optim.Optimizer)

jax_grads = jax.grad(_loss_fn, 0)(jax_params, jax_x, jax_y)
# Perform the backward pass
(loss, logits), jax_grads = self.backward_pass(jax_params, jax_x, jax_y)
torch_grads = jax.tree.map(jax_to_torch, jax_grads)

with torch.no_grad():
Expand All @@ -112,10 +117,11 @@ def _loss_fn(params, x, y):
param.grad = grad
else:
param.grad += grad

if phase == "train":
optimizer.step()
torch_logits = jax_to_torch(jax_logits)
optimizer.zero_grad()

torch_logits = jax_to_torch(logits)
torch_loss = jax_to_torch(loss)
accuracy = torch_logits.argmax(-1).eq(torch_y).float().mean()
self.log(f"{phase}/accuracy", accuracy, prog_bar=True)
self.log(f"{phase}/loss", torch_loss, prog_bar=True)
Expand Down

0 comments on commit 752afd1

Please sign in to comment.