Skip to content

Commit

Permalink
Able to use jax in intermediate node in graph!
Browse files Browse the repository at this point in the history
Signed-off-by: Fabrice Normandin <[email protected]>
  • Loading branch information
lebrice committed Jun 6, 2024
1 parent 8feece6 commit 8d79e67
Showing 1 changed file with 91 additions and 70 deletions.
161 changes: 91 additions & 70 deletions project/algorithms/jax_algo.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import dataclasses
import operator
import typing
from collections.abc import Callable
from typing import ClassVar, Concatenate, Literal

import flax.linen
import jax
import optax
import torch
import torch.distributed
from chex import PyTreeDef
Expand All @@ -16,7 +16,11 @@
from project.algorithms.bases.algorithm import Algorithm
from project.datamodules.image_classification.base import ImageClassificationDataModule
from project.datamodules.image_classification.mnist import MNISTDataModule
from project.utils.types import PhaseStr
from project.utils.types import PhaseStr, is_sequence_of


def flatten(x: jax.Array) -> jax.Array:
return x.reshape((x.shape[0], -1))


class CNN(flax.linen.Module):
Expand All @@ -35,6 +39,7 @@ def __call__(self, x: jax.Array):
x = flax.linen.Conv(features=64, kernel_size=(3, 3))(x)
x = flax.linen.relu(x)
x = flax.linen.avg_pool(x, window_shape=(2, 2), strides=(2, 2))

x = x.reshape((x.shape[0], -1)) # flatten
x = flax.linen.Dense(features=256)(x)
x = flax.linen.relu(x)
Expand Down Expand Up @@ -100,66 +105,86 @@ def to_channels_last[T: jax.Array | torch.Tensor](tensor: T) -> T:
return tensor.transpose(1, 3)


class JaxOperation(torch.nn.Module):
def __init__(
self,
jax_function: Callable[[VariableDict, jax.Array], jax.Array],
jax_params_dict: VariableDict,
):
super().__init__()
self.jax_function = jax.jit(jax_function)
params_list, self.params_treedef = jax.tree.flatten(jax_params_dict)
# Register the parameters.
# Need to call .clone() when doing distributed training, otherwise we get a RuntimeError:
# Invalid device pointer when trying to share the CUDA memory.
self.params = torch.nn.ParameterList(
map(operator.methodcaller("clone"), map(jax_to_torch, params_list))
)

def forward(self, input: torch.Tensor) -> torch.Tensor:
out = JaxFunction.apply(
input,
self.params_treedef,
self.jax_function,
*self.params,
)
return out

if typing.TYPE_CHECKING:
__call__ = forward


class JaxFunction(torch.autograd.Function):
"""Wrapper for a jax function."""
params_treedef: ClassVar

@staticmethod
def forward(
ctx: torch.autograd.function.NestedIOFunction,
x: torch.Tensor,
y: torch.Tensor,
input: torch.Tensor,
params_treedef: PyTreeDef,
loss_fn: Callable[[VariableDict, jax.Array, jax.Array], tuple[jax.Array, jax.Array]],
loss_value_and_grad_fn: Callable[
[VariableDict, jax.Array, jax.Array], tuple[jax.Array, jax.Array]
],
*params: torch.Tensor,
jax_function: Callable[[VariableDict, jax.Array], jax.Array],
*params: torch.Tensor, # need to flatten the params for autograd to understand that they need a gradient.
):
jax_x = torch_to_jax(x)
jax_y = torch_to_jax(y)
jax_input = torch_to_jax(input)
jax_params = tuple(map(torch_to_jax, params))
jax_params = jax.tree.unflatten(params_treedef, jax_params)

needs_grad: tuple[bool, ...] = ctx.needs_input_grad # type: ignore
x_needs_grad, y_needs_grad, _, _, _, *params_need_grad = needs_grad
# todo: broaden a bit:
assert not x_needs_grad
assert not y_needs_grad
if all(params_need_grad):
# We're going to need to do the backward pass, so do it right away and save the grads
# in the context.
(loss, logits), param_grads = loss_value_and_grad_fn(jax_params, jax_x, jax_y)
flattened_param_grads = jax.tree.leaves(param_grads)
torch_grads = tuple(map(jax_to_torch, flattened_param_grads))
ctx.save_for_backward(*torch_grads)
input_needs_grad, _, _, _, *params_need_grad = needs_grad
if any(params_need_grad) or input_needs_grad:
output, jvp_function = jax.vjp(jax_function, jax_params, jax_input)
ctx.jvp_function = jvp_function
else:
assert not any(params_need_grad)
loss, logits = loss_fn(jax_params, jax_x, jax_y)
loss = jax_to_torch(loss)
logits = jax_to_torch(logits)
return loss, logits
# Forward pass without gradient computation.
output = jax_function(jax_params, jax_input)
output = jax_to_torch(output)
return output

@staticmethod
def backward(
ctx: torch.autograd.function.NestedIOFunction,
grad_loss: torch.Tensor,
grad_logits: torch.Tensor,
grad_output: torch.Tensor,
):
x_needs_grad, y_needs_grad, _, _, _, *params_needs_grad = ctx.needs_input_grad
input_need_grad, _, _, *params_needs_grad = ctx.needs_input_grad
# todo: broaden this a bit in case we need the grad of the input.
# todo: Figure out how to do jax.grad for a function that outputs a matrix or vector.
assert not x_needs_grad
assert not y_needs_grad

assert not input_need_grad
grad_input = None
grad_y = None
if all(params_needs_grad):
params_grads = ctx.saved_tensors
if input_need_grad or any(params_needs_grad):
assert all(params_needs_grad)
jvp_function = ctx.jvp_function
jax_grad_output = torch_to_jax(grad_output)
jax_grad_params, jax_input_grad = jvp_function(jax_grad_output)
params_grads = jax.tree.map(jax_to_torch, jax.tree.leaves(jax_grad_params))
assert is_sequence_of(params_grads, torch.Tensor)

if input_need_grad:
grad_input = jax_to_torch(jax_input_grad)
else:
assert not any(params_needs_grad)
params_grads = tuple(None for _ in params_needs_grad)

return grad_input, grad_y, None, None, None, *params_grads
return grad_input, None, None, *params_grads


class JaxAlgorithm(Algorithm):
Expand All @@ -180,40 +205,39 @@ def __init__(
):
super().__init__(datamodule=datamodule, hp=hp or self.HParams())
self.hp: JaxAlgorithm.HParams
torch.zeros(1, device="cuda") # weird cuda errors!
key = jax.random.key(self.hp.seed)
self.network = network
x = jax.random.uniform(key, shape=(datamodule.batch_size, *datamodule.dims))
x = to_channels_last(x)
params = self.network.init(key, x=x)
params_list, self.params_treedef = jax.tree.flatten(params)
jax_net = CNN()
params = jax_net.init(key, x=x)
# Need to call .clone() when doing distributed training, otherwise we get a RuntimeError:
# Invalid device pointer when trying to share the CUDA memory.
self.params = torch.nn.ParameterList(
map(operator.methodcaller("clone"), map(jax_to_torch, params_list))
)

self.network = JaxOperation(jax_function=jax_net.apply, jax_params_dict=params)

# self.params = torch.nn.ParameterList(
# map(operator.methodcaller("clone"), map(jax_to_torch, params_list))
# )

self.automatic_optimization = True

def on_fit_start(self):
# Setting those here, because otherwise we get pickling errors when running with multiple
# GPUs.
def loss_function(
params: VariableDict,
x: jax.Array,
y: jax.Array,
):
logits = self.network.apply(params, x)
assert isinstance(logits, jax.Array)
loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=y).mean()
assert isinstance(loss, jax.Array)
return loss, logits

self.forward_pass = loss_function
self.backward_pass = value_and_grad(self.forward_pass, argnums=0, has_aux=True)

if not self.hp.debug:
self.forward_pass = jit(self.forward_pass)
self.backward_pass = jit(self.backward_pass)
pass
# # Setting those here, because otherwise we get pickling errors when running with multiple
# # GPUs.
# def loss_function(
# params: VariableDict,
# x: jax.Array,
# y: jax.Array,
# ):

# self.forward_pass = loss_function
# self.backward_pass = value_and_grad(self.forward_pass, argnums=0, has_aux=True)

# if not self.hp.debug:
# self.forward_pass = jit(self.forward_pass)
# self.backward_pass = jit(self.backward_pass)

def shared_step(
self, batch: tuple[torch.Tensor, torch.Tensor], batch_index: int, phase: PhaseStr
Expand All @@ -225,13 +249,10 @@ def shared_step(

x = to_channels_last(x)

loss: torch.Tensor
logits: torch.Tensor
loss, logits = JaxFunction.apply( # type: ignore
x, y, self.params_treedef, self.forward_pass, self.backward_pass, *self.parameters()
)

logits = self.network(x)
assert isinstance(logits, torch.Tensor)
loss = torch.nn.functional.cross_entropy(logits, target=y).mean()
assert isinstance(loss, torch.Tensor)
if phase == "train":
assert loss.requires_grad
self.log(f"{phase}/loss", loss, prog_bar=True, sync_dist=True)
Expand All @@ -245,7 +266,7 @@ def configure_optimizers(self):

def main():
trainer = Trainer(devices=1, accelerator="auto")
datamodule = MNISTDataModule(num_workers=4)
datamodule = MNISTDataModule(num_workers=4, batch_size=2)
model = JaxAlgorithm(network=CNN(num_classes=datamodule.num_classes), datamodule=datamodule)
trainer.fit(model, datamodule=datamodule)

Expand Down

0 comments on commit 8d79e67

Please sign in to comment.