Skip to content

Commit

Permalink
Use flax nn.Module
Browse files Browse the repository at this point in the history
Signed-off-by: Fabrice Normandin <[email protected]>
  • Loading branch information
lebrice committed Jun 4, 2024
1 parent 4deb156 commit 6306ccd
Showing 1 changed file with 90 additions and 50 deletions.
140 changes: 90 additions & 50 deletions project/algorithms/jax_algo.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import dataclasses
import operator
from collections.abc import Callable
from typing import Concatenate, Literal, NamedTuple
from typing import Concatenate, Literal

import flax.linen
import jax
import numpy as np
import optax
import torch
import torch.distributed
from flax.typing import VariableDict
from lightning import Trainer
from torch_jax_interop import jax_to_torch, torch_to_jax

Expand All @@ -15,42 +17,40 @@
from project.datamodules.image_classification.mnist import MNISTDataModule
from project.utils.types import PhaseStr

# type ParamsTuple = tuple[jax.Array, ...]

class CNN(flax.linen.Module):
"""A simple CNN model.
class ParamsTuple[T: torch.Tensor | jax.Array](NamedTuple):
w1: T
b1: T
w2: T
b2: T
Taken from https://flax.readthedocs.io/en/latest/quick_start.html#define-network
"""

num_classes: int = 10

def fcnet(input: jax.Array, params: ParamsTuple) -> jax.Array:
"""Forward pass of a simple two-layer fully-connected neural network with relu activation."""
z1 = jax.numpy.matmul(input, params.w1) + params.b1
a1 = jax.nn.relu(z1)
logits = jax.numpy.matmul(a1, params.w2) + params.b2
return logits
@flax.linen.compact
def __call__(self, x: jax.Array):
x = flax.linen.Conv(features=32, kernel_size=(3, 3))(x)
x = flax.linen.relu(x)
x = flax.linen.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = flax.linen.Conv(features=64, kernel_size=(3, 3))(x)
x = flax.linen.relu(x)
x = flax.linen.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
x = flax.linen.Dense(features=256)(x)
x = flax.linen.relu(x)
x = flax.linen.Dense(features=self.num_classes)(x)
return x


def loss_fn(
logits: jax.Array,
labels: jax.Array,
) -> jax.Array:
log_probs = jax.nn.log_softmax(logits, axis=-1)
assert isinstance(log_probs, jax.Array)
one_hot_labels = jax.nn.one_hot(labels, num_classes=logits.shape[-1])
assert isinstance(one_hot_labels, jax.Array)
loss = -(one_hot_labels * log_probs).sum(axis=-1).mean()
return loss
class FcNet(flax.linen.Module):
num_classes: int = 10


def forward_pass(
params: ParamsTuple[jax.Array], x: jax.Array, y: jax.Array
) -> tuple[jax.Array, jax.Array]:
logits = fcnet(x, params)
loss = loss_fn(logits, y)
return loss, logits
@flax.linen.compact
def __call__(self, x: jax.Array):
x = x.reshape((x.shape[0], -1)) # flatten
x = flax.linen.Dense(features=256)(x)
x = flax.linen.relu(x)
x = flax.linen.Dense(features=self.num_classes)(x)
return x


def jit[**P, Out](
Expand All @@ -76,6 +76,29 @@ def _parameter_to_jax_array(value: torch.nn.Parameter) -> jax.Array:
return torch_to_jax(value.data)


def is_channels_first(shape: tuple[int, int, int] | tuple[int, int, int, int]) -> bool:
if len(shape) == 4:
return is_channels_first(shape[1:])
return (shape[0] in (1, 3) and shape[1] not in {1, 3} and shape[2] not in {1, 3}) or (
shape[0] < min(shape[1], shape[2])
)


def to_channels_last[T: jax.Array | torch.Tensor](tensor: T) -> T:
shape = tuple(tensor.shape)
assert len(shape) == 3 or len(shape) == 4
if not is_channels_first(shape):
return tensor
if isinstance(tensor, jax.Array):
if len(shape) == 3:
return tensor.transpose(1, 2, 0)
return tensor.transpose(0, 2, 3, 1)
else:
if len(shape) == 3:
return tensor.transpose(0, 2)
return tensor.transpose(1, 3)


class JaxAlgorithm(Algorithm):
"""Example of an algorithm where the forward / backward passes are written in Jax."""

Expand All @@ -88,26 +111,22 @@ class HParams(Algorithm.HParams):
def __init__(
self,
*,
network: flax.linen.Module,
datamodule: ImageClassificationDataModule,
hp: HParams | None = None,
):
super().__init__(datamodule=datamodule, hp=hp or self.HParams())
input_dims = int(np.prod(datamodule.dims))
output_dims = datamodule.num_classes
self.hp: JaxAlgorithm.HParams
key = jax.random.key(self.hp.seed)
# todo: Extract out the "network" portion, and probably use something like flax for it.
params = ParamsTuple(
w1=jax.random.uniform(key=jax.random.fold_in(key, 1), shape=(input_dims, 128)),
b1=jax.random.uniform(key=jax.random.fold_in(key, 2), shape=(128,)),
w2=jax.random.uniform(key=jax.random.fold_in(key, 3), shape=(128, output_dims)),
b2=jax.random.uniform(key=jax.random.fold_in(key, 4), shape=(output_dims,)),
)
parameters, self.params_treedef = jax.tree.flatten(params)
self.network = network
x = jax.random.uniform(key, shape=(datamodule.batch_size, *datamodule.dims))
x = to_channels_last(x)
params = self.network.init(key, x=x)
params_list, self.params_treedef = jax.tree.flatten(params)
# Need to call .clone() when doing distributed training, otherwise we get a RuntimeError:
# Invalid device pointer when trying to share the CUDA memory.
self.params = torch.nn.ParameterList(
map(operator.methodcaller("clone"), map(jax_to_torch, parameters))
map(operator.methodcaller("clone"), map(jax_to_torch, params_list))
)

# We will do the backward pass ourselves, and PL will only be used to synchronize stuff
Expand All @@ -117,31 +136,52 @@ def __init__(
def on_fit_start(self):
# Setting those here, because otherwise we get pickling errors when running with multiple
# GPUs.
self.forward_pass = forward_pass

def loss_fn(params: VariableDict, x: jax.Array, y: jax.Array):
logits = self.network.apply(params, x)
assert isinstance(logits, jax.Array)
loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=y).mean()
assert isinstance(loss, jax.Array)
return loss, logits

self.forward_pass = loss_fn
self.backward_pass = value_and_grad(self.forward_pass)

if not self.hp.debug:
self.forward_pass = jit(self.forward_pass)
self.backward_pass = jit(self.backward_pass)

def jax_params(self) -> ParamsTuple[jax.Array]:
def jax_params(self) -> VariableDict:
# View the torch parameters as jax Arrays
jax_parameters = jax.tree.map(torch_to_jax, list(self.params))
jax_parameters = jax.tree.map(torch_to_jax, list(self.parameters()))
# Reconstruct the original object structure.
jax_params_tuple = jax.tree.unflatten(self.params_treedef, jax_parameters)
return jax_params_tuple

# def on_before_batch_transfer(
# self, batch: tuple[torch.Tensor, torch.Tensor], dataloader_idx: int
# ):
# # Convert the batch to jax Arrays.
# x, y = batch
# # Seems like jax likes channels last tensors: jax.from_dlpack doesn't work with
# # channels-first tensors, so we have to do a transpose here.
# x = to_channels_last(x)
# # View the torch inputs as jax Arrays.
# x, y = torch_to_jax(x), torch_to_jax(y)
# return x, y

def shared_step(
self, batch: tuple[torch.Tensor, torch.Tensor], batch_index: int, phase: PhaseStr
):
x, y = batch
# Note: flattening the input also gets rid of the stride issues in jax.from_dlpack.
x = x.flatten(start_dim=1)
# View/"convert" the torch inputs to jax Arrays.
# Convert the batch to jax Arrays.
# Seems like jax likes channels last tensors: jax.from_dlpack doesn't work with
# channels-first tensors, so we have to do a transpose here.
x = to_channels_last(x)
# View the torch inputs as jax Arrays.
x, y = torch_to_jax(x), torch_to_jax(y)

jax_params = self.jax_params()

if phase != "train":
# Only use the forward pass.
loss, logits = self.forward_pass(jax_params, x, y)
Expand Down Expand Up @@ -183,7 +223,7 @@ def configure_optimizers(self):
def main():
trainer = Trainer(devices="auto", accelerator="auto")
datamodule = MNISTDataModule(num_workers=4)
model = JaxAlgorithm(datamodule=datamodule)
model = JaxAlgorithm(network=CNN(num_classes=datamodule.num_classes), datamodule=datamodule)
trainer.fit(model, datamodule=datamodule)

...
Expand Down

0 comments on commit 6306ccd

Please sign in to comment.