Skip to content

Commit

Permalink
update orbax handler to use bulk read APIs.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 711859145
  • Loading branch information
Pathways-on-Cloud Team authored and copybara-github committed Jan 3, 2025
1 parent 2a494e4 commit a716b9b
Show file tree
Hide file tree
Showing 4 changed files with 274 additions and 41 deletions.
131 changes: 123 additions & 8 deletions pathwaysutils/persistence/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
"""Helper functions for persistence."""

import base64
import concurrent.futures
import datetime
import json
from typing import Sequence, Union
from typing import Any, Sequence, Tuple, Union

import jax
from jax import core
Expand All @@ -25,6 +26,34 @@
from pathwaysutils import plugin_executable


def dtype_to_etype(dtype: np.dtype) -> xc.PrimitiveType:
"""Converts a numpy dtype to an xla PrimitiveType."""
if dtype == np.dtype("bfloat16"):
return xc.PrimitiveType.BF16
elif dtype == np.dtype("float32"):
return xc.PrimitiveType.F32
elif dtype == np.dtype("float64"):
return xc.PrimitiveType.F64
elif dtype == np.dtype("int8"):
return xc.PrimitiveType.S8
elif dtype == np.dtype("int16"):
return xc.PrimitiveType.S16
elif dtype == np.dtype("int32"):
return xc.PrimitiveType.S32
elif dtype == np.dtype("int64"):
return xc.PrimitiveType.S64
elif dtype == np.dtype("uint8"):
return xc.PrimitiveType.U8
elif dtype == np.dtype("uint16"):
return xc.PrimitiveType.U16
elif dtype == np.dtype("uint32"):
return xc.PrimitiveType.U32
elif dtype == np.dtype("uint64"):
return xc.PrimitiveType.U64
else:
raise ValueError(f"Unsupported dtype: {dtype}")


def base64_utf8_stringify(bs: bytes) -> str:
"""Converts bytes to a base64-encoded utf-8 string.
Expand Down Expand Up @@ -69,7 +98,7 @@ def get_shape_string(
"""Serializes the shape, encodes it to base64 and returns the base-64 as an utf-8 string."""
return base64_utf8_stringify(
xc.Shape.array_shape(
xc.PrimitiveType(xc.dtype_to_etype(dtype)),
xc.PrimitiveType(dtype_to_etype(dtype)),
shape,
)
.with_major_to_minor_layout_if_absent()
Expand All @@ -82,7 +111,8 @@ def get_write_request(
name: str,
jax_array: jax.Array,
timeout: datetime.timedelta,
) -> str:
return_dict: bool = False,
) -> Union[str, dict[str, Any]]:
"""Returns a string representation of the plugin program which writes the given jax_array to the given location."""
sharding = jax_array.sharding
assert isinstance(sharding, jax.sharding.Sharding), sharding
Expand All @@ -91,7 +121,7 @@ def get_write_request(
timeout.total_seconds(), 1
)
timeout_nanoseconds = timeout_fractional_seconds * 1e9
return json.dumps({
d = {
"persistenceWriteRequest": {
"b64_location": string_to_base64(location_path),
"b64_name": string_to_base64(name),
Expand All @@ -112,7 +142,29 @@ def get_write_request(
"nanos": int(timeout_nanoseconds),
},
}
})
}

if return_dict:
return d
return json.dumps(d)


def get_bulk_write_request(
location_path: str,
names: Sequence[str],
jax_arrays: Sequence[jax.Array],
timeout: datetime.timedelta,
) -> str:
"""Returns a string representation of a bulk write request, writes multiple arrays with one call."""
write_requests = [
get_write_request(location_path, name, jax_array, timeout, True)[
"persistenceWriteRequest"
]
for name, jax_array in zip(names, jax_arrays)
]
return json.dumps(
{"bulk_persistence_write_request": {"write_requests": write_requests}}
)


def get_read_request(
Expand All @@ -123,7 +175,8 @@ def get_read_request(
sharding: jax.sharding.Sharding,
devices: Sequence[jax.Device],
timeout: datetime.timedelta,
) -> str:
return_dict: bool = False,
) -> Union[str, dict[str, Any]]:
"""Returns a string representation of the plugin program which reads the given array from the given location into the provided sharding."""
if not isinstance(devices, np.ndarray):
devices = np.array(devices)
Expand All @@ -132,7 +185,7 @@ def get_read_request(
timeout.total_seconds(), 1
)
timeout_nanoseconds = timeout_fractional_seconds * 1e9
return json.dumps({
d = {
"persistenceReadRequest": {
"b64_location": string_to_base64(location_path),
"b64_shape_proto_string": get_shape_string(dtype, shape),
Expand All @@ -148,7 +201,32 @@ def get_read_request(
"nanos": int(timeout_nanoseconds),
},
}
})
}

if return_dict:
return d
return json.dumps(d)


def get_bulk_read_request(
location_path: str,
names: Sequence[str],
dtypes: Sequence[np.dtype],
shapes: Sequence[Sequence[int]],
shardings: Sequence[jax.sharding.Sharding],
devices: Sequence[jax.Device],
timeout: datetime.timedelta,
) -> str:
"""Returns a string representation of a bulk read request, reads multiple arrays with one call."""
read_requests = [
get_read_request(
location_path, name, dtype, shape, sharding, devices, timeout, True
)["persistenceReadRequest"]
for name, dtype, shape, sharding in zip(names, dtypes, shapes, shardings)
]
return json.dumps(
{"bulk_persistence_read_request": {"read_requests": read_requests}}
)


def write_one_array(
Expand All @@ -164,6 +242,19 @@ def write_one_array(
return write_future


def write_arrays(
location: str,
names: Sequence[str],
values: Sequence[jax.Array],
timeout: datetime.timedelta,
) -> concurrent.futures.Future[None]:
"""Creates the write array plugin program string, compiles it to an executable, calls it and returns an awaitable future."""
bulk_write_request = get_bulk_write_request(location, names, values, timeout)
bulk_write_executable = plugin_executable.PluginExecutable(bulk_write_request)
_, bulk_write_future = bulk_write_executable.call(values)
return bulk_write_future


def read_one_array(
location: str,
name: str,
Expand All @@ -190,3 +281,27 @@ def read_one_array(
)
read_future.result()
return read_array[0]


def read_arrays(
location: str,
names: Sequence[str],
dtypes: Sequence[np.dtype],
shapes: Sequence[Sequence[int]],
shardings: Sequence[jax.sharding.Sharding],
devices: Union[Sequence[jax.Device], np.ndarray],
timeout: datetime.timedelta,
) -> Tuple[Sequence[jax.Array], concurrent.futures.Future[None]]:
"""Creates the read array plugin program string, compiles it to an executable, calls it and returns the result."""

bulk_read_request = get_bulk_read_request(
location, names, dtypes, shapes, shardings, devices, timeout
)
bulk_read_executable = plugin_executable.PluginExecutable(bulk_read_request)
out_avals = [
core.ShapedArray(shape, dtype) for shape, dtype in zip(shapes, dtypes)
]
arrays, read_future = bulk_read_executable.call(
out_shardings=shardings, out_avals=out_avals
)
return (arrays, read_future)
54 changes: 21 additions & 33 deletions pathwaysutils/persistence/pathways_orbax_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__(
self._read_timeout = read_timeout

if use_ocdbt:
raise ValueError('OCDBT not supported for Pathways.')
raise ValueError("OCDBT not supported for Pathways.")
super().__init__()

async def serialize(
Expand All @@ -73,12 +73,10 @@ async def serialize(
type_handlers.check_input_arguments(values, infos, args)

if any([arg.dtype is not None for arg in args]):
raise ValueError('Casting during save not supported for Pathways.')
raise ValueError("Casting during save not supported for Pathways.")

locations, names = extract_parent_dir_and_name(infos)
f = functools.partial(
helper.write_one_array, timeout=self._read_timeout
)
f = functools.partial(helper.write_one_array, timeout=self._read_timeout)
return list(map(f, locations, names, values))

async def deserialize(
Expand All @@ -88,7 +86,7 @@ async def deserialize(
) -> Sequence[jax.Array]:
"""Uses Pathways Persistence API to deserialize a jax array."""
if args is None:
raise ValueError('Must provide ArrayRestoreArgs to restore as jax.Array.')
raise ValueError("Must provide ArrayRestoreArgs to restore as jax.Array.")
type_handlers.check_input_arguments(infos, args)

global_meshes = []
Expand All @@ -101,14 +99,14 @@ async def deserialize(
for arg in args:
if not isinstance(arg, ArrayRestoreArgs):
raise ValueError(
'To restore jax.Array, provide ArrayRestoreArgs; found'
f' {type(arg).__name__}'
"To restore jax.Array, provide ArrayRestoreArgs; found"
f" {type(arg).__name__}"
)
arg = typing.cast(ArrayRestoreArgs, arg)
if arg.sharding is None and (arg.mesh is None or arg.mesh_axes is None):
raise ValueError(
'Sharding of jax.Array cannot be None. Provide `mesh`'
' and `mesh_axes` OR `sharding`.'
"Sharding of jax.Array cannot be None. Provide `mesh`"
" and `mesh_axes` OR `sharding`."
)
if arg.sharding is None:
global_meshes.append(arg.mesh)
Expand All @@ -118,15 +116,15 @@ async def deserialize(
)
else:
if not isinstance(arg.sharding, jax.sharding.NamedSharding):
raise ValueError('Pathways only supports jax.sharding.NamedSharding.')
raise ValueError("Pathways only supports jax.sharding.NamedSharding.")
sharding = typing.cast(jax.sharding.NamedSharding, arg.sharding)
global_meshes.append(sharding.mesh)
mesh_axes.append(sharding.spec)
shardings.append(sharding)
if arg.global_shape is None or arg.dtype is None:
logger.warning(
'Shape or dtype not provided for restoration. Provide these'
' properties for improved performance.'
"Shape or dtype not provided for restoration. Provide these"
" properties for improved performance."
)
should_open_metadata = True
global_shapes.append(arg.global_shape)
Expand All @@ -153,27 +151,17 @@ async def deserialize(
grouped_dtypes = [dtypes[idx] for idx in idxs]
grouped_shardings = [shardings[idx] for idx in idxs]
locations, names = extract_parent_dir_and_name(grouped_infos)
f = functools.partial(
helper.read_one_array,
devices=global_mesh.devices,
grouped_arrays, read_future = helper.read_arrays(
locations[0],
names,
grouped_dtypes,
grouped_global_shapes,
grouped_shardings,
global_mesh.devices,
timeout=self._read_timeout,
)
grouped_arrays = [
f(
location=location,
name=name,
dtype=dtype,
shape=shape,
shardings=sharding,
)
for location, name, dtype, shape, sharding in zip(
locations,
names,
grouped_dtypes,
grouped_global_shapes,
grouped_shardings,
)
]
# each persistence call is awaited serially.
read_future.result()
for idx, arr in zip(idxs, grouped_arrays):
results[idx] = arr
return results # pytype: disable=bad-return-type
Expand All @@ -184,7 +172,7 @@ def register_pathways_handlers(
):
"""Function that must be called before saving or restoring with Pathways."""
logger.debug(
'Registering CloudPathwaysArrayHandler (Pathways Persistence API).'
"Registering CloudPathwaysArrayHandler (Pathways Persistence API)."
)
type_handlers.register_type_handler(
jax.Array,
Expand Down
Loading

0 comments on commit a716b9b

Please sign in to comment.