Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add docs and a utility function for packing JAX models #181

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions docs/website/pages/docs/packing/index.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ Pick a framework to get started:
{ /* <ButtonLink href="/docs/packing/tensorflow">TensorFlow</ButtonLink> */ }
<ButtonLink href="/docs/packing/pytorch">PyTorch</ButtonLink>
<ButtonLink href="/docs/packing/torchscript">TorchScript</ButtonLink>
{ /* <ButtonLink href="/docs/packing/ludwig">Ludwig</ButtonLink>
<ButtonLink href="/docs/packing/jax">JAX</ButtonLink>
{ /* <ButtonLink href="/docs/packing/ludwig">Ludwig</ButtonLink>
<ButtonLink href="/docs/packing/keras">Keras</ButtonLink>
<ButtonLink href="/docs/packing/tensorrt">TensorRT</ButtonLink>
<ButtonLink href="/docs/packing/xgboost">XGBoost</ButtonLink>
Expand All @@ -27,7 +27,6 @@ Or explore [options that are available for any model type](/docs/packing/options
We're currently in the process of implementing support for the following frameworks:

- TensorFlow
- JAX
- Keras
- Ludwig
- TensorRT
Expand Down
106 changes: 106 additions & 0 deletions docs/website/pages/docs/packing/jax.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Pack a JAX model

Carton supports packing any JAX model that works with [`jax.jit`](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html) in a way that does not require Python at runtime. It runs these models using [XLA](https://github.com/openxla/xla), which is the optimizing ML compiler that JAX uses under the hood. The main constraint is that the input shapes and types of your model must be known at packing time.

For models that don't work with `jax.jit` or need to work with dynamic shapes, you can package them as [arbitrary Python code](/docs/packing/python).

import LanguageSelect, {LanguageSwitch, LanguageItem} from "@/components/languageselect";

<LanguageSwitch>
<LanguageItem forLang='python'>

### Packing a model

Let's create a simple model to pack:

```python forLang="python" withLineNumbers
import asyncio
import jax.numpy as jnp

# The "model" we want to package
def selu(x, alpha=1.67, lambda_=1.05):
return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

async def main():
# ...

asyncio.run(main())
```

As mentioned above, this can be anything that works with `jax.jit`.

Next, we can use the `pack_jax_for_xla` utility to package it:

```python forLang="python" withLineNumbers highlight='4,5,14-40'
import asyncio
import jax.numpy as jnp

from cartonml import TensorSpec
from cartonml.utils.jax import pack_jax_for_xla

# The "model" we want to package
def selu(x, alpha=1.67, lambda_=1.05):
return {
"out": lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
}

async def main():
packed_model_path = await pack_jax_for_xla(
model = selu,

# A mapping of arg names to shapes and datatypes
# Note that shapes need to be concrete (i.e `None` and symbols are not allowed). This
# means the shape and datatypes of your inputs at runtime must match the shapes and
# datatypes below.
inputs = [
TensorSpec(
name = "x",
dtype = "float32",
shape = [1000],
)
],

# All the output tensors need to be specified, but you can use `None` and symbols in
# `shape`
outputs = [
TensorSpec(
name = "out",
dtype = "float32",
shape = [None],
)
],
# ...
# See the link below for a list of other information you can provide when packing a model
)

asyncio.run(main())
```

The model at `packed_model_path` can now be loaded from any programming language supported by Carton!

For more details on `inputs` and `outputs` along with several other options you can provide when packing a model (e.g. description, examples, etc.), see [here](/docs/packing/options).

### Differences from `carton.pack`

If you've used the `carton.pack` function for other model types, the `pack_jax_for_xla` utility is very similar. The major differences are:

* It takes a jittable JAX function as a `model` input (instead of a path)
* `inputs` and `outputs` are required arguments and all shapes must be concrete (i.e `None` and symbols are not allowed)

It calls `carton.pack` internally after it exports the JAX model and supports all the optional arguments that `pack` does.

</LanguageItem>
<LanguageItem>
<div class="bg-gradient-to-r from-pink-500 via-red-500 to-yellow-500 rounded-md drop-shadow-lg p-5">
<div class="flex items-center flex-col">
<span class="text-white pb-5">This guide is currently only available for Python. Please change your language selection to see it</span>
<div className="flex not-prose">
<LanguageSelect/>
</div>
</div>
</div>
</LanguageItem>
</LanguageSwitch>

import DocsLayout from '@/components/docslayout'
export default ({children}) => <DocsLayout>{children}</DocsLayout>
67 changes: 67 additions & 0 deletions source/carton-bindings-py/python/cartonml/utils/jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright 2023 Vivek Panyam
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os
import tempfile
from typing import List

import cartonml as carton
from cartonml import TensorSpec

import jax
import jax.numpy as jnp

async def pack_jax_for_xla(model, inputs: List[TensorSpec], outputs: List[TensorSpec], **kwargs) -> str:
"""
Pack a JAX model supported by `jax.jit` into a Carton.
"""
# Create a temp dir
with tempfile.TemporaryDirectory() as temp_dir:
# Wrap the JAX model in a function with args in the correct order
# Based on https://github.com/google/jax/blob/main/jax/tools/jax_to_ir.py (Apache 2.0)
args = [jnp.zeros(s.shape, s.dtype) for _, s in inputs]

# Wrapper that takes in args in the order of `input_shapes` and converts them
# to kwargs for calling `fn`.
# It then transforms the outputs to return in order as well
def ordered_wrapper(*args):
arg_names = [item.name for item in inputs]
out_dict = model(**dict(zip(arg_names, args)))

return [out_dict[item.name] for item in outputs]

# Export to HLO
comp = jax.xla_computation(ordered_wrapper)(*args)
serialized_proto = comp.as_serialized_hlo_module_proto()

with open(os.path.join(temp_dir, "model.pb"), 'wb') as f:
f.write(serialized_proto)

# Write metadata that stores the arg ordering
with open(os.path.join(temp_dir, "model.json"), 'w') as f:
json.dump(dict(
input_ordering=[item.name for item in inputs],
output_ordering=[item.name for item in outputs],
), f)

return await carton.pack(
temp_dir,
inputs = inputs,
runner_name = "xla",
# TODO: Allow users to set the version of XLA they want to run with
# How is XLA versioned? I don't see anything standard version numbers with the prebuilts
required_framework_version = ">= 0.0.0"
**kwargs
)