Skip to content

Commit

Permalink
[nnx] jit cache
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Dec 19, 2024
1 parent fc38f21 commit 1becd1d
Show file tree
Hide file tree
Showing 8 changed files with 406 additions and 88 deletions.
63 changes: 43 additions & 20 deletions benchmarks/nnx_simple_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
from absl import app

FLAGS = flags.FLAGS
flags.DEFINE_enum('mode', 'nnx', ['nnx', 'jax'], 'Mode to run the script in')
flags.DEFINE_enum(
'mode', 'all', ['all', 'nnx', 'jax'], 'Mode to run the script in'
)
flags.DEFINE_integer('total_steps', 10_000, 'Total number of training steps')
flags.DEFINE_integer('batch_size', 32, 'Batch size')
flags.DEFINE_integer('width', 32, 'Hidden layer size')
Expand All @@ -46,6 +48,13 @@ def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
def __call__(self, x):
return x @ self.w + self.b

class Block(nnx.Module):
def __init__(self, din, dhidden, *, rngs: nnx.Rngs):
self.linear = Linear(din, dhidden, rngs=rngs)
self.bn = nnx.BatchNorm(dhidden, rngs=rngs)

def __call__(self, x):
return nnx.relu(self.bn(self.linear(x)))

class Count(nnx.Variable):
pass
Expand All @@ -54,11 +63,11 @@ class Count(nnx.Variable):
class MLP(nnx.Module):
def __init__(self, din, dhidden, dout, depth, *, rngs: nnx.Rngs):
self.count = Count(jnp.array(0))
self.linear_in = Linear(din, dhidden, rngs=rngs)
self.linear_in = Block(din, dhidden, rngs=rngs)
self.intermediates = [
Linear(dhidden, dhidden, rngs=rngs) for _ in range(depth - 2)
Block(dhidden, dhidden, rngs=rngs) for _ in range(depth - 2)
]
self.linear_out = Linear(dhidden, dout, rngs=rngs)
self.linear_out = Block(dhidden, dout, rngs=rngs)

def __call__(self, x):
self.count.value += 1
Expand All @@ -79,18 +88,14 @@ def main(argv):

print(f'{mode=}, {total_steps=}, {batch_size=}, {width=}')

if mode not in ['nnx', 'jax']:
raise ValueError(f'Invalid mode: {mode}')

X = np.linspace(0, 1, 100)[:, None]
Y = 0.8 * X**2 + 0.1 + np.random.normal(0, 0.1, size=X.shape)

model = MLP(din=1, dhidden=width, dout=1, depth=depth, rngs=nnx.Rngs(0))
tx = optax.sgd(1e-3)
optimizer = nnx.Optimizer(model, tx)
t0 = time()

if mode == 'nnx':
if mode == 'nnx' or mode == 'all':
model = MLP(din=1, dhidden=width, dout=1, depth=depth, rngs=nnx.Rngs(0))
tx = optax.sgd(1e-3)
optimizer = nnx.Optimizer(model, tx)
t0 = time()

@nnx.jit
def train_step_nnx(model: MLP, optimizer: nnx.Optimizer, batch):
Expand All @@ -110,16 +115,30 @@ def test_step_nnx(model: MLP, batch):
loss = jnp.mean((y - y_pred) ** 2)
return {'loss': loss}

print('### NNX ###')
for step, batch in enumerate(dataset(X, Y, batch_size)):
train_step_nnx(model, optimizer, batch)

if step % 1000 == 0:
logs = test_step_nnx(model, (X, Y))
print(f"step: {step}, loss: {logs['loss']}")

if step >= total_steps - 1:
break
else:

total_time = time() - t0
time_per_step = total_time / total_steps
time_per_layer = time_per_step / depth

print(f"step: {step}, loss: {logs['loss']}")
print('total time:', total_time)
print(f'time per step: {time_per_step * 1e6:.2f} µs')
print(f'time per layer: {time_per_layer * 1e6:.2f} µs')

if mode == 'jax' or mode == 'all':
model = MLP(din=1, dhidden=width, dout=1, depth=depth, rngs=nnx.Rngs(0))
tx = optax.sgd(1e-3)
optimizer = nnx.Optimizer(model, tx)
t0 = time()

@jax.jit
def train_step_jax(graphdef, state, batch):
Expand All @@ -146,22 +165,26 @@ def test_step_jax(graphdef, state, batch):

graphdef, state = nnx.split((model, optimizer))

print('### JAX ###')
for step, batch in enumerate(dataset(X, Y, batch_size)):
state = train_step_jax(graphdef, state, batch)

if step % 1000 == 0:
state, logs = test_step_jax(graphdef, state, (X, Y))
print(f"step: {step}, loss: {logs['loss']}")

if step >= total_steps - 1:
break

model, optimizer = nnx.merge(graphdef, state)

total_time = time() - t0
print('total time:', total_time)
print(f'time per step: {total_time / total_steps * 1e6:.2f} µs')
print('times called:', model.count.value)
total_time = time() - t0
time_per_step = total_time / total_steps
time_per_layer = time_per_step / depth

print(f"step: {step}, loss: {logs['loss']}")
print('total time:', total_time)
print(f'time per step: {time_per_step * 1e6:.2f} µs')
print(f'time per layer: {time_per_layer * 1e6:.2f} µs')


if __name__ == '__main__':
Expand Down
Loading

0 comments on commit 1becd1d

Please sign in to comment.