Skip to content

Commit

Permalink
Slightly tweak the jax example
Browse files Browse the repository at this point in the history
Signed-off-by: Fabrice Normandin <[email protected]>
  • Loading branch information
lebrice committed Jun 3, 2024
1 parent 752afd1 commit d5b400c
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 50 deletions.
8 changes: 4 additions & 4 deletions pdm.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

115 changes: 70 additions & 45 deletions project/algorithms/jax_algo.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,33 @@
import dataclasses
from collections.abc import Callable
from typing import Concatenate, Literal, NamedTuple

import jax
import numpy as np
import torch
from lightning import Trainer
from torch_jax_interop import jax_to_torch, torch_to_jax

from project.algorithms.bases.algorithm import Algorithm
from project.datamodules.image_classification.base import ImageClassificationDataModule
from project.datamodules.image_classification.mnist import MNISTDataModule
from project.utils.types import PhaseStr
from project.utils.types.protocols import DataModule

type ParamsTuple = tuple[jax.Array, ...]
# type ParamsTuple = tuple[jax.Array, ...]


def fcnet(
input: jax.Array, w1: jax.Array, b1: jax.Array, w2: jax.Array, b2: jax.Array
) -> jax.Array:
class ParamsTuple[T: torch.Tensor | jax.Array](NamedTuple):
w1: T
b1: T
w2: T
b2: T


def fcnet(input: jax.Array, params: ParamsTuple) -> jax.Array:
"""Forward pass of a simple two-layer fully-connected neural network with relu activation."""
z1 = jax.numpy.matmul(input, w1) + b1
z1 = jax.numpy.matmul(input, params.w1) + params.b1
a1 = jax.nn.relu(z1)
logits = jax.numpy.matmul(a1, w2) + b2
logits = jax.numpy.matmul(a1, params.w2) + params.b2
return logits


Expand All @@ -29,60 +36,80 @@ def loss_fn(
labels: jax.Array,
) -> jax.Array:
probs = jax.nn.log_softmax(logits)
assert isinstance(probs, jax.Array)
one_hot_labels = jax.nn.one_hot(labels, num_classes=logits.shape[-1])
assert isinstance(one_hot_labels, jax.Array)
assert isinstance(probs, jax.Array)
loss = -(one_hot_labels * probs).sum(axis=-1).mean()
return loss


def forward_pass(params: ParamsTuple, x: jax.Array, y: jax.Array) -> tuple[jax.Array, jax.Array]:
logits = fcnet(x, *params)
return loss_fn(logits, y), logits
def forward_pass(
params: ParamsTuple[jax.Array], x: jax.Array, y: jax.Array
) -> tuple[jax.Array, jax.Array]:
logits = fcnet(x, params)
loss = loss_fn(logits, y)
return loss, logits


def jit[**P, Out](
fn: Callable[P, Out],
) -> Callable[P, Out]:
"""Small type hint fix for jax's `jit` (preserves the signature of the callable)."""
return jax.jit(fn) # type: ignore

backward_pass: Callable[
[ParamsTuple, jax.Array, jax.Array], tuple[tuple[jax.Array, jax.Array], ParamsTuple]
] = jax.value_and_grad(forward_pass, argnums=0, has_aux=True)

def value_and_grad[In, **P, Out, Aux](
fn: Callable[Concatenate[In, P], tuple[Out, Aux]],
argnums: Literal[0] = 0,
has_aux: Literal[True] = True,
) -> Callable[Concatenate[In, P], tuple[tuple[Out, Aux], In]]:
"""Small type hint fix for jax's `value_and_grad` (preserves the signature of the callable)."""
return jax.value_and_grad(fn, argnums=argnums, has_aux=has_aux) # type: ignore


# Register a handler for "converting" nn.Parameters to jax Arrays: they can be viewed as jax Arrays
# by just viewing their data as a jax array.
@torch_to_jax.register(torch.nn.Parameter)
def _parameter_to_jax_array(value: torch.nn.Parameter) -> jax.Array:
return torch_to_jax(value.data)


class JaxAlgorithm(Algorithm):
"""Example of an algorithm where the forward / backward passes are written in Jax."""

@dataclasses.dataclass()
@dataclasses.dataclass
class HParams(Algorithm.HParams):
lr: float = 1e-3
seed: int = 123
debug: bool = True
debug: bool = False

def __init__(
self,
*,
datamodule: DataModule | None = None,
datamodule: ImageClassificationDataModule,
hp: HParams | None = None,
):
super().__init__(datamodule=datamodule, hp=hp or self.HParams())
input_dims = int(np.prod(datamodule.dims))
output_dims = datamodule.num_classes
self.hp: JaxAlgorithm.HParams
key = jax.random.key(self.hp.seed)
self.w1 = torch.nn.Parameter(
jax_to_torch(jax.random.uniform(key=jax.random.fold_in(key, 1), shape=(784, 128))),
requires_grad=True,
# todo: Extract out the "network" portion, and probably use something like flax for it.
params = ParamsTuple(
w1=jax.random.uniform(key=jax.random.fold_in(key, 1), shape=(input_dims, 128)),
b1=jax.random.uniform(key=jax.random.fold_in(key, 2), shape=(128,)),
w2=jax.random.uniform(key=jax.random.fold_in(key, 3), shape=(128, output_dims)),
b2=jax.random.uniform(key=jax.random.fold_in(key, 4), shape=(output_dims,)),
)
self.b1 = torch.nn.Parameter(
jax_to_torch(jax.random.uniform(key=jax.random.fold_in(key, 2), shape=(128,))),
requires_grad=True,
)
self.w2 = torch.nn.Parameter(
jax_to_torch(jax.random.uniform(key=jax.random.fold_in(key, 3), shape=(128, 10))),
requires_grad=True,
)
self.b2 = torch.nn.Parameter(
jax_to_torch(jax.random.uniform(key=jax.random.fold_in(key, 4), shape=(10,))),
requires_grad=True,
self.params = torch.nn.ParameterList(
[torch.nn.Parameter(v, requires_grad=True) for v in map(jax_to_torch, params)]
)
self.forward_pass = forward_pass
self.backward_pass = value_and_grad(self.forward_pass)

self.forward_pass = jax.jit(forward_pass) if not self.hp.debug else forward_pass
self.backward_pass = jax.jit(backward_pass) if not self.hp.debug else backward_pass
if not self.hp.debug:
self.forward_pass = jit(self.forward_pass)
self.backward_pass = jit(self.backward_pass)

# We will do the backward pass ourselves, and PL will synchronize stuff between workers, etc.
self.automatic_optimization = False
Expand All @@ -91,14 +118,13 @@ def shared_step(
self, batch: tuple[torch.Tensor, torch.Tensor], batch_index: int, phase: PhaseStr
):
torch_x, torch_y = batch
# note: Also gets rid of the stride issues. in jax.from_dlpack.
# Note: flattening the input also gets rid of the stride issues in jax.from_dlpack.
torch_x = torch_x.flatten(start_dim=1)
jax_x, jax_y = jax.tree.map(torch_to_jax, [torch_x, torch_y])
assert isinstance(jax_x, jax.Array)
assert isinstance(jax_y, jax.Array)
# View/"convert" the torch inputs to jax Arrays.
jax_x, jax_y = torch_to_jax(torch_x), torch_to_jax(torch_y)

torch_params = tuple(p.data for p in self.parameters())
jax_params: ParamsTuple = jax.tree.map(torch_to_jax, torch_params)
# View the parameters as jax Arrays
jax_params = ParamsTuple(*map(torch_to_jax, self.parameters()))

if phase != "train":
# Only use the forward pass.
Expand All @@ -109,9 +135,8 @@ def shared_step(

# Perform the backward pass
(loss, logits), jax_grads = self.backward_pass(jax_params, jax_x, jax_y)
torch_grads = jax.tree.map(jax_to_torch, jax_grads)

with torch.no_grad():
torch_grads = map(jax_to_torch, jax_grads)
for param, grad in zip(self.parameters(), torch_grads):
if param.grad is None:
param.grad = grad
Expand All @@ -120,9 +145,8 @@ def shared_step(
optimizer.step()
optimizer.zero_grad()

torch_logits = jax_to_torch(logits)
torch_loss = jax_to_torch(loss)
accuracy = torch_logits.argmax(-1).eq(torch_y).float().mean()
accuracy = jax_to_torch(logits).argmax(-1).eq(torch_y).float().mean()
self.log(f"{phase}/accuracy", accuracy, prog_bar=True)
self.log(f"{phase}/loss", torch_loss, prog_bar=True)
return torch_loss
Expand All @@ -133,8 +157,9 @@ def configure_optimizers(self):

def main():
trainer = Trainer(devices=1, accelerator="auto")
model = JaxAlgorithm()
trainer.fit(model, datamodule=MNISTDataModule())
datamodule = MNISTDataModule()
model = JaxAlgorithm(datamodule=datamodule)
trainer.fit(model, datamodule=datamodule)

...

Expand Down
6 changes: 5 additions & 1 deletion project/datamodules/vision/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@
P = ParamSpec("P")

SLURM_TMPDIR: Path | None = (
Path(os.environ["SLURM_TMPDIR"]) if "SLURM_TMPDIR" in os.environ else None
Path(os.environ["SLURM_TMPDIR"])
if "SLURM_TMPDIR" in os.environ
else tmp
if "SLURM_JOB_ID" in os.environ and (tmp := Path("/tmp")).exists()
else None
)
logger = get_logger(__name__)

Expand Down

0 comments on commit d5b400c

Please sign in to comment.