diff --git a/project/algorithms/jax_algo.py b/project/algorithms/jax_algo.py index f355da59..8d1c9d68 100644 --- a/project/algorithms/jax_algo.py +++ b/project/algorithms/jax_algo.py @@ -1,11 +1,11 @@ import dataclasses import operator +import typing from collections.abc import Callable from typing import ClassVar, Concatenate, Literal import flax.linen import jax -import optax import torch import torch.distributed from chex import PyTreeDef @@ -16,7 +16,11 @@ from project.algorithms.bases.algorithm import Algorithm from project.datamodules.image_classification.base import ImageClassificationDataModule from project.datamodules.image_classification.mnist import MNISTDataModule -from project.utils.types import PhaseStr +from project.utils.types import PhaseStr, is_sequence_of + + +def flatten(x: jax.Array) -> jax.Array: + return x.reshape((x.shape[0], -1)) class CNN(flax.linen.Module): @@ -35,6 +39,7 @@ def __call__(self, x: jax.Array): x = flax.linen.Conv(features=64, kernel_size=(3, 3))(x) x = flax.linen.relu(x) x = flax.linen.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) + x = x.reshape((x.shape[0], -1)) # flatten x = flax.linen.Dense(features=256)(x) x = flax.linen.relu(x) @@ -100,66 +105,86 @@ def to_channels_last[T: jax.Array | torch.Tensor](tensor: T) -> T: return tensor.transpose(1, 3) +class JaxOperation(torch.nn.Module): + def __init__( + self, + jax_function: Callable[[VariableDict, jax.Array], jax.Array], + jax_params_dict: VariableDict, + ): + super().__init__() + self.jax_function = jax.jit(jax_function) + params_list, self.params_treedef = jax.tree.flatten(jax_params_dict) + # Register the parameters. + # Need to call .clone() when doing distributed training, otherwise we get a RuntimeError: + # Invalid device pointer when trying to share the CUDA memory. + self.params = torch.nn.ParameterList( + map(operator.methodcaller("clone"), map(jax_to_torch, params_list)) + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + out = JaxFunction.apply( + input, + self.params_treedef, + self.jax_function, + *self.params, + ) + return out + + if typing.TYPE_CHECKING: + __call__ = forward + + class JaxFunction(torch.autograd.Function): + """Wrapper for a jax function.""" params_treedef: ClassVar @staticmethod def forward( ctx: torch.autograd.function.NestedIOFunction, - x: torch.Tensor, - y: torch.Tensor, + input: torch.Tensor, params_treedef: PyTreeDef, - loss_fn: Callable[[VariableDict, jax.Array, jax.Array], tuple[jax.Array, jax.Array]], - loss_value_and_grad_fn: Callable[ - [VariableDict, jax.Array, jax.Array], tuple[jax.Array, jax.Array] - ], - *params: torch.Tensor, + jax_function: Callable[[VariableDict, jax.Array], jax.Array], + *params: torch.Tensor, # need to flatten the params for autograd to understand that they need a gradient. ): - jax_x = torch_to_jax(x) - jax_y = torch_to_jax(y) + jax_input = torch_to_jax(input) jax_params = tuple(map(torch_to_jax, params)) jax_params = jax.tree.unflatten(params_treedef, jax_params) needs_grad: tuple[bool, ...] = ctx.needs_input_grad # type: ignore - x_needs_grad, y_needs_grad, _, _, _, *params_need_grad = needs_grad - # todo: broaden a bit: - assert not x_needs_grad - assert not y_needs_grad - if all(params_need_grad): - # We're going to need to do the backward pass, so do it right away and save the grads - # in the context. - (loss, logits), param_grads = loss_value_and_grad_fn(jax_params, jax_x, jax_y) - flattened_param_grads = jax.tree.leaves(param_grads) - torch_grads = tuple(map(jax_to_torch, flattened_param_grads)) - ctx.save_for_backward(*torch_grads) + input_needs_grad, _, _, _, *params_need_grad = needs_grad + if any(params_need_grad) or input_needs_grad: + output, jvp_function = jax.vjp(jax_function, jax_params, jax_input) + ctx.jvp_function = jvp_function else: - assert not any(params_need_grad) - loss, logits = loss_fn(jax_params, jax_x, jax_y) - loss = jax_to_torch(loss) - logits = jax_to_torch(logits) - return loss, logits + # Forward pass without gradient computation. + output = jax_function(jax_params, jax_input) + output = jax_to_torch(output) + return output @staticmethod def backward( ctx: torch.autograd.function.NestedIOFunction, - grad_loss: torch.Tensor, - grad_logits: torch.Tensor, + grad_output: torch.Tensor, ): - x_needs_grad, y_needs_grad, _, _, _, *params_needs_grad = ctx.needs_input_grad + input_need_grad, _, _, *params_needs_grad = ctx.needs_input_grad # todo: broaden this a bit in case we need the grad of the input. # todo: Figure out how to do jax.grad for a function that outputs a matrix or vector. - assert not x_needs_grad - assert not y_needs_grad - + assert not input_need_grad grad_input = None - grad_y = None - if all(params_needs_grad): - params_grads = ctx.saved_tensors + if input_need_grad or any(params_needs_grad): + assert all(params_needs_grad) + jvp_function = ctx.jvp_function + jax_grad_output = torch_to_jax(grad_output) + jax_grad_params, jax_input_grad = jvp_function(jax_grad_output) + params_grads = jax.tree.map(jax_to_torch, jax.tree.leaves(jax_grad_params)) + assert is_sequence_of(params_grads, torch.Tensor) + + if input_need_grad: + grad_input = jax_to_torch(jax_input_grad) else: assert not any(params_needs_grad) params_grads = tuple(None for _ in params_needs_grad) - - return grad_input, grad_y, None, None, None, *params_grads + return grad_input, None, None, *params_grads class JaxAlgorithm(Algorithm): @@ -180,40 +205,39 @@ def __init__( ): super().__init__(datamodule=datamodule, hp=hp or self.HParams()) self.hp: JaxAlgorithm.HParams + torch.zeros(1, device="cuda") # weird cuda errors! key = jax.random.key(self.hp.seed) - self.network = network x = jax.random.uniform(key, shape=(datamodule.batch_size, *datamodule.dims)) x = to_channels_last(x) - params = self.network.init(key, x=x) - params_list, self.params_treedef = jax.tree.flatten(params) + jax_net = CNN() + params = jax_net.init(key, x=x) # Need to call .clone() when doing distributed training, otherwise we get a RuntimeError: # Invalid device pointer when trying to share the CUDA memory. - self.params = torch.nn.ParameterList( - map(operator.methodcaller("clone"), map(jax_to_torch, params_list)) - ) + + self.network = JaxOperation(jax_function=jax_net.apply, jax_params_dict=params) + + # self.params = torch.nn.ParameterList( + # map(operator.methodcaller("clone"), map(jax_to_torch, params_list)) + # ) self.automatic_optimization = True def on_fit_start(self): - # Setting those here, because otherwise we get pickling errors when running with multiple - # GPUs. - def loss_function( - params: VariableDict, - x: jax.Array, - y: jax.Array, - ): - logits = self.network.apply(params, x) - assert isinstance(logits, jax.Array) - loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=y).mean() - assert isinstance(loss, jax.Array) - return loss, logits - - self.forward_pass = loss_function - self.backward_pass = value_and_grad(self.forward_pass, argnums=0, has_aux=True) - - if not self.hp.debug: - self.forward_pass = jit(self.forward_pass) - self.backward_pass = jit(self.backward_pass) + pass + # # Setting those here, because otherwise we get pickling errors when running with multiple + # # GPUs. + # def loss_function( + # params: VariableDict, + # x: jax.Array, + # y: jax.Array, + # ): + + # self.forward_pass = loss_function + # self.backward_pass = value_and_grad(self.forward_pass, argnums=0, has_aux=True) + + # if not self.hp.debug: + # self.forward_pass = jit(self.forward_pass) + # self.backward_pass = jit(self.backward_pass) def shared_step( self, batch: tuple[torch.Tensor, torch.Tensor], batch_index: int, phase: PhaseStr @@ -225,13 +249,10 @@ def shared_step( x = to_channels_last(x) - loss: torch.Tensor - logits: torch.Tensor - loss, logits = JaxFunction.apply( # type: ignore - x, y, self.params_treedef, self.forward_pass, self.backward_pass, *self.parameters() - ) - + logits = self.network(x) assert isinstance(logits, torch.Tensor) + loss = torch.nn.functional.cross_entropy(logits, target=y).mean() + assert isinstance(loss, torch.Tensor) if phase == "train": assert loss.requires_grad self.log(f"{phase}/loss", loss, prog_bar=True, sync_dist=True) @@ -245,7 +266,7 @@ def configure_optimizers(self): def main(): trainer = Trainer(devices=1, accelerator="auto") - datamodule = MNISTDataModule(num_workers=4) + datamodule = MNISTDataModule(num_workers=4, batch_size=2) model = JaxAlgorithm(network=CNN(num_classes=datamodule.num_classes), datamodule=datamodule) trainer.fit(model, datamodule=datamodule)