Skip to content

Commit

Permalink
Reuse input_to_device() for continuous batching implementation
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 649185731
Change-Id: I96dd898ab06fb5f7f6b9bedf03f6a92d2e62b88b
  • Loading branch information
changlan authored and copybara-github committed Jul 3, 2024
1 parent fc8cf18 commit a066674
Show file tree
Hide file tree
Showing 7 changed files with 133 additions and 96 deletions.
1 change: 1 addition & 0 deletions saxml/server/jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ pytype_strict_library(
"//third_party/py/absl-py/logging",
"//third_party/py/jax",
"//third_party/py/jax:experimental",
"//third_party/py/jaxtyping",
"//third_party/py/numpy",
"//third_party/py/paxml:host_callback",
"//third_party/py/praxis:pytypes",
Expand Down
185 changes: 109 additions & 76 deletions saxml/server/jax/servable_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from jax import experimental as jax_exp
from jax import numpy as jnp
from jax.experimental import pjit
import jaxtyping as jt
import numpy as np
from paxml import host_callback as paxml_hcb
from praxis import pytypes
Expand All @@ -39,10 +40,12 @@
HostTensors = Any
DeviceTensors = Any
PSpecs = Any
ShapesAndDtypes = Any
Shapes = Any
Shapes = tuple[int, ...]
DType = np.dtype
ShapeDType = jax.ShapeDtypeStruct
JaxTensors = Any
InputShapeInfo = servable_model.InputShapeInfo
PyTree = jt.PyTree


def remove_padding(x: jnp.ndarray, shape: Sequence[int]) -> jnp.ndarray:
Expand Down Expand Up @@ -72,7 +75,7 @@ class ServableModelState:
# Model variables.
mdl_vars: DeviceTensors
# Shapes of model variables without GSPMD padding.
mdl_var_unpadded_shapes: Shapes
mdl_var_unpadded_shapes: PyTree[Shapes] | None

# Whether the current host is the primary in a multi-jax-client setup. It is
# set to True for Pathways.
Expand All @@ -98,7 +101,7 @@ class MethodInputInfo:
# Partition specs for the inputs of the device function.
input_pspecs: PSpecs
# Global shape and dtype for the inputs of the device function.
global_inputs_shape_dtype: ShapesAndDtypes
global_inputs_shape_dtypes: PyTree[ShapeDType]
# Dummy input tensors used for secondary hosts.
dummy_inputs: Optional[DeviceTensors] = None
# Dummy input device buffers (on the local devices)
Expand Down Expand Up @@ -240,62 +243,21 @@ def _create_aval(x):
).compile(compiler_options=compiler_options)
return compiled

def _register_for_input_shape(self, input_shape: InputShapeInfo) -> None:
batched_host_dummy = self.get_dummy_inputs(input_shape)
batched_host_dummy = self.update_extra_inputs(
batched_host_dummy,
input_shape.batch_size,
[self.default_extra_inputs] * input_shape.batch_size,
)

def _assert_type(x):
assert isinstance(x, np.ndarray) or isinstance(
x, jnp.ndarray
), f'Output of pre_processing contained an invalid type: {type(x)}'
return x

dummy_step = np.array(0, dtype=np.int32)
dummy_prng_key = jax.random.PRNGKey(0)
host_dummy = (
dummy_step,
dummy_prng_key,
batched_host_dummy,
self.get_nonbatch_inputs(batched_host_dummy),
)
host_dummy = jax.tree_util.tree_map(_assert_type, host_dummy)

def _get_pspec(x):
# Add a `cores` dimension.
return jax.sharding.PartitionSpec(
self.model_state.global_mesh.axis_names, *(None,) * (len(x.shape))
)

input_pspecs = jax.tree_util.tree_map(_get_pspec, host_dummy)
num_cores = len(self.model_state.global_mesh.devices.flat)

global_inputs_shape_dtype = jax.tree_util.tree_map(
lambda x: ((num_cores,) + x.shape, x.dtype), host_dummy
)
def _initialize_device_fn(
self, input_shape: InputShapeInfo, info: MethodInputInfo
) -> None:
# Initialize the device function.
input_pspecs = info.input_pspecs
global_inputs_shape_dtypes = info.global_inputs_shape_dtypes
logging.info('global_inputs_shape_dtypes: %s', global_inputs_shape_dtypes)
global_inputs_shaped_arrays = jax.tree_util.tree_map(
lambda x: jax.core.ShapedArray((num_cores,) + x.shape, x.dtype),
host_dummy,
)

info = MethodInputInfo(
input_pspecs=input_pspecs,
global_inputs_shape_dtype=global_inputs_shape_dtype,
)

info.dummy_inputs_per_device_buffers = self._input_to_device_buffers(
batched_host_dummy, input_shape, info, is_dummy=True
)
info.dummy_inputs = self._device_buffers_to_jax_arrays(
info.dummy_inputs_per_device_buffers, info
lambda x: jax.core.ShapedArray(x.shape, x.dtype),
global_inputs_shape_dtypes,
)

# Initialize the device function.
dummy_inputs = info.dummy_inputs
dummy_inputs_per_device_buffers = info.dummy_inputs_per_device_buffers
device_fn = self._pjit_device_fn(
input_pspecs, input_shape.batch_size, info.dummy_inputs
input_pspecs, input_shape.batch_size, dummy_inputs
)

logging.info(
Expand Down Expand Up @@ -341,15 +303,11 @@ def _get_pspec(x):
self._model_state.mdl_vars = resharded_mdl_vars
self._model_state.mdl_var_pspecs = mdl_var_pspecs

info.device_fn = device_fn

if self.model_state.precompile and not self._mutable:
# Compute with dummy to trigger compilation. Only use this option for
# immutable methods to prevent side effects.
with self.model_state.global_mesh:
init_dummy_outputs = info.device_fn(
self.model_state.mdl_vars, info.dummy_inputs
)
init_dummy_outputs = device_fn(self.model_state.mdl_vars, dummy_inputs)

# Only warm up post processing on primary host.
if self.model_state.is_primary_host:
Expand All @@ -368,7 +326,69 @@ def _get_pspec(x):
outs = self.output_to_host(init_dummy_outputs, self.batch_size)
# Warm up post processor.
self.post_processing(outs)
info.device_fn = device_fn

def _register_for_input_shape(
self, input_shape: InputShapeInfo, initialize_device_fn=True
) -> None:
batched_host_dummy = self.get_dummy_inputs(input_shape)
batched_host_dummy = self.update_extra_inputs(
batched_host_dummy,
input_shape.batch_size,
[self.default_extra_inputs] * input_shape.batch_size,
)

def _assert_type(x):
assert isinstance(x, np.ndarray) or isinstance(
x, jnp.ndarray
), f'Output of pre_processing contained an invalid type: {type(x)}'
return x

dummy_step = np.array(0, dtype=np.int32)
dummy_prng_key = jax.random.PRNGKey(0)
host_dummy = (
dummy_step,
dummy_prng_key,
batched_host_dummy,
self.get_nonbatch_inputs(batched_host_dummy),
)
host_dummy = jax.tree_util.tree_map(_assert_type, host_dummy)

def _get_pspec(x):
# Add a `cores` dimension.
return jax.sharding.PartitionSpec(
self.model_state.global_mesh.axis_names, *(None,) * (len(x.shape))
)

input_pspecs = jax.tree_util.tree_map(_get_pspec, host_dummy)
num_cores = len(self.model_state.global_mesh.devices.flat)

global_inputs_shape_dtypes = jax.tree_util.tree_map(
lambda x: ShapeDType((num_cores,) + x.shape, x.dtype), host_dummy
)

dummy_inputs_per_device_buffers = self._input_to_device_buffers(
batched_host_dummy,
input_shape,
global_inputs_shape_dtypes,
dummy_inputs_per_device_buffers=None,
is_dummy=True,
)
dummy_inputs = self._device_buffers_to_jax_arrays(
dummy_inputs_per_device_buffers,
input_pspecs,
global_inputs_shape_dtypes,
)

info = MethodInputInfo(
input_pspecs=input_pspecs,
global_inputs_shape_dtypes=global_inputs_shape_dtypes,
dummy_inputs=dummy_inputs,
dummy_inputs_per_device_buffers=dummy_inputs_per_device_buffers,
)

if initialize_device_fn:
self._initialize_device_fn(input_shape, info)
self._per_bs_infos[input_shape] = info

@property
Expand All @@ -392,7 +412,7 @@ def get_nonbatch_inputs(self, one_core_inputs: HostTensors) -> HostTensors:
def resize_host_array(
self,
x: np.ndarray,
global_input_shape_dtype: ShapesAndDtypes,
global_input_shape_dtype: ShapeDType,
unpadded_input_shape: InputShapeInfo,
):
"""Checks the shape of x and resizes to the desired shape.
Expand All @@ -405,7 +425,10 @@ def resize_host_array(
Returns:
host array after padding or slice of x.
"""
global_shape, global_dtype = global_input_shape_dtype
global_shape, global_dtype = (
global_input_shape_dtype.shape,
global_input_shape_dtype.dtype,
)
assert x.dtype == global_dtype, (x.dtype, global_dtype)
assert x.shape[1:] == global_shape[2:], (x.shape, global_shape)
b = x.shape[0]
Expand All @@ -423,7 +446,8 @@ def _input_to_device_buffers(
self,
one_core_inputs: HostTensors,
unpadded_input_shape: InputShapeInfo,
info: MethodInputInfo,
global_inputs_shape_dtypes: PyTree[ShapeDType],
dummy_inputs_per_device_buffers: DeviceTensors | None,
is_dummy: bool,
) -> DeviceTensors:
step = np.array(self._step.next(), dtype=np.int32)
Expand All @@ -433,7 +457,7 @@ def _input_to_device_buffers(
),
one_core_inputs,
# Only the batched inputs.
info.global_inputs_shape_dtype[2],
global_inputs_shape_dtypes[2],
)
host_inputs = (
step,
Expand All @@ -459,7 +483,7 @@ def _pad_for_devices(x):
return jax.tree_util.tree_map(pad_fn, host_inputs)
else:
if is_dummy:
assert info.dummy_inputs_per_device_buffers is None
assert dummy_inputs_per_device_buffers is None

def _to_buffers(x):
if self.model_state.is_primary_host:
Expand All @@ -473,7 +497,7 @@ def _to_buffers(x):

return jax.tree_util.tree_map(_to_buffers, host_inputs)
else:
assert info.dummy_inputs_per_device_buffers is not None
assert dummy_inputs_per_device_buffers is not None

def _update_buffers(x, buffers):
x = np.expand_dims(x, axis=0)
Expand All @@ -482,17 +506,20 @@ def _update_buffers(x, buffers):
return [jax.device_put(x, self._local_devices[0])] + buffers[1:]

return jax.tree_util.tree_map(
_update_buffers, host_inputs, info.dummy_inputs_per_device_buffers
_update_buffers, host_inputs, dummy_inputs_per_device_buffers
)

def _device_buffers_to_jax_arrays(
self, buffers: Any, info: MethodInputInfo
self,
buffers: DeviceTensors,
input_pspecs: PSpecs,
global_inputs_shape_dtypes: PyTree[ShapeDType],
) -> DeviceTensors:
if not self.model_state.input_prefetch:
return buffers

def _to_jax_array(pspec, bufs, shape_dtype):
shape, _ = shape_dtype
def _to_jax_array(pspec, bufs, shape_dtype: ShapeDType):
shape = shape_dtype.shape
return jax.make_array_from_single_device_arrays(
shape,
jax.sharding.NamedSharding(self.model_state.global_mesh, pspec),
Expand All @@ -501,9 +528,9 @@ def _to_jax_array(pspec, bufs, shape_dtype):

return jax.tree_util.tree_map(
_to_jax_array,
info.input_pspecs,
input_pspecs,
buffers,
info.global_inputs_shape_dtype,
global_inputs_shape_dtypes,
)

def input_to_device(
Expand All @@ -515,9 +542,15 @@ def input_to_device(
"""Transfers host inputs to device. Pads incomplete shapes."""
info = self._per_bs_infos[padded_shape]
buffers = self._input_to_device_buffers(
one_core_inputs, unpadded_shape, info, is_dummy=False
one_core_inputs,
unpadded_shape,
info.global_inputs_shape_dtypes,
info.dummy_inputs_per_device_buffers,
is_dummy=False,
)
return self._device_buffers_to_jax_arrays(
buffers, info.input_pspecs, info.global_inputs_shape_dtypes
)
return self._device_buffers_to_jax_arrays(buffers, info)

def output_to_host(
self, output_tensors: DeviceTensors, unpadded_batch_size: int
Expand Down
6 changes: 3 additions & 3 deletions saxml/server/pax/custom/servable_custom_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
NestedMap = py_utils.NestedMap
InputShapeInfo = servable_model.InputShapeInfo
HostTensors = servable_model.HostTensors
ShapesAndDtypes = servable_model.ShapesAndDtypes
ShapeDType = servable_model.ShapeDType


FetchOutputFn = Callable[[NestedJTensor, NestedJTensor], NestedJTensor]
Expand Down Expand Up @@ -93,7 +93,7 @@ def __call__(
GetPaddedInputShapeFn = Callable[[Any], Any]
GetUnpaddedInputShapeFn = Callable[[int, HostTensors], Any]
DeserializeInputShapeFn = Callable[[str], Any]
ResizeHostArrayFn = Callable[[np.ndarray, ShapesAndDtypes, Any], HostTensors]
ResizeHostArrayFn = Callable[[np.ndarray, ShapeDType, Any], HostTensors]


class CustomMethodName:
Expand Down Expand Up @@ -311,7 +311,7 @@ def deserialize_input_shape(self, unpadded_shape_str: str) -> InputShapeInfo:
def resize_host_array(
self,
x: np.ndarray,
global_input_shape_dtype: ShapesAndDtypes,
global_input_shape_dtype: ShapeDType,
unpadded_input_shape: InputShapeInfo,
) -> HostTensors:
"""Resizes x to the desired shape.
Expand Down
6 changes: 3 additions & 3 deletions saxml/server/pax/lm/servable_lm_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
NestedTfTensor = pytypes.Nested[tf.Tensor]
NestedNpOrTfTensor = Union[NestedNpTensor, NestedTfTensor]
HostTensors = servable_model.HostTensors
ShapesAndDtypes = servable_model.ShapesAndDtypes
ShapeDType = servable_model.ShapeDType
TensorSpec = Union[tf.TensorSpec, oex.TensorSpecWithDefault]
NpOrTfTensor = Union[pytypes.NpTensor, tf.Tensor]

Expand Down Expand Up @@ -825,7 +825,7 @@ def _slice_fn(x):

def resize_host_array(
x: np.ndarray,
global_input_shape_dtype: ShapesAndDtypes,
global_input_shape_dtype: ShapeDType,
unpadded_input_shape: InputShapeInfo,
) -> HostTensors:
"""Resize host array to the deired shape.
Expand All @@ -838,7 +838,7 @@ def resize_host_array(
Returns:
host array after padding or slice of x.
"""
global_shape, _ = global_input_shape_dtype
global_shape = global_input_shape_dtype.shape
if unpadded_input_shape.seq_len != -1 and (
len(x.shape) == 2 or len(x.shape) == 3
):
Expand Down
Loading

0 comments on commit a066674

Please sign in to comment.