You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Provide as much information as possible. At least, this should include a description of your issue and steps to reproduce the problem. If possible also provide a summary of what steps or workarounds you have already tried.
System information
OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Linux Ubuntu 22.04
Flax, jax, jaxlib versions (obtain with pip show flax jax jaxlib: flax: 0.9.0, jax: 0.4.30, jaxlib: 0.4.30
Python version: 3.11
GPU/TPU model and memory: GPU: Nvidia RTX 4090
CUDA version (if applicable): 12.2
Problem you have encountered:
nnx.jit(aux_fn) is slower than directly using nnx.jit(model.__call__), where aux_fn is defined by
defaux_fn(model, x):
returnmodel(x)
From my understanding, I found that using an auxiliary function with nnx.jit seems a common practice and is required if we want to modify the internal state of the model (#3998). However, it seems slower than directly wrapping the model.__call__ function using nnx.jit.
importtimeimportjaxfromflaximportnnxasnnxclassMLP(nnx.Module):
def__init__(self, din: int, dout: int, rngs: nnx.Rngs) ->None:
# super().__init__()self.fc1=nnx.Linear(din, 128, rngs=rngs)
self.fc2=nnx.Linear(128, 128, rngs=rngs)
self.fc3=nnx.Linear(128, 128, rngs=rngs)
self.out=nnx.Linear(128, dout, rngs=rngs)
def__call__(self, x):
x=self.fc1(x)
x=nnx.relu(x)
x=self.fc2(x)
x=nnx.relu(x)
x=self.fc3(x)
x=nnx.relu(x)
x=self.out(x)
returnxdefnn_forward(model, x):
returnmodel, xdefbenchmark_jax():
rngs=nnx.Rngs(0)
din, dout=29, 7# Example dimensionsmlp=MLP(din, dout, rngs)
nn_forward_call_no_aux=nnx.jit(mlp.__call__)
# Prepare datax=jax.random.normal(rngs(), shape=(1, din))
num_iterations=1000warmup_iters=100for_inrange(warmup_iters):
_=nn_forward_call_no_aux(x)
start_time=time.time()
for_inrange(num_iterations):
_=nn_forward_call_no_aux(x)
end_time=time.time()
print(f"JAX forward pass time for {num_iterations} iterations: {end_time-start_time:.5f} seconds")
print(f"JAX forward pass average time: {(end_time-start_time) /num_iterations:.5f} seconds")
print("-------------------")
nn_forward_jit=nnx.jit(nn_forward)
for_inrange(warmup_iters):
_=nn_forward_jit(mlp, x)
start_time=time.time()
for_inrange(num_iterations):
_=nn_forward_jit(mlp, x)
end_time=time.time()
print(f"JAX forward pass time while using auxiliary functions for {num_iterations} iterations: {end_time-start_time:.5f} seconds")
print(f"JAX forward pass average while using auxiliary functions time: {(end_time-start_time) /num_iterations:.5f} seconds")
The outputs using a RTX 4090 are:
JAX forward pass time for 1000 iterations: 0.10531 seconds
JAX forward pass average time: 0.00011 seconds
-------------------
JAX forward pass time while using auxiliary functions for 1000 iterations: 0.59596 seconds
JAX forward pass average while using auxiliary functions time: 0.00060 seconds
The text was updated successfully, but these errors were encountered:
Just to clarify, what is happening is that mlp.__call__ is not traversing self so its faster, a lot faster in this case.
We are going to be developing a Rust extension (see #4196) so in the future nnx.jit should be fast. For now consider using this pattern to remove the python overhead.
Provide as much information as possible. At least, this should include a description of your issue and steps to reproduce the problem. If possible also provide a summary of what steps or workarounds you have already tried.
System information
pip show flax jax jaxlib
:flax
: 0.9.0,jax
: 0.4.30,jaxlib
: 0.4.30Problem you have encountered:
nnx.jit(aux_fn)
is slower than directly usingnnx.jit(model.__call__)
, whereaux_fn
is defined byFrom my understanding, I found that using an auxiliary function with
nnx.jit
seems a common practice and is required if we want to modify the internal state of the model (#3998). However, it seems slower than directly wrapping themodel.__call__
function usingnnx.jit
.See the colab link below to reproduce.
Steps to reproduce:
Colab link: https://colab.research.google.com/drive/1cGpcaBaJABUxhZuywgLZELZRwFsT5zve?usp=sharing
For completeness, I also copy the code here
The outputs using a RTX 4090 are:
The text was updated successfully, but these errors were encountered: