Skip to content

Commit

Permalink
Tweak algo a bit (again)
Browse files Browse the repository at this point in the history
Signed-off-by: Fabrice Normandin <[email protected]>
  • Loading branch information
lebrice committed Jun 4, 2024
1 parent 1a71d48 commit 4deb156
Showing 1 changed file with 43 additions and 30 deletions.
73 changes: 43 additions & 30 deletions project/algorithms/jax_algo.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import dataclasses
import operator
from collections.abc import Callable
from typing import Concatenate, Literal, NamedTuple

import jax
import numpy as np
import torch
import torch.distributed
from lightning import Trainer
from torch_jax_interop import jax_to_torch, torch_to_jax

Expand Down Expand Up @@ -35,11 +37,11 @@ def loss_fn(
logits: jax.Array,
labels: jax.Array,
) -> jax.Array:
probs = jax.nn.log_softmax(logits)
assert isinstance(probs, jax.Array)
log_probs = jax.nn.log_softmax(logits, axis=-1)
assert isinstance(log_probs, jax.Array)
one_hot_labels = jax.nn.one_hot(labels, num_classes=logits.shape[-1])
assert isinstance(one_hot_labels, jax.Array)
loss = -(one_hot_labels * probs).sum(axis=-1).mean()
loss = -(one_hot_labels * log_probs).sum(axis=-1).mean()
return loss


Expand Down Expand Up @@ -95,81 +97,92 @@ def __init__(
self.hp: JaxAlgorithm.HParams
key = jax.random.key(self.hp.seed)
# todo: Extract out the "network" portion, and probably use something like flax for it.
self.jax_params = ParamsTuple(
params = ParamsTuple(
w1=jax.random.uniform(key=jax.random.fold_in(key, 1), shape=(input_dims, 128)),
b1=jax.random.uniform(key=jax.random.fold_in(key, 2), shape=(128,)),
w2=jax.random.uniform(key=jax.random.fold_in(key, 3), shape=(128, output_dims)),
b2=jax.random.uniform(key=jax.random.fold_in(key, 4), shape=(output_dims,)),
)
parameters, self.params_treedef = jax.tree.flatten(params)
# Need to call .clone() when doing distributed training, otherwise we get a RuntimeError:
# Invalid device pointer when trying to share the CUDA memory.
self.params = torch.nn.ParameterList(
map(operator.methodcaller("clone"), map(jax_to_torch, parameters))
)

# We will do the backward pass ourselves, and PL will only be used to synchronize stuff
# between workers, do logging, etc.
self.automatic_optimization = False

def on_fit_start(self):
# Setting those here, because otherwise we get pickling errors when running with multiple
# GPUs.
self.forward_pass = forward_pass
self.backward_pass = value_and_grad(self.forward_pass)

if not self.hp.debug:
self.forward_pass = jit(self.forward_pass)
self.backward_pass = jit(self.backward_pass)

# We will do the backward pass ourselves, and PL will synchronize stuff between workers, etc.
self.automatic_optimization = False

@property
def jax_params(self):
def jax_params(self) -> ParamsTuple[jax.Array]:
# View the torch parameters as jax Arrays
return ParamsTuple(**{k: torch_to_jax(p.data) for k, p in self.named_parameters()})

@jax_params.setter
def jax_params(self, value: ParamsTuple[jax.Array]):
for k, jax_v in value._asdict().items():
assert isinstance(jax_v, jax.Array)
torch_v = jax_to_torch(jax_v)
p: torch.nn.Parameter = torch.nn.Parameter(torch_v, requires_grad=True)
self.register_parameter(k, p)
jax_parameters = jax.tree.map(torch_to_jax, list(self.params))
# Reconstruct the original object structure.
jax_params_tuple = jax.tree.unflatten(self.params_treedef, jax_parameters)
return jax_params_tuple

def shared_step(
self, batch: tuple[torch.Tensor, torch.Tensor], batch_index: int, phase: PhaseStr
):
torch_x, torch_y = batch
x, y = batch
# Note: flattening the input also gets rid of the stride issues in jax.from_dlpack.
torch_x = torch_x.flatten(start_dim=1)
x = x.flatten(start_dim=1)
# View/"convert" the torch inputs to jax Arrays.
jax_x, jax_y = torch_to_jax(torch_x), torch_to_jax(torch_y)
x, y = torch_to_jax(x), torch_to_jax(y)

jax_params = self.jax_params
jax_params = self.jax_params()

if phase != "train":
# Only use the forward pass.
loss, logits = self.forward_pass(jax_params, jax_x, jax_y)
loss, logits = self.forward_pass(jax_params, x, y)
else:
optimizer = self.optimizers()
assert isinstance(optimizer, torch.optim.Optimizer)

# Perform the backward pass
(loss, logits), jax_grads = self.backward_pass(jax_params, jax_x, jax_y)
(loss, logits), jax_grads = self.backward_pass(jax_params, x, y)
distributed = torch.distributed.is_initialized()

with torch.no_grad():
# 'convert' the gradients to pytorch
torch_grads = map(jax_to_torch, jax_grads)
torch_grads = jax.tree.map(jax_to_torch, jax_grads)
# Update the torch parameters tensors in-place using the jax grads.
for param, grad in zip(self.parameters(), torch_grads):
for param, grad in zip(self.parameters(), jax.tree.leaves(torch_grads)):
if distributed:
torch.distributed.all_reduce(grad, op=torch.distributed.ReduceOp.AVG)
if param.grad is None:
param.grad = grad
else:
param.grad += grad
optimizer.step()
optimizer.zero_grad()

# IDEA: What about a hacky .backward method on a torch Tensor, that calls the backward pass
# and sets the grads? Could we then use automatic optimization?
torch_loss = jax_to_torch(loss)
torch_y = batch[1]
accuracy = jax_to_torch(logits).argmax(-1).eq(torch_y).float().mean()
self.log(f"{phase}/accuracy", accuracy, prog_bar=True)
self.log(f"{phase}/loss", torch_loss, prog_bar=True)
self.log(f"{phase}/accuracy", accuracy, prog_bar=True, sync_dist=True)
self.log(f"{phase}/loss", torch_loss, prog_bar=True, sync_dist=True)
return torch_loss

def configure_optimizers(self):
return torch.optim.SGD(self.parameters(), lr=self.hp.lr)


def main():
trainer = Trainer(devices=1, accelerator="auto")
datamodule = MNISTDataModule()
trainer = Trainer(devices="auto", accelerator="auto")
datamodule = MNISTDataModule(num_workers=4)
model = JaxAlgorithm(datamodule=datamodule)
trainer.fit(model, datamodule=datamodule)

Expand Down

0 comments on commit 4deb156

Please sign in to comment.