diff --git a/project/algorithms/jax_algo.py b/project/algorithms/jax_algo.py index 3f5ee954..33db020a 100644 --- a/project/algorithms/jax_algo.py +++ b/project/algorithms/jax_algo.py @@ -1,10 +1,12 @@ import dataclasses +import operator from collections.abc import Callable from typing import Concatenate, Literal, NamedTuple import jax import numpy as np import torch +import torch.distributed from lightning import Trainer from torch_jax_interop import jax_to_torch, torch_to_jax @@ -35,11 +37,11 @@ def loss_fn( logits: jax.Array, labels: jax.Array, ) -> jax.Array: - probs = jax.nn.log_softmax(logits) - assert isinstance(probs, jax.Array) + log_probs = jax.nn.log_softmax(logits, axis=-1) + assert isinstance(log_probs, jax.Array) one_hot_labels = jax.nn.one_hot(labels, num_classes=logits.shape[-1]) assert isinstance(one_hot_labels, jax.Array) - loss = -(one_hot_labels * probs).sum(axis=-1).mean() + loss = -(one_hot_labels * log_probs).sum(axis=-1).mean() return loss @@ -95,13 +97,26 @@ def __init__( self.hp: JaxAlgorithm.HParams key = jax.random.key(self.hp.seed) # todo: Extract out the "network" portion, and probably use something like flax for it. - self.jax_params = ParamsTuple( + params = ParamsTuple( w1=jax.random.uniform(key=jax.random.fold_in(key, 1), shape=(input_dims, 128)), b1=jax.random.uniform(key=jax.random.fold_in(key, 2), shape=(128,)), w2=jax.random.uniform(key=jax.random.fold_in(key, 3), shape=(128, output_dims)), b2=jax.random.uniform(key=jax.random.fold_in(key, 4), shape=(output_dims,)), ) + parameters, self.params_treedef = jax.tree.flatten(params) + # Need to call .clone() when doing distributed training, otherwise we get a RuntimeError: + # Invalid device pointer when trying to share the CUDA memory. + self.params = torch.nn.ParameterList( + map(operator.methodcaller("clone"), map(jax_to_torch, parameters)) + ) + + # We will do the backward pass ourselves, and PL will only be used to synchronize stuff + # between workers, do logging, etc. + self.automatic_optimization = False + def on_fit_start(self): + # Setting those here, because otherwise we get pickling errors when running with multiple + # GPUs. self.forward_pass = forward_pass self.backward_pass = value_and_grad(self.forward_pass) @@ -109,47 +124,42 @@ def __init__( self.forward_pass = jit(self.forward_pass) self.backward_pass = jit(self.backward_pass) - # We will do the backward pass ourselves, and PL will synchronize stuff between workers, etc. - self.automatic_optimization = False - - @property - def jax_params(self): + def jax_params(self) -> ParamsTuple[jax.Array]: # View the torch parameters as jax Arrays - return ParamsTuple(**{k: torch_to_jax(p.data) for k, p in self.named_parameters()}) - - @jax_params.setter - def jax_params(self, value: ParamsTuple[jax.Array]): - for k, jax_v in value._asdict().items(): - assert isinstance(jax_v, jax.Array) - torch_v = jax_to_torch(jax_v) - p: torch.nn.Parameter = torch.nn.Parameter(torch_v, requires_grad=True) - self.register_parameter(k, p) + jax_parameters = jax.tree.map(torch_to_jax, list(self.params)) + # Reconstruct the original object structure. + jax_params_tuple = jax.tree.unflatten(self.params_treedef, jax_parameters) + return jax_params_tuple def shared_step( self, batch: tuple[torch.Tensor, torch.Tensor], batch_index: int, phase: PhaseStr ): - torch_x, torch_y = batch + x, y = batch # Note: flattening the input also gets rid of the stride issues in jax.from_dlpack. - torch_x = torch_x.flatten(start_dim=1) + x = x.flatten(start_dim=1) # View/"convert" the torch inputs to jax Arrays. - jax_x, jax_y = torch_to_jax(torch_x), torch_to_jax(torch_y) + x, y = torch_to_jax(x), torch_to_jax(y) - jax_params = self.jax_params + jax_params = self.jax_params() if phase != "train": # Only use the forward pass. - loss, logits = self.forward_pass(jax_params, jax_x, jax_y) + loss, logits = self.forward_pass(jax_params, x, y) else: optimizer = self.optimizers() assert isinstance(optimizer, torch.optim.Optimizer) # Perform the backward pass - (loss, logits), jax_grads = self.backward_pass(jax_params, jax_x, jax_y) + (loss, logits), jax_grads = self.backward_pass(jax_params, x, y) + distributed = torch.distributed.is_initialized() + with torch.no_grad(): # 'convert' the gradients to pytorch - torch_grads = map(jax_to_torch, jax_grads) + torch_grads = jax.tree.map(jax_to_torch, jax_grads) # Update the torch parameters tensors in-place using the jax grads. - for param, grad in zip(self.parameters(), torch_grads): + for param, grad in zip(self.parameters(), jax.tree.leaves(torch_grads)): + if distributed: + torch.distributed.all_reduce(grad, op=torch.distributed.ReduceOp.AVG) if param.grad is None: param.grad = grad else: @@ -157,10 +167,13 @@ def shared_step( optimizer.step() optimizer.zero_grad() + # IDEA: What about a hacky .backward method on a torch Tensor, that calls the backward pass + # and sets the grads? Could we then use automatic optimization? torch_loss = jax_to_torch(loss) + torch_y = batch[1] accuracy = jax_to_torch(logits).argmax(-1).eq(torch_y).float().mean() - self.log(f"{phase}/accuracy", accuracy, prog_bar=True) - self.log(f"{phase}/loss", torch_loss, prog_bar=True) + self.log(f"{phase}/accuracy", accuracy, prog_bar=True, sync_dist=True) + self.log(f"{phase}/loss", torch_loss, prog_bar=True, sync_dist=True) return torch_loss def configure_optimizers(self): @@ -168,8 +181,8 @@ def configure_optimizers(self): def main(): - trainer = Trainer(devices=1, accelerator="auto") - datamodule = MNISTDataModule() + trainer = Trainer(devices="auto", accelerator="auto") + datamodule = MNISTDataModule(num_workers=4) model = JaxAlgorithm(datamodule=datamodule) trainer.fit(model, datamodule=datamodule)