Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable JetStream Standalone Server #94

Draft
wants to merge 25 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .github/workflows/unit_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@ jobs:
with:
python-version: ${{ matrix.python-version }}
- name: Install Dependencies
run: make install-deps
run: |
make install-deps
make install-submodules
- name: Run all unit tests in JetStream (jetstream/tests)
run: make unit-tests
- name: Create test coverage report
Expand Down
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "jetstream/engine/implementations/maxtext"]
path = jetstream/engine/implementations/maxtext
url = https://github.com/google/maxtext.git
6 changes: 6 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,15 @@ GRPC_TOOLS_VERSION := 1.62.1
all: install-deps generate-protos format check

# Dependency management targets

install-deps:
$(PIP) install pytype pylint pyink -r requirements.txt -r benchmarks/requirements.in

install-submodules:
git submodule update --init --recursive
- $(PIP) install -r ./jetstream/engine/implementations/maxtext/requirements.txt
- $(PIP) install jetstream_pt@git+https://github.com/google/[email protected]#egg=jetstream_pt

# Code generation/formatting targets
generate-protos: generate-and-prepend-preambles format

Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ make install-deps
Use the following commands to run a server locally:
```
# Start a server
python -m jetstream.core.implementations.mock.server
python -m jetstream.entrypoints.mock.server

# Test local mock server
python -m jetstream.tools.requester
Expand Down
8 changes: 8 additions & 0 deletions jetstream/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,11 @@
except ImportError as e:
print("Proxy backend support is not added")
pass

import os
import sys

submodule_path = os.path.join(
os.path.dirname(__file__), "implementations/maxtext/MaxText"
)
sys.path.append(submodule_path)
1 change: 1 addition & 0 deletions jetstream/engine/implementations/maxtext
Submodule maxtext added at 2a6154
10 changes: 5 additions & 5 deletions jetstream/engine/mock_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def prefill(
samples_per_slot=self.generate_cache_batch // self.prefill_cache_batch,
)

return (prefill_cache, first_step), first_token
return (prefill_cache.astype(jnp.float32), first_step), first_token

@functools.partial(jax.jit, static_argnums=(0,))
def generate(
Expand All @@ -152,7 +152,7 @@ def generate(

# Update generate cache
generate_cache = jax.lax.dynamic_update_slice_in_dim(
generate_cache,
generate_cache.astype(jnp.float32),
previous_timestep,
start_index=generate_cache_index,
axis=1,
Expand Down Expand Up @@ -198,7 +198,7 @@ def generate(
)
return DecodeState(
prefill_cache=prefill_cache,
generate_cache=generate_cache,
generate_cache=generate_cache.astype(jnp.float32),
generate_cache_index=generate_cache_index,
generate_lengths=new_lengths,
generate_tokens=new_timestep,
Expand Down Expand Up @@ -230,7 +230,7 @@ def insert(
)
generate_cache = jax.lax.dynamic_update_slice_in_dim(
decode_state.generate_cache,
jnp.zeros((1, self.cache_length)),
jnp.zeros((1, self.cache_length), dtype=jnp.float32),
slot,
axis=0,
)
Expand All @@ -243,7 +243,7 @@ def insert(
)
generate_tokens = jax.lax.dynamic_update_slice_in_dim(
decode_state.generate_tokens,
previous_timestep,
previous_timestep.astype(jnp.float32),
slot * samples_per_slot,
axis=0,
)
Expand Down
38 changes: 36 additions & 2 deletions jetstream/entrypoints/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,49 @@

"""Config for JetStream Server (including engine init)."""

from typing import Type
import functools
import os
from typing import Sequence, Type

import jax
from jetstream.core import config_lib
from jetstream.engine.implementations.maxtext.MaxText import maxengine, pyconfig
from jetstream_pt import config


def get_server_config(
config_str: str,
config_str: str, argv: Sequence[str]
) -> config_lib.ServerConfig | Type[config_lib.ServerConfig]:
match config_str:
case "MaxtextInterleavedServer":
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
pyconfig.initialize(argv)
server_config = config_lib.ServerConfig(
prefill_slices=(),
generate_slices=(),
interleaved_slices=("tpu=" + str(jax.device_count()),),
prefill_engine_create_fns=(),
generate_engine_create_fns=(),
interleaved_engine_create_fns=(
functools.partial(
maxengine.MaxEngine(config), config=pyconfig.config
),
),
)
case "PyTorchInterleavedServer":
os.environ["XLA_FLAGS"] = (
"--xla_dump_to=/tmp/xla_logs --xla_dump_hlo_as_text"
)
engine = config.create_engine_from_config_flags()
server_config = config_lib.ServerConfig(
prefill_slices=(),
generate_slices=(),
interleaved_slices=("tpu=" + str(jax.device_count()),),
prefill_engine_create_fns=(),
generate_engine_create_fns=(),
interleaved_engine_create_fns=(lambda a: engine,),
)
case "InterleavedCPUTestServer":
server_config = config_lib.InterleavedCPUTestServer
case "CPUTestServer":
Expand Down
62 changes: 62 additions & 0 deletions jetstream/entrypoints/grpc/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Runs a JetStream Server."""

from typing import Sequence

from absl import app
from absl import flags

from jetstream.entrypoints import config
from jetstream.core import config_lib, server_lib


flags.DEFINE_integer("port", 9000, "port to listen on")
flags.DEFINE_integer("threads", 64, "number of worker threads in thread pool")
flags.DEFINE_string(
"config",
"InterleavedCPUTestServer",
"available servers",
)
flags.DEFINE_integer("prometheus_port", 0, "")


def main(argv: Sequence[str]):
devices = server_lib.get_devices()
print(f"devices: {devices}")
server_config = config.get_server_config(flags.FLAGS.config, argv)
print(f"server_config: {server_config}")
del argv

metrics_server_config: config_lib.MetricsServerConfig | None = None
if flags.FLAGS.prometheus_port != 0:
metrics_server_config = config_lib.MetricsServerConfig(
port=flags.FLAGS.prometheus_port
)
# We separate credential from run so that we can unit test it with local
# credentials.
# TODO: Add grpc credentials for OSS.
jetstream_server = server_lib.run(
threads=flags.FLAGS.threads,
port=flags.FLAGS.port,
config=server_config,
devices=devices,
metrics_server_config=metrics_server_config,
)
jetstream_server.wait_for_termination()


if __name__ == "__main__":
app.run(main)
6 changes: 3 additions & 3 deletions jetstream/entrypoints/http/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,11 @@ def server(argv: Sequence[str]):
app.include_router(router)

# Init LLMOrchestrator which would be the main handler in the api endpoints.
devices = server_lib.get_devices()
print(f"devices: {devices}")
server_config = get_server_config(flags.FLAGS.config)
server_config = get_server_config(flags.FLAGS.config, argv)
print(f"server_config: {server_config}")
del argv
devices = server_lib.get_devices()
print(f"devices: {devices}")

metrics_server_config: config_lib.MetricsServerConfig | None = None
# Setup Prometheus server
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from absl import app
from absl import flags

from jetstream.core.implementations.mock import config as mock_config
from jetstream.entrypoints.mock import config as mock_config
from jetstream.core import server_lib


Expand Down
5 changes: 5 additions & 0 deletions jetstream/tests/entrypoints/http/test_api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

"""Tests http server end-to-end."""

import os
import subprocess
import sys
import time
Expand All @@ -29,13 +30,17 @@ class HTTPServerTest(unittest.IsolatedAsyncioTestCase):
def setUpClass(cls):
"""Sets up a JetStream http server for unit tests."""
cls.base_url = "http://localhost:8080"
my_env = os.environ.copy() # Create a copy of the current environment
my_env["JAX_PLATFORMS"] = "cpu"
my_env["JAX_TRACEBACK_FILTERING"] = "off"
cls.server = subprocess.Popen(
[
"python",
"-m",
"jetstream.entrypoints.http.api_server",
"--config=InterleavedCPUTestServer",
],
env=my_env,
stdout=sys.stdout,
stderr=sys.stderr,
)
Expand Down
27 changes: 27 additions & 0 deletions requirements-standalone.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# jetstream library
absl-py
coverage
flax
grpcio
jax
jaxlib
numpy
portpicker
prometheus-client
pytest
seqio
tiktoken
blobfile
parameterized
shortuuid
# jetstream benchmarks
nltk
evaluate
rouge-score
tqdm
# jetstream profiling
tensorboard-plugin-profile
# engines
# maxtext @ git+https://github.com/google/[email protected]#egg=maxtext
# maxtext @ {root:uri}/jetstream/engine/implementations/maxtext
jetstream_pt @ git+https://github.com/google/[email protected]#egg=jetstream_pt
Loading