diff --git a/pdm.lock b/pdm.lock index b8891c73..751af386 100644 --- a/pdm.lock +++ b/pdm.lock @@ -1188,7 +1188,7 @@ files = [ [[package]] name = "jaxlib" -version = "0.4.28+cuda12.cudnn89" +version = "0.4.28" requires_python = ">=3.9" summary = "XLA library for JAX" groups = ["default"] @@ -2823,7 +2823,7 @@ files = [ [[package]] name = "torch-jax-interop" -version = "0.0.2" +version = "0.0.3" requires_python = "<4.0,>=3.11" summary = "Utility to convert Tensors from Jax to Torch and vice-versa" groups = ["default"] @@ -2834,8 +2834,8 @@ dependencies = [ "torch<3.0.0,>=2.3.0", ] files = [ - {file = "torch_jax_interop-0.0.2-py3-none-any.whl", hash = "sha256:a155c3dd3b00017040755b77e7ff87599fa1d909d8be0f6a9854ed57e94a04d7"}, - {file = "torch_jax_interop-0.0.2.tar.gz", hash = "sha256:cb1e8d8e8195652f244031cb21391483bfa6d790931f257b0566dc522c1b0e75"}, + {file = "torch_jax_interop-0.0.3-py3-none-any.whl", hash = "sha256:fcae8b45304ef4f8d8cd6b40f078c05147d8bafec65dec3a79cee3595aa8dc4e"}, + {file = "torch_jax_interop-0.0.3.tar.gz", hash = "sha256:44fdc37fe89be32de85f648e0604056dd1af46bb869aab4ddcb2267198c6cf86"}, ] [[package]] diff --git a/project/algorithms/jax_algo.py b/project/algorithms/jax_algo.py index 15329924..d5085ffc 100644 --- a/project/algorithms/jax_algo.py +++ b/project/algorithms/jax_algo.py @@ -1,26 +1,33 @@ import dataclasses from collections.abc import Callable +from typing import Concatenate, Literal, NamedTuple import jax +import numpy as np import torch from lightning import Trainer from torch_jax_interop import jax_to_torch, torch_to_jax from project.algorithms.bases.algorithm import Algorithm +from project.datamodules.image_classification.base import ImageClassificationDataModule from project.datamodules.image_classification.mnist import MNISTDataModule from project.utils.types import PhaseStr -from project.utils.types.protocols import DataModule -type ParamsTuple = tuple[jax.Array, ...] +# type ParamsTuple = tuple[jax.Array, ...] -def fcnet( - input: jax.Array, w1: jax.Array, b1: jax.Array, w2: jax.Array, b2: jax.Array -) -> jax.Array: +class ParamsTuple[T: torch.Tensor | jax.Array](NamedTuple): + w1: T + b1: T + w2: T + b2: T + + +def fcnet(input: jax.Array, params: ParamsTuple) -> jax.Array: """Forward pass of a simple two-layer fully-connected neural network with relu activation.""" - z1 = jax.numpy.matmul(input, w1) + b1 + z1 = jax.numpy.matmul(input, params.w1) + params.b1 a1 = jax.nn.relu(z1) - logits = jax.numpy.matmul(a1, w2) + b2 + logits = jax.numpy.matmul(a1, params.w2) + params.b2 return logits @@ -29,60 +36,80 @@ def loss_fn( labels: jax.Array, ) -> jax.Array: probs = jax.nn.log_softmax(logits) + assert isinstance(probs, jax.Array) one_hot_labels = jax.nn.one_hot(labels, num_classes=logits.shape[-1]) assert isinstance(one_hot_labels, jax.Array) - assert isinstance(probs, jax.Array) loss = -(one_hot_labels * probs).sum(axis=-1).mean() return loss -def forward_pass(params: ParamsTuple, x: jax.Array, y: jax.Array) -> tuple[jax.Array, jax.Array]: - logits = fcnet(x, *params) - return loss_fn(logits, y), logits +def forward_pass( + params: ParamsTuple[jax.Array], x: jax.Array, y: jax.Array +) -> tuple[jax.Array, jax.Array]: + logits = fcnet(x, params) + loss = loss_fn(logits, y) + return loss, logits + +def jit[**P, Out]( + fn: Callable[P, Out], +) -> Callable[P, Out]: + """Small type hint fix for jax's `jit` (preserves the signature of the callable).""" + return jax.jit(fn) # type: ignore -backward_pass: Callable[ - [ParamsTuple, jax.Array, jax.Array], tuple[tuple[jax.Array, jax.Array], ParamsTuple] -] = jax.value_and_grad(forward_pass, argnums=0, has_aux=True) + +def value_and_grad[In, **P, Out, Aux]( + fn: Callable[Concatenate[In, P], tuple[Out, Aux]], + argnums: Literal[0] = 0, + has_aux: Literal[True] = True, +) -> Callable[Concatenate[In, P], tuple[tuple[Out, Aux], In]]: + """Small type hint fix for jax's `value_and_grad` (preserves the signature of the callable).""" + return jax.value_and_grad(fn, argnums=argnums, has_aux=has_aux) # type: ignore + + +# Register a handler for "converting" nn.Parameters to jax Arrays: they can be viewed as jax Arrays +# by just viewing their data as a jax array. +@torch_to_jax.register(torch.nn.Parameter) +def _parameter_to_jax_array(value: torch.nn.Parameter) -> jax.Array: + return torch_to_jax(value.data) class JaxAlgorithm(Algorithm): """Example of an algorithm where the forward / backward passes are written in Jax.""" - @dataclasses.dataclass() + @dataclasses.dataclass class HParams(Algorithm.HParams): lr: float = 1e-3 seed: int = 123 - debug: bool = True + debug: bool = False def __init__( self, *, - datamodule: DataModule | None = None, + datamodule: ImageClassificationDataModule, hp: HParams | None = None, ): super().__init__(datamodule=datamodule, hp=hp or self.HParams()) + input_dims = int(np.prod(datamodule.dims)) + output_dims = datamodule.num_classes self.hp: JaxAlgorithm.HParams key = jax.random.key(self.hp.seed) - self.w1 = torch.nn.Parameter( - jax_to_torch(jax.random.uniform(key=jax.random.fold_in(key, 1), shape=(784, 128))), - requires_grad=True, + # todo: Extract out the "network" portion, and probably use something like flax for it. + params = ParamsTuple( + w1=jax.random.uniform(key=jax.random.fold_in(key, 1), shape=(input_dims, 128)), + b1=jax.random.uniform(key=jax.random.fold_in(key, 2), shape=(128,)), + w2=jax.random.uniform(key=jax.random.fold_in(key, 3), shape=(128, output_dims)), + b2=jax.random.uniform(key=jax.random.fold_in(key, 4), shape=(output_dims,)), ) - self.b1 = torch.nn.Parameter( - jax_to_torch(jax.random.uniform(key=jax.random.fold_in(key, 2), shape=(128,))), - requires_grad=True, - ) - self.w2 = torch.nn.Parameter( - jax_to_torch(jax.random.uniform(key=jax.random.fold_in(key, 3), shape=(128, 10))), - requires_grad=True, - ) - self.b2 = torch.nn.Parameter( - jax_to_torch(jax.random.uniform(key=jax.random.fold_in(key, 4), shape=(10,))), - requires_grad=True, + self.params = torch.nn.ParameterList( + [torch.nn.Parameter(v, requires_grad=True) for v in map(jax_to_torch, params)] ) + self.forward_pass = forward_pass + self.backward_pass = value_and_grad(self.forward_pass) - self.forward_pass = jax.jit(forward_pass) if not self.hp.debug else forward_pass - self.backward_pass = jax.jit(backward_pass) if not self.hp.debug else backward_pass + if not self.hp.debug: + self.forward_pass = jit(self.forward_pass) + self.backward_pass = jit(self.backward_pass) # We will do the backward pass ourselves, and PL will synchronize stuff between workers, etc. self.automatic_optimization = False @@ -91,14 +118,13 @@ def shared_step( self, batch: tuple[torch.Tensor, torch.Tensor], batch_index: int, phase: PhaseStr ): torch_x, torch_y = batch - # note: Also gets rid of the stride issues. in jax.from_dlpack. + # Note: flattening the input also gets rid of the stride issues in jax.from_dlpack. torch_x = torch_x.flatten(start_dim=1) - jax_x, jax_y = jax.tree.map(torch_to_jax, [torch_x, torch_y]) - assert isinstance(jax_x, jax.Array) - assert isinstance(jax_y, jax.Array) + # View/"convert" the torch inputs to jax Arrays. + jax_x, jax_y = torch_to_jax(torch_x), torch_to_jax(torch_y) - torch_params = tuple(p.data for p in self.parameters()) - jax_params: ParamsTuple = jax.tree.map(torch_to_jax, torch_params) + # View the parameters as jax Arrays + jax_params = ParamsTuple(*map(torch_to_jax, self.parameters())) if phase != "train": # Only use the forward pass. @@ -109,9 +135,8 @@ def shared_step( # Perform the backward pass (loss, logits), jax_grads = self.backward_pass(jax_params, jax_x, jax_y) - torch_grads = jax.tree.map(jax_to_torch, jax_grads) - with torch.no_grad(): + torch_grads = map(jax_to_torch, jax_grads) for param, grad in zip(self.parameters(), torch_grads): if param.grad is None: param.grad = grad @@ -120,9 +145,8 @@ def shared_step( optimizer.step() optimizer.zero_grad() - torch_logits = jax_to_torch(logits) torch_loss = jax_to_torch(loss) - accuracy = torch_logits.argmax(-1).eq(torch_y).float().mean() + accuracy = jax_to_torch(logits).argmax(-1).eq(torch_y).float().mean() self.log(f"{phase}/accuracy", accuracy, prog_bar=True) self.log(f"{phase}/loss", torch_loss, prog_bar=True) return torch_loss @@ -133,8 +157,9 @@ def configure_optimizers(self): def main(): trainer = Trainer(devices=1, accelerator="auto") - model = JaxAlgorithm() - trainer.fit(model, datamodule=MNISTDataModule()) + datamodule = MNISTDataModule() + model = JaxAlgorithm(datamodule=datamodule) + trainer.fit(model, datamodule=datamodule) ... diff --git a/project/datamodules/vision/base.py b/project/datamodules/vision/base.py index 14a43adc..88b85fcb 100644 --- a/project/datamodules/vision/base.py +++ b/project/datamodules/vision/base.py @@ -21,7 +21,11 @@ P = ParamSpec("P") SLURM_TMPDIR: Path | None = ( - Path(os.environ["SLURM_TMPDIR"]) if "SLURM_TMPDIR" in os.environ else None + Path(os.environ["SLURM_TMPDIR"]) + if "SLURM_TMPDIR" in os.environ + else tmp + if "SLURM_JOB_ID" in os.environ and (tmp := Path("/tmp")).exists() + else None ) logger = get_logger(__name__)