You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I want to save the executable file generated by jax.jit in the main process and execute this executable file in another process. However, I have found that the performance of the executable file is much slower than that of the JIT itself. I would like to know why this is the case.
1.The execution code and results of jax.jit are as follows.
1.1 The execution code
from flax import linen as nn
from flax.training import train_state
import jax
import jax.numpy as jnp
import numpy as np
import optax
import tqdm
class CNN(nn.Module):
"""A simple CNN model."""
@nn.compact
def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
return x
@jax.jit
def train_step(state, images, labels):
"""Computes gradients, loss and accuracy for a single batch."""
def loss_fn(params):
logits = state.apply_fn({'params': params}, images)
one_hot = jax.nn.one_hot(labels, 10)
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, logits), grads = grad_fn(state.params)
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
state = state.apply_gradients(grads=grads)
return state, loss, accuracy
def train_epoch(state, steps_per_epoch):
"""Train for a single epoch."""
rng_key = jax.random.PRNGKey(42)
train_input = jax.random.uniform(rng_key, shape=(128 ,64 ,64, 3))
train_label = jax.numpy.ones(shape=(128,))
for i in tqdm.tqdm(range(steps_per_epoch)):
state, loss, accuracy = train_step(state, train_input, train_label)
def create_train_state():
"""Creates initial `TrainState`."""
cnn = CNN()
rng = jax.random.PRNGKey(42)
params = cnn.init(rng, jnp.ones([1, 64, 64, 3]))['params']
tx = optax.sgd(0.9)
return train_state.TrainState.create(
apply_fn=cnn.apply, params=params, tx=tx)
if __name__ == "__main__":
state = create_train_state()
for i in range(5):
train_epoch(state, 10000)
1.2 The execution results
Train for 5 epochs, with 10,000 steps per epoch, and each epoch takes 15 seconds.
2.The code and results from running the executable file generated by jax.jit are as follows.
2.1 The code
import os
import tqdm
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
os.environ["JAX_ENABLE_X64"] = "1"
import jax
jax.config.update("jax_platform_name", "gpu")
import jax.numpy as jnp
import numpy as np
import jaxlib.xla_extension as xe
from jax.lib import (xla_bridge as xb, xla_client as xc)
print(xb.get_backend().platform)
import pickle
backend = xb.get_backend("cuda")
with open("./xla_executable.exe", "rb") as f:
a = f.read()
compiled = backend.deserialize_executable(a)
print(compiled.hlo_modules())
with open("./input_bufs.pkl", "rb") as f:
input_array0 = pickle.load(f)
input_array = jax.device_put(input_array0)
jax.block_until_ready(input_array)
for i in range(5):
for i in tqdm.tqdm(range(10000)):
out_ = compiled.execute_sharded(input_array)
2.2 The results
Train for 5 epochs, with 10,000 steps per epoch, and each epoch takes 50 seconds.
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered:
Description
I want to save the executable file generated by jax.jit in the main process and execute this executable file in another process. However, I have found that the performance of the executable file is much slower than that of the JIT itself. I would like to know why this is the case.
1.The execution code and results of jax.jit are as follows.
1.1 The execution code
1.2 The execution results
Train for 5 epochs, with 10,000 steps per epoch, and each epoch takes 15 seconds.
2.The code and results from running the executable file generated by jax.jit are as follows.
2.1 The code
2.2 The results
Train for 5 epochs, with 10,000 steps per epoch, and each epoch takes 50 seconds.
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: