Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nnx.jit(aux_fn) is slower than directly using nnx.jit(model.__call__) #4218

Open
JunhongXu opened this issue Sep 23, 2024 · 3 comments
Open

Comments

@JunhongXu
Copy link

Provide as much information as possible. At least, this should include a description of your issue and steps to reproduce the problem. If possible also provide a summary of what steps or workarounds you have already tried.

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Linux Ubuntu 22.04
  • Flax, jax, jaxlib versions (obtain with pip show flax jax jaxlib: flax: 0.9.0, jax: 0.4.30, jaxlib: 0.4.30
  • Python version: 3.11
  • GPU/TPU model and memory: GPU: Nvidia RTX 4090
  • CUDA version (if applicable): 12.2

Problem you have encountered:

nnx.jit(aux_fn) is slower than directly using nnx.jit(model.__call__), where aux_fn is defined by

def aux_fn(model, x):
    return model(x)

From my understanding, I found that using an auxiliary function with nnx.jit seems a common practice and is required if we want to modify the internal state of the model (#3998). However, it seems slower than directly wrapping the model.__call__ function using nnx.jit.

See the colab link below to reproduce.

Steps to reproduce:

Colab link: https://colab.research.google.com/drive/1cGpcaBaJABUxhZuywgLZELZRwFsT5zve?usp=sharing

For completeness, I also copy the code here

import time
import jax
from flax import nnx as nnx


class MLP(nnx.Module):
	def __init__(self, din: int, dout: int, rngs: nnx.Rngs) -> None:
		# super().__init__()
		self.fc1 = nnx.Linear(din, 128, rngs=rngs)
		self.fc2 = nnx.Linear(128, 128, rngs=rngs)
		self.fc3 = nnx.Linear(128, 128, rngs=rngs)
		self.out = nnx.Linear(128, dout, rngs=rngs)

	def __call__(self, x):
		x = self.fc1(x)
		x = nnx.relu(x)
		x = self.fc2(x)
		x = nnx.relu(x)
		x = self.fc3(x)
		x = nnx.relu(x)
		x = self.out(x)
		return x


def nn_forward(model, x):
    return model, x


def benchmark_jax():
    rngs = nnx.Rngs(0)
    din, dout = 29, 7  # Example dimensions
    mlp = MLP(din, dout, rngs)
    nn_forward_call_no_aux = nnx.jit(mlp.__call__)

    # Prepare data
    x = jax.random.normal(rngs(), shape=(1, din))
    num_iterations = 1000
    warmup_iters = 100

    for _ in range(warmup_iters):
        _ = nn_forward_call_no_aux(x)

    start_time = time.time()
    for _ in range(num_iterations):
        _ = nn_forward_call_no_aux(x)
    end_time = time.time()

    print(f"JAX forward pass time for {num_iterations} iterations: {end_time - start_time:.5f} seconds")
    print(f"JAX forward pass average time: {(end_time - start_time) / num_iterations:.5f} seconds")

    print("-------------------")
    nn_forward_jit = nnx.jit(nn_forward)
    for _ in range(warmup_iters):
        _ = nn_forward_jit(mlp, x)

    start_time = time.time()
    for _ in range(num_iterations):
        _ = nn_forward_jit(mlp, x)
    end_time = time.time()
    print(f"JAX forward pass time while using auxiliary functions for {num_iterations} iterations: {end_time - start_time:.5f} seconds")
    print(f"JAX forward pass average while using auxiliary functions time: {(end_time - start_time) / num_iterations:.5f} seconds")

The outputs using a RTX 4090 are:

JAX forward pass time for 1000 iterations: 0.10531 seconds
JAX forward pass average time: 0.00011 seconds
-------------------
JAX forward pass time while using auxiliary functions for 1000 iterations: 0.59596 seconds
JAX forward pass average while using auxiliary functions time: 0.00060 seconds
@cgarciae
Copy link
Collaborator

mlp.__call__ is not recommended as you are passing self as a capture. Try MLP.__call__ and passing mlp as the first input.

@cgarciae
Copy link
Collaborator

Just to clarify, what is happening is that mlp.__call__ is not traversing self so its faster, a lot faster in this case.
We are going to be developing a Rust extension (see #4196) so in the future nnx.jit should be fast. For now consider using this pattern to remove the python overhead.

@cgarciae
Copy link
Collaborator

I've created this mini guide to clarify the situation around performance: #4224.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants