Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Running the executable compiled directly from jax.jit is more than three times slower than jax.jit itself. #25023

Open
caixiiaoyang opened this issue Nov 21, 2024 · 1 comment
Labels
bug Something isn't working

Comments

@caixiiaoyang
Copy link

Description

I want to save the executable file generated by jax.jit in the main process and execute this executable file in another process. However, I have found that the performance of the executable file is much slower than that of the JIT itself. I would like to know why this is the case.

1.The execution code and results of jax.jit are as follows.

1.1 The execution code


from flax import linen as nn
from flax.training import train_state
import jax
import jax.numpy as jnp
import numpy as np
import optax
import tqdm

class CNN(nn.Module):
  """A simple CNN model."""
  @nn.compact
  def __call__(self, x):
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))  # flatten
    x = nn.Dense(features=256)(x)
    x = nn.relu(x)
    x = nn.Dense(features=10)(x)
    return x
  
@jax.jit
def train_step(state, images, labels):
  """Computes gradients, loss and accuracy for a single batch."""
  def loss_fn(params):
    logits = state.apply_fn({'params': params}, images)
    one_hot = jax.nn.one_hot(labels, 10)
    loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
    return loss, logits

  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (loss, logits), grads = grad_fn(state.params)
  accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
  state = state.apply_gradients(grads=grads)
  return state, loss, accuracy

def train_epoch(state,  steps_per_epoch):
  """Train for a single epoch."""
  rng_key = jax.random.PRNGKey(42)
  train_input = jax.random.uniform(rng_key, shape=(128 ,64 ,64, 3))
  train_label = jax.numpy.ones(shape=(128,))
  
  for i in tqdm.tqdm(range(steps_per_epoch)):
    state, loss, accuracy = train_step(state, train_input, train_label)
    
def create_train_state():
  """Creates initial `TrainState`."""
  cnn = CNN()
  rng = jax.random.PRNGKey(42)
  params = cnn.init(rng, jnp.ones([1, 64, 64, 3]))['params']
  tx = optax.sgd(0.9)
  return train_state.TrainState.create(
      apply_fn=cnn.apply, params=params, tx=tx)
    
if __name__ == "__main__":
  state = create_train_state()
  for i in range(5):
    train_epoch(state, 10000)

1.2 The execution results


Train for 5 epochs, with 10,000 steps per epoch, and each epoch takes 15 seconds.
cnn

2.The code and results from running the executable file generated by jax.jit are as follows.

2.1 The code


import os
import tqdm

os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
os.environ["JAX_ENABLE_X64"] = "1"

import jax

jax.config.update("jax_platform_name", "gpu")

import jax.numpy as jnp
import numpy as np
import jaxlib.xla_extension as xe
from jax.lib import (xla_bridge as xb, xla_client as xc)

print(xb.get_backend().platform)
import pickle

backend = xb.get_backend("cuda")

with open("./xla_executable.exe", "rb") as f:
    a = f.read()
    compiled = backend.deserialize_executable(a)

print(compiled.hlo_modules())

with open("./input_bufs.pkl", "rb") as f:
    input_array0 = pickle.load(f)

input_array = jax.device_put(input_array0)
jax.block_until_ready(input_array)
for i in range(5):
    for i in tqdm.tqdm(range(10000)):
        out_ = compiled.execute_sharded(input_array)

2.2 The results


Train for 5 epochs, with 10,000 steps per epoch, and each epoch takes 50 seconds.
executable

System info (python version, jaxlib version, accelerator, etc.)

image

@caixiiaoyang caixiiaoyang added the bug Something isn't working label Nov 21, 2024
@yashk2810
Copy link
Collaborator

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants