Skip to content

Commit

Permalink
Hacky: Wrap jax fn into a torch.autograd.Function
Browse files Browse the repository at this point in the history
Signed-off-by: Fabrice Normandin <[email protected]>
  • Loading branch information
lebrice committed Jun 4, 2024
1 parent 6306ccd commit a52972c
Showing 1 changed file with 101 additions and 23 deletions.
124 changes: 101 additions & 23 deletions project/algorithms/jax_algo.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import dataclasses
import operator
from collections.abc import Callable
from typing import Concatenate, Literal
from typing import ClassVar, Concatenate, Literal

import flax.linen
import jax
import optax
import torch
import torch.distributed
from chex import PyTreeDef
from flax.typing import VariableDict
from lightning import Trainer
from torch_jax_interop import jax_to_torch, torch_to_jax
Expand Down Expand Up @@ -99,6 +100,75 @@ def to_channels_last[T: jax.Array | torch.Tensor](tensor: T) -> T:
return tensor.transpose(1, 3)


class JaxFunction(torch.autograd.Function):
params_treedef: ClassVar

@staticmethod
def loss_function(
params: VariableDict,
x: jax.Array,
y: jax.Array,
):
logits = CNN().apply(params, x)
assert isinstance(logits, jax.Array)
loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=y).mean()
assert isinstance(loss, jax.Array)
return loss, logits

@staticmethod
def forward(
ctx: torch.autograd.function.FunctionCtx,
x: torch.Tensor,
y: torch.Tensor,
params_treedef: PyTreeDef,
*params: torch.Tensor,
):
ctx.save_for_backward(x, y, *params)
ctx.params_treedef = params_treedef # type: ignore
jax_x = torch_to_jax(x)
jax_y = torch_to_jax(y)
jax_params = tuple(map(torch_to_jax, params))
jax_params = jax.tree.unflatten(params_treedef, jax_params)
jax_loss, jax_logits = JaxFunction.loss_function(jax_params, x=jax_x, y=jax_y)
loss = jax_to_torch(jax_loss)
logits = jax_to_torch(jax_logits)
return loss, logits

@staticmethod
def backward(
ctx: torch.autograd.function.NestedIOFunction,
grad_loss: torch.Tensor,
grad_logits: torch.Tensor,
):
x: torch.Tensor
params: tuple[torch.Tensor, ...]
x, y, *params = ctx.saved_tensors # type: ignore
params_treedef: PyTreeDef = ctx.params_treedef # type: ignore
jax_x = torch_to_jax(x)
jax_y = torch_to_jax(y)
# jax_grad_output = torch_to_jax(grad_output) # TODO: Can we even pass this to jax.grad?

structured_params = jax.tree.unflatten(params_treedef, params)
jax_params = jax.tree.map(torch_to_jax, structured_params)

grad_input = None
grad_y = None

# todo: broaden this a bit in case we need the grad of the input.
assert ctx.needs_input_grad == (
False, # input
False, # y
False, # params_treedef
*(True for _ in params),
), ctx.needs_input_grad
jax_params_grad, logits = jax.grad(JaxFunction.loss_function, argnums=0, has_aux=True)(
jax_params, jax_x, jax_y
)
torch_params_grad = jax.tree.map(jax_to_torch, jax_params_grad)
torch_flat_params_grad = jax.tree.leaves(torch_params_grad)
return grad_input, grad_y, None, *torch_flat_params_grad


class JaxAlgorithm(Algorithm):
"""Example of an algorithm where the forward / backward passes are written in Jax."""

Expand Down Expand Up @@ -131,32 +201,32 @@ def __init__(

# We will do the backward pass ourselves, and PL will only be used to synchronize stuff
# between workers, do logging, etc.
self.automatic_optimization = False
self.automatic_optimization = True

def on_fit_start(self):
# Setting those here, because otherwise we get pickling errors when running with multiple
# GPUs.
# def on_fit_start(self):
# Setting those here, because otherwise we get pickling errors when running with multiple
# GPUs.

def loss_fn(params: VariableDict, x: jax.Array, y: jax.Array):
logits = self.network.apply(params, x)
assert isinstance(logits, jax.Array)
loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=y).mean()
assert isinstance(loss, jax.Array)
return loss, logits
# def loss_fn(params: VariableDict, x: jax.Array, y: jax.Array):
# logits = self.network.apply(params, x)
# assert isinstance(logits, jax.Array)
# loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=y).mean()
# assert isinstance(loss, jax.Array)
# return loss, logits

self.forward_pass = loss_fn
self.backward_pass = value_and_grad(self.forward_pass)
# self.forward_pass = loss_fn
# self.backward_pass = value_and_grad(self.forward_pass)

if not self.hp.debug:
self.forward_pass = jit(self.forward_pass)
self.backward_pass = jit(self.backward_pass)
# if not self.hp.debug:
# self.forward_pass = jit(self.forward_pass)
# self.backward_pass = jit(self.backward_pass)

def jax_params(self) -> VariableDict:
# View the torch parameters as jax Arrays
jax_parameters = jax.tree.map(torch_to_jax, list(self.parameters()))
# Reconstruct the original object structure.
jax_params_tuple = jax.tree.unflatten(self.params_treedef, jax_parameters)
return jax_params_tuple
# def jax_params(self) -> VariableDict:
# # View the torch parameters as jax Arrays
# jax_parameters = jax.tree.map(torch_to_jax, list(self.parameters()))
# # Reconstruct the original object structure.
# jax_params_tuple = jax.tree.unflatten(self.params_treedef, jax_parameters)
# return jax_params_tuple

# def on_before_batch_transfer(
# self, batch: tuple[torch.Tensor, torch.Tensor], dataloader_idx: int
Expand All @@ -178,6 +248,14 @@ def shared_step(
# Seems like jax likes channels last tensors: jax.from_dlpack doesn't work with
# channels-first tensors, so we have to do a transpose here.
x = to_channels_last(x)
loss, logits = JaxFunction.apply(x, y, self.params_treedef, *self.parameters()) # type: ignore
assert isinstance(logits, torch.Tensor)
if phase == "train":
assert loss.requires_grad
self.log(f"{phase}/loss", loss, prog_bar=True, sync_dist=True)
acc = logits.argmax(-1).eq(y).float().mean()
self.log(f"{phase}/acc", acc, prog_bar=True, sync_dist=True)
return loss
# View the torch inputs as jax Arrays.
x, y = torch_to_jax(x), torch_to_jax(y)

Expand Down Expand Up @@ -221,7 +299,7 @@ def configure_optimizers(self):


def main():
trainer = Trainer(devices="auto", accelerator="auto")
trainer = Trainer(devices=1, accelerator="auto")
datamodule = MNISTDataModule(num_workers=4)
model = JaxAlgorithm(network=CNN(num_classes=datamodule.num_classes), datamodule=datamodule)
trainer.fit(model, datamodule=datamodule)
Expand Down

0 comments on commit a52972c

Please sign in to comment.