From 1439a2eb02f037306a0d9d48a393f30ea1e9c061 Mon Sep 17 00:00:00 2001 From: Zijun Zhou Date: Thu, 30 May 2024 19:24:37 +0000 Subject: [PATCH 01/23] Enable JetStream Standalone Server --- README.md | 2 +- jetstream/{core/implementations => entrypoints}/__init__.py | 0 jetstream/{core/implementations => entrypoints}/mock/README.md | 0 .../{core/implementations => entrypoints}/mock/__init__.py | 0 jetstream/{core/implementations => entrypoints}/mock/config.py | 0 jetstream/{core/implementations => entrypoints}/mock/server.py | 2 +- 6 files changed, 2 insertions(+), 2 deletions(-) rename jetstream/{core/implementations => entrypoints}/__init__.py (100%) rename jetstream/{core/implementations => entrypoints}/mock/README.md (100%) rename jetstream/{core/implementations => entrypoints}/mock/__init__.py (100%) rename jetstream/{core/implementations => entrypoints}/mock/config.py (100%) rename jetstream/{core/implementations => entrypoints}/mock/server.py (95%) diff --git a/README.md b/README.md index ee0b1eee..aaabec81 100644 --- a/README.md +++ b/README.md @@ -46,7 +46,7 @@ pip install -r requirements.txt Use the following commands to run a server locally: ``` # Start a server -python -m jetstream.core.implementations.mock.server +python -m jetstream.entrypoints.mock.server # Test local mock server python -m jetstream.tools.requester diff --git a/jetstream/core/implementations/__init__.py b/jetstream/entrypoints/__init__.py similarity index 100% rename from jetstream/core/implementations/__init__.py rename to jetstream/entrypoints/__init__.py diff --git a/jetstream/core/implementations/mock/README.md b/jetstream/entrypoints/mock/README.md similarity index 100% rename from jetstream/core/implementations/mock/README.md rename to jetstream/entrypoints/mock/README.md diff --git a/jetstream/core/implementations/mock/__init__.py b/jetstream/entrypoints/mock/__init__.py similarity index 100% rename from jetstream/core/implementations/mock/__init__.py rename to jetstream/entrypoints/mock/__init__.py diff --git a/jetstream/core/implementations/mock/config.py b/jetstream/entrypoints/mock/config.py similarity index 100% rename from jetstream/core/implementations/mock/config.py rename to jetstream/entrypoints/mock/config.py diff --git a/jetstream/core/implementations/mock/server.py b/jetstream/entrypoints/mock/server.py similarity index 95% rename from jetstream/core/implementations/mock/server.py rename to jetstream/entrypoints/mock/server.py index 6a0cee76..aca0c427 100644 --- a/jetstream/core/implementations/mock/server.py +++ b/jetstream/entrypoints/mock/server.py @@ -19,7 +19,7 @@ from absl import app from absl import flags -from jetstream.core.implementations.mock import config as mock_config +from jetstream.entrypoints.mock import config as mock_config from jetstream.core import server_lib From 7600cd484271a7e8c4670ecb433b51a46bf8d6fd Mon Sep 17 00:00:00 2001 From: Zijun Zhou Date: Fri, 31 May 2024 22:42:13 +0000 Subject: [PATCH 02/23] Add server entrypoints and engine deps --- .github/workflows/unit_tests.yaml | 2 +- .gitignore | 3 ++ .gitmodules | 3 ++ jetstream/engine/implementations/maxtext | 1 + jetstream/entrypoints/config.py | 67 ++++++++++++++++++++++++ jetstream/entrypoints/server.py | 62 ++++++++++++++++++++++ requirements-standalone.txt | 21 ++++++++ 7 files changed, 158 insertions(+), 1 deletion(-) create mode 100644 .gitmodules create mode 160000 jetstream/engine/implementations/maxtext create mode 100644 jetstream/entrypoints/config.py create mode 100644 jetstream/entrypoints/server.py create mode 100644 requirements-standalone.txt diff --git a/.github/workflows/unit_tests.yaml b/.github/workflows/unit_tests.yaml index 7b230dde..105bb0b8 100644 --- a/.github/workflows/unit_tests.yaml +++ b/.github/workflows/unit_tests.yaml @@ -50,7 +50,7 @@ jobs: pip install -r benchmarks/requirements.in - name: Typecheck the code with pytype run: | - pytype --jobs auto --disable=import-error,module-attr jetstream/ benchmarks/ + pytype --jobs auto --exclude "jetstream/engine/implementations/*" --disable=import-error,module-attr jetstream/ benchmarks/ - name: Analysing the code with pylint run: | pylint jetstream/ benchmarks/ diff --git a/.gitignore b/.gitignore index a13d13c3..a9ebd15f 100644 --- a/.gitignore +++ b/.gitignore @@ -11,3 +11,6 @@ logs/ tmp/ venv/ .vscode/ + +# engine imple submodules +jetstream/engine/implementations/ diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 00000000..b4480d50 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "jetstream/engine/implementations/maxtext"] + path = jetstream/engine/implementations/maxtext + url = https://github.com/google/maxtext.git diff --git a/jetstream/engine/implementations/maxtext b/jetstream/engine/implementations/maxtext new file mode 160000 index 00000000..122db988 --- /dev/null +++ b/jetstream/engine/implementations/maxtext @@ -0,0 +1 @@ +Subproject commit 122db98858105f066171bdf67a96aa11033ea99e diff --git a/jetstream/entrypoints/config.py b/jetstream/entrypoints/config.py new file mode 100644 index 00000000..92b92734 --- /dev/null +++ b/jetstream/entrypoints/config.py @@ -0,0 +1,67 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Config for JetStream Server (including engine init).""" + +import functools +import os +from jetstream.engine.implementations.maxtext.MaxText.maxengine_config import create_maxengine +import pyconfig +from typing import Sequence, Type + +import jax + + +from jetstream.core import config_lib +from jetstream_pt import config + + +def get_server_config( + config_str: str, argv: Sequence[str] +) -> config_lib.ServerConfig | Type[config_lib.ServerConfig]: + match config_str: + case "MaxtextInterleavedServer": + jax.config.update("jax_default_prng_impl", "unsafe_rbg") + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" + pyconfig.initialize(argv) + server_config = config_lib.ServerConfig( + prefill_slices=(), + generate_slices=(), + interleaved_slices=("tpu=" + str(jax.device_count()),), + prefill_engine_create_fns=(), + generate_engine_create_fns=(), + interleaved_engine_create_fns=( + functools.partial(create_maxengine, config=pyconfig.config), + ), + ) + case "PyTorchInterleavedServer": + os.environ["XLA_FLAGS"] = ( + "--xla_dump_to=/tmp/xla_logs --xla_dump_hlo_as_text" + ) + engine = config.create_engine_from_config_flags() + server_config = config_lib.ServerConfig( + prefill_slices=(), + generate_slices=(), + interleaved_slices=("tpu=" + str(jax.device_count()),), + prefill_engine_create_fns=(), + generate_engine_create_fns=(), + interleaved_engine_create_fns=(lambda a: engine,), + ) + case "InterleavedCPUTestServer": + server_config = config_lib.InterleavedCPUTestServer + case "CPUTestServer": + server_config = config_lib.CPUTestServer + case _: + raise NotImplementedError + return server_config diff --git a/jetstream/entrypoints/server.py b/jetstream/entrypoints/server.py new file mode 100644 index 00000000..e54fb767 --- /dev/null +++ b/jetstream/entrypoints/server.py @@ -0,0 +1,62 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Runs a JetStream Server.""" + +from typing import Sequence + +from absl import app +from absl import flags + +from jetstream.entrypoints import config +from jetstream.core import config_lib, server_lib + + +flags.DEFINE_integer("port", 9000, "port to listen on") +flags.DEFINE_integer("threads", 64, "number of worker threads in thread pool") +flags.DEFINE_string( + "config", + "InterleavedCPUTestServer", + "available servers", +) +flags.DEFINE_integer("prometheus_port", 0, "") + + +def main(argv: Sequence[str]): + devices = server_lib.get_devices() + print(f"devices: {devices}") + server_config = config.get_server_config(flags.FLAGS.config, argv) + print(f"server_config: {server_config}") + del argv + + metrics_server_config: config_lib.MetricsServerConfig | None = None + if flags.FLAGS.prometheus_port != 0: + metrics_server_config = config_lib.MetricsServerConfig( + port=flags.FLAGS.prometheus_port + ) + # We separate credential from run so that we can unit test it with local + # credentials. + # TODO: Add grpc credentials for OSS. + jetstream_server = server_lib.run( + threads=flags.FLAGS.threads, + port=flags.FLAGS.port, + config=server_config, + devices=devices, + metrics_server_config=metrics_server_config, + ) + jetstream_server.wait_for_termination() + + +if __name__ == "__main__": + app.run(main) diff --git a/requirements-standalone.txt b/requirements-standalone.txt new file mode 100644 index 00000000..39380d3f --- /dev/null +++ b/requirements-standalone.txt @@ -0,0 +1,21 @@ +# jetstream library +absl-py +coverage +flax +grpcio +jax +jaxlib +numpy +portpicker +prometheus-client +pytest +seqio +tiktoken +blobfile +parameterized +shortuuid +# jetstream entrypoints +pyconfig +# engines +# maxtext @ git+https://github.com/google/maxtext.git@jetstream-v0.2.2#egg=maxtext +jetstream_pt @ git+https://github.com/google/jetstream-pytorch.git@jetstream-v0.2.2#egg=jetstream_pt From 3c1c3ce1a58a8557e75f301ca747b6289bc283d7 Mon Sep 17 00:00:00 2001 From: Zijun Zhou Date: Fri, 31 May 2024 22:54:23 +0000 Subject: [PATCH 03/23] format imports --- jetstream/entrypoints/config.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/jetstream/entrypoints/config.py b/jetstream/entrypoints/config.py index 92b92734..9299da4f 100644 --- a/jetstream/entrypoints/config.py +++ b/jetstream/entrypoints/config.py @@ -16,15 +16,13 @@ import functools import os -from jetstream.engine.implementations.maxtext.MaxText.maxengine_config import create_maxengine -import pyconfig from typing import Sequence, Type import jax - - from jetstream.core import config_lib +from jetstream.engine.implementations.maxtext.MaxText.maxengine_config import create_maxengine from jetstream_pt import config +import pyconfig def get_server_config( From 3d3431a68a04cb1f3cb5223ac0218550996421ef Mon Sep 17 00:00:00 2001 From: Zijun Zhou Date: Mon, 3 Jun 2024 19:20:43 +0000 Subject: [PATCH 04/23] Use MaxText jetstream-v0.2.2 --- jetstream/engine/implementations/maxtext | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jetstream/engine/implementations/maxtext b/jetstream/engine/implementations/maxtext index 122db988..34412d44 160000 --- a/jetstream/engine/implementations/maxtext +++ b/jetstream/engine/implementations/maxtext @@ -1 +1 @@ -Subproject commit 122db98858105f066171bdf67a96aa11033ea99e +Subproject commit 34412d44f23b9d2d3bb00e681435e1cfdf678a69 From 225a8a8b4a83e962db7068a1f5e83f9ca43afc32 Mon Sep 17 00:00:00 2001 From: Zijun Zhou Date: Fri, 7 Jun 2024 21:12:12 +0000 Subject: [PATCH 05/23] update deps --- jetstream/entrypoints/config.py | 7 ++++--- requirements-standalone.txt | 10 ++++++++-- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/jetstream/entrypoints/config.py b/jetstream/entrypoints/config.py index 9299da4f..1f260fcb 100644 --- a/jetstream/entrypoints/config.py +++ b/jetstream/entrypoints/config.py @@ -20,9 +20,8 @@ import jax from jetstream.core import config_lib -from jetstream.engine.implementations.maxtext.MaxText.maxengine_config import create_maxengine +from jetstream.engine.implementations.maxtext.MaxText import maxengine_config, pyconfig from jetstream_pt import config -import pyconfig def get_server_config( @@ -40,7 +39,9 @@ def get_server_config( prefill_engine_create_fns=(), generate_engine_create_fns=(), interleaved_engine_create_fns=( - functools.partial(create_maxengine, config=pyconfig.config), + functools.partial( + maxengine_config.create_maxengine, config=pyconfig.config + ), ), ) case "PyTorchInterleavedServer": diff --git a/requirements-standalone.txt b/requirements-standalone.txt index 39380d3f..716e7ba2 100644 --- a/requirements-standalone.txt +++ b/requirements-standalone.txt @@ -14,8 +14,14 @@ tiktoken blobfile parameterized shortuuid -# jetstream entrypoints -pyconfig +# jetstream benchmarks +nltk +evaluate +rouge-score +tqdm +# jetstream profiling +tensorboard-plugin-profile # engines # maxtext @ git+https://github.com/google/maxtext.git@jetstream-v0.2.2#egg=maxtext +# maxtext @ {root:uri}/jetstream/engine/implementations/maxtext jetstream_pt @ git+https://github.com/google/jetstream-pytorch.git@jetstream-v0.2.2#egg=jetstream_pt From 5627ad1e369bf2716cb09c6a2ae011183a56ca2d Mon Sep 17 00:00:00 2001 From: Zijun Zhou Date: Fri, 9 Aug 2024 21:34:46 +0000 Subject: [PATCH 06/23] update entrypoints --- jetstream/entrypoints/grpc/__init__.py | 13 +++++++++++++ jetstream/entrypoints/{ => grpc}/server.py | 0 jetstream/entrypoints/http/api_server.py | 2 +- 3 files changed, 14 insertions(+), 1 deletion(-) create mode 100644 jetstream/entrypoints/grpc/__init__.py rename jetstream/entrypoints/{ => grpc}/server.py (100%) diff --git a/jetstream/entrypoints/grpc/__init__.py b/jetstream/entrypoints/grpc/__init__.py new file mode 100644 index 00000000..6d5e14bc --- /dev/null +++ b/jetstream/entrypoints/grpc/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/jetstream/entrypoints/server.py b/jetstream/entrypoints/grpc/server.py similarity index 100% rename from jetstream/entrypoints/server.py rename to jetstream/entrypoints/grpc/server.py diff --git a/jetstream/entrypoints/http/api_server.py b/jetstream/entrypoints/http/api_server.py index aaced235..c8dae1f2 100644 --- a/jetstream/entrypoints/http/api_server.py +++ b/jetstream/entrypoints/http/api_server.py @@ -96,7 +96,7 @@ def server(argv: Sequence[str]): # Init LLMOrchestrator which would be the main handler in the api endpoints. devices = server_lib.get_devices() print(f"devices: {devices}") - server_config = get_server_config(flags.FLAGS.config) + server_config = get_server_config(flags.FLAGS.config, argv) print(f"server_config: {server_config}") del argv From bfd3e308a40772e4f408105b8e2f53ed6dc050d6 Mon Sep 17 00:00:00 2001 From: Zijun Zhou Date: Fri, 9 Aug 2024 22:05:13 +0000 Subject: [PATCH 07/23] install submodules in CI --- .github/workflows/unit_tests.yaml | 4 +++- Makefile | 5 ++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/.github/workflows/unit_tests.yaml b/.github/workflows/unit_tests.yaml index 8db79fc3..03e4300e 100644 --- a/.github/workflows/unit_tests.yaml +++ b/.github/workflows/unit_tests.yaml @@ -65,7 +65,9 @@ jobs: with: python-version: ${{ matrix.python-version }} - name: Install Dependencies - run: make install-deps + run: | + make install-deps + make install-submodules - name: Run all unit tests in JetStream (jetstream/tests) run: make unit-tests - name: Create test coverage report diff --git a/Makefile b/Makefile index a7699a53..d3c12e31 100644 --- a/Makefile +++ b/Makefile @@ -5,7 +5,7 @@ GRPC_TOOLS_VERSION := 1.62.1 all: update-and-install-deps generate-protos format check # Dependency management targets -update-and-install-deps: update-deps install-deps +update-and-install-deps: update-deps install-deps install-submodules update-deps: $(PIP) install pip-tools @@ -14,6 +14,9 @@ update-deps: install-deps: $(PIP) install pytype pylint pyink -r requirements.txt -r benchmarks/requirements.in +install-submodules: + git submodule update --init --recursive + # Code generation/formatting targets generate-protos: generate-and-prepend-preambles format From 406ee343b37b77c9d5b6e7d8bb05629dd2bd02b7 Mon Sep 17 00:00:00 2001 From: Zijun Zhou Date: Fri, 9 Aug 2024 22:19:21 +0000 Subject: [PATCH 08/23] fix import --- jetstream/entrypoints/config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jetstream/entrypoints/config.py b/jetstream/entrypoints/config.py index 1f260fcb..644cc38f 100644 --- a/jetstream/entrypoints/config.py +++ b/jetstream/entrypoints/config.py @@ -20,7 +20,7 @@ import jax from jetstream.core import config_lib -from jetstream.engine.implementations.maxtext.MaxText import maxengine_config, pyconfig +from jetstream.engine.implementations.maxtext.MaxText import maxengine, pyconfig from jetstream_pt import config @@ -40,7 +40,7 @@ def get_server_config( generate_engine_create_fns=(), interleaved_engine_create_fns=( functools.partial( - maxengine_config.create_maxengine, config=pyconfig.config + maxengine.MaxEngine(config), config=pyconfig.config ), ), ) From b527e32b721111ba4344ba0e5bb45c959010eda6 Mon Sep 17 00:00:00 2001 From: Zijun Zhou Date: Sat, 10 Aug 2024 00:47:12 +0000 Subject: [PATCH 09/23] add submodule to py path --- Makefile | 1 + jetstream/engine/__init__.py | 6 ++++++ 2 files changed, 7 insertions(+) diff --git a/Makefile b/Makefile index d3c12e31..7aba1e24 100644 --- a/Makefile +++ b/Makefile @@ -16,6 +16,7 @@ install-deps: install-submodules: git submodule update --init --recursive + bash jetstream/engine/implementations/maxtext/setup.sh # Code generation/formatting targets generate-protos: generate-and-prepend-preambles format diff --git a/jetstream/engine/__init__.py b/jetstream/engine/__init__.py index ee979964..b6b6f1f9 100644 --- a/jetstream/engine/__init__.py +++ b/jetstream/engine/__init__.py @@ -21,3 +21,9 @@ except ImportError as e: print("Proxy backend support is not added") pass + +import os +import sys + +submodule_path = os.path.join(os.path.dirname(__file__), 'implementations/maxtext/MaxText') +sys.path.append(submodule_path) From 84506f340f6e1479fb729e6f550124019a62050b Mon Sep 17 00:00:00 2001 From: Zijun Zhou Date: Sat, 10 Aug 2024 00:55:43 +0000 Subject: [PATCH 10/23] fix Makefile --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 7aba1e24..5f9797cd 100644 --- a/Makefile +++ b/Makefile @@ -16,7 +16,7 @@ install-deps: install-submodules: git submodule update --init --recursive - bash jetstream/engine/implementations/maxtext/setup.sh + ./jetstream/engine/implementations/maxtext/setup.sh # Code generation/formatting targets generate-protos: generate-and-prepend-preambles format From f0b255a31168c038541b426a7df7f864a8649231 Mon Sep 17 00:00:00 2001 From: Zijun Zhou Date: Sat, 10 Aug 2024 00:58:54 +0000 Subject: [PATCH 11/23] fix Makefile permission --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 5f9797cd..92af027b 100644 --- a/Makefile +++ b/Makefile @@ -16,7 +16,7 @@ install-deps: install-submodules: git submodule update --init --recursive - ./jetstream/engine/implementations/maxtext/setup.sh + chmod +x ./jetstream/engine/implementations/maxtext/setup.sh # Code generation/formatting targets generate-protos: generate-and-prepend-preambles format From a4031588afef9004aca184b394bf818ef1f1b7d5 Mon Sep 17 00:00:00 2001 From: Zijun Zhou Date: Sat, 10 Aug 2024 01:02:12 +0000 Subject: [PATCH 12/23] fix Makefile permission and run --- Makefile | 1 + 1 file changed, 1 insertion(+) diff --git a/Makefile b/Makefile index 92af027b..d033cac8 100644 --- a/Makefile +++ b/Makefile @@ -17,6 +17,7 @@ install-deps: install-submodules: git submodule update --init --recursive chmod +x ./jetstream/engine/implementations/maxtext/setup.sh + ./jetstream/engine/implementations/maxtext/setup.sh # Code generation/formatting targets generate-protos: generate-and-prepend-preambles format From 0bc032e5215d24b149c0647a42f7d4f79eeed0a7 Mon Sep 17 00:00:00 2001 From: Zijun Zhou Date: Sat, 10 Aug 2024 01:10:28 +0000 Subject: [PATCH 13/23] change correct path in Makefile --- Makefile | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/Makefile b/Makefile index d033cac8..d5e62b7a 100644 --- a/Makefile +++ b/Makefile @@ -16,8 +16,9 @@ install-deps: install-submodules: git submodule update --init --recursive - chmod +x ./jetstream/engine/implementations/maxtext/setup.sh - ./jetstream/engine/implementations/maxtext/setup.sh + cd jetstream/engine/implementations/maxtext + chmod +x ./setup.sh + ./setup.sh # Code generation/formatting targets generate-protos: generate-and-prepend-preambles format From da1e7b05affb1df0b0b94c61969b0c80fd4699fa Mon Sep 17 00:00:00 2001 From: Zijun Zhou Date: Sat, 10 Aug 2024 01:18:12 +0000 Subject: [PATCH 14/23] fix Makefile --- Makefile | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/Makefile b/Makefile index d5e62b7a..e7207f29 100644 --- a/Makefile +++ b/Makefile @@ -16,9 +16,7 @@ install-deps: install-submodules: git submodule update --init --recursive - cd jetstream/engine/implementations/maxtext - chmod +x ./setup.sh - ./setup.sh + cd ./jetstream/engine/implementations/maxtext && chmod +x ./setup.sh && ./setup.sh # Code generation/formatting targets generate-protos: generate-and-prepend-preambles format From e1a621658964eb99a1c492980013d9494698e5d2 Mon Sep 17 00:00:00 2001 From: Zijun Zhou Date: Sat, 10 Aug 2024 01:35:40 +0000 Subject: [PATCH 15/23] update deps --- Makefile | 1 + requirements.in | 19 --- requirements.txt | 370 +++-------------------------------------------- 3 files changed, 20 insertions(+), 370 deletions(-) delete mode 100644 requirements.in diff --git a/Makefile b/Makefile index e7207f29..9788b422 100644 --- a/Makefile +++ b/Makefile @@ -17,6 +17,7 @@ install-deps: install-submodules: git submodule update --init --recursive cd ./jetstream/engine/implementations/maxtext && chmod +x ./setup.sh && ./setup.sh + $(PIP) jetstream_pt @ git+https://github.com/google/jetstream-pytorch.git@jetstream-v0.2.2#egg=jetstream_pt # Code generation/formatting targets generate-protos: generate-and-prepend-preambles format diff --git a/requirements.in b/requirements.in deleted file mode 100644 index 86841a57..00000000 --- a/requirements.in +++ /dev/null @@ -1,19 +0,0 @@ -absl-py -coverage -flax -grpcio -jax -jaxlib -numpy -portpicker -prometheus-client -pytest -seqio -tiktoken -blobfile -parameterized -shortuuid -fastapi -uvicorn -# For profiling -tensorboard-plugin-profile \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 67e31fdd..86841a57 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,351 +1,19 @@ -# -# This file is autogenerated by pip-compile with Python 3.10 -# by the following command: -# -# pip-compile requirements.in -# -absl-py==1.4.0 - # via - # -r requirements.in - # array-record - # chex - # clu - # etils - # ml-collections - # optax - # orbax-checkpoint - # seqio - # tensorboard - # tensorflow - # tensorflow-metadata - # tfds-nightly -anyio==3.7.1 - # via - # fastapi - # starlette -array-record==0.5.0 - # via tfds-nightly -astunparse==1.6.3 - # via tensorflow -blobfile==2.1.1 - # via -r requirements.in -cachetools==5.3.2 - # via google-auth -certifi==2024.7.4 - # via requests -charset-normalizer==3.3.2 - # via requests -chex==0.1.7 - # via optax -click==8.1.7 - # via - # tfds-nightly - # uvicorn -clu==0.0.10 - # via seqio -contextlib2==21.6.0 - # via ml-collections -coverage==7.4.4 - # via -r requirements.in -dm-tree==0.1.8 - # via - # chex - # tfds-nightly -docstring-parser==0.15 - # via pyglove -editdistance==0.6.2 - # via seqio -etils[array-types,enp,epath,epy,etqdm,etree]==1.6.0 - # via - # array-record - # clu - # orbax-checkpoint - # tfds-nightly -exceptiongroup==1.2.0 - # via - # anyio - # pytest -fastapi==0.103.2 - # via -r requirements.in -filelock==3.14.0 - # via blobfile -flatbuffers==23.5.26 - # via tensorflow -flax==0.8.0 - # via - # -r requirements.in - # clu -fsspec==2023.12.2 - # via etils -gast==0.4.0 - # via tensorflow -google-auth==2.27.0 - # via - # google-auth-oauthlib - # tensorboard -google-auth-oauthlib==1.0.0 - # via tensorboard -google-pasta==0.2.0 - # via tensorflow -googleapis-common-protos==1.62.0 - # via tensorflow-metadata -grpcio==1.60.1 - # via - # -r requirements.in - # tensorboard - # tensorflow -gviz-api==1.10.0 - # via tensorboard-plugin-profile -h11==0.14.0 - # via uvicorn -h5py==3.10.0 - # via tensorflow -idna==3.7 - # via - # anyio - # requests -importlib-resources==6.1.1 - # via etils -iniconfig==2.0.0 - # via pytest -jax==0.4.23 - # via - # -r requirements.in - # chex - # clu - # flax - # optax - # orbax-checkpoint - # seqio -jaxlib==0.4.23 - # via - # -r requirements.in - # chex - # clu - # optax - # orbax-checkpoint - # seqio -keras==2.13.1 - # via tensorflow -libclang==16.0.6 - # via tensorflow -lxml==4.9.4 - # via blobfile -markdown==3.5.2 - # via tensorboard -markdown-it-py==3.0.0 - # via rich -markupsafe==2.1.5 - # via werkzeug -mdurl==0.1.2 - # via markdown-it-py -ml-collections==0.1.1 - # via clu -ml-dtypes==0.3.2 - # via - # jax - # jaxlib - # tensorstore -msgpack==1.0.7 - # via - # flax - # orbax-checkpoint -nest-asyncio==1.6.0 - # via orbax-checkpoint -numpy==1.23.1 - # via - # -r requirements.in - # chex - # clu - # etils - # flax - # h5py - # jax - # jaxlib - # ml-dtypes - # opt-einsum - # optax - # orbax-checkpoint - # scipy - # seqio - # tensorboard - # tensorflow - # tensorflow-hub - # tensorstore - # tfds-nightly -oauthlib==3.2.2 - # via requests-oauthlib -opt-einsum==3.3.0 - # via - # jax - # tensorflow -optax==0.1.8 - # via flax -orbax-checkpoint==0.5.2 - # via flax -packaging==23.2 - # via - # clu - # pytest - # seqio - # tensorflow -parameterized==0.9.0 - # via -r requirements.in -pluggy==1.4.0 - # via pytest -portpicker==1.6.0 - # via -r requirements.in -prometheus-client==0.20.0 - # via -r requirements.in -promise==2.3 - # via tfds-nightly -protobuf==3.20.3 - # via - # googleapis-common-protos - # orbax-checkpoint - # seqio - # tensorboard - # tensorboard-plugin-profile - # tensorflow - # tensorflow-hub - # tensorflow-metadata - # tfds-nightly -psutil==5.9.8 - # via - # portpicker - # tfds-nightly -pyasn1==0.5.1 - # via - # pyasn1-modules - # rsa -pyasn1-modules==0.3.0 - # via google-auth -pycryptodomex==3.20.0 - # via blobfile -pydantic==1.10.17 - # via fastapi -pyglove==0.4.4 - # via seqio -pygments==2.17.2 - # via rich -pytest==8.1.1 - # via -r requirements.in -pyyaml==6.0.1 - # via - # flax - # ml-collections - # orbax-checkpoint -regex==2024.4.28 - # via tiktoken -requests==2.32.0 - # via - # requests-oauthlib - # tensorboard - # tfds-nightly - # tiktoken -requests-oauthlib==1.3.1 - # via google-auth-oauthlib -rich==13.7.0 - # via flax -rsa==4.9 - # via google-auth -scipy==1.12.0 - # via - # jax - # jaxlib -sentencepiece==0.1.99 - # via seqio -seqio==0.0.19 - # via -r requirements.in -shortuuid==1.0.13 - # via -r requirements.in -six==1.16.0 - # via - # astunparse - # google-pasta - # gviz-api - # ml-collections - # promise - # tensorboard-plugin-profile - # tensorflow -sniffio==1.3.1 - # via anyio -starlette==0.27.0 - # via fastapi -tensorboard==2.13.0 - # via tensorflow -tensorboard-data-server==0.7.2 - # via tensorboard -tensorboard-plugin-profile==2.15.1 - # via -r requirements.in -tensorflow==2.13.1 - # via tensorflow-text -tensorflow-estimator==2.13.0 - # via tensorflow -tensorflow-hub==0.16.1 - # via tensorflow-text -tensorflow-io-gcs-filesystem==0.35.0 - # via tensorflow -tensorflow-metadata==1.14.0 - # via tfds-nightly -tensorflow-text==2.13.0 - # via seqio -tensorstore==0.1.52 - # via - # flax - # orbax-checkpoint -termcolor==2.4.0 - # via - # tensorflow - # tfds-nightly -tf-keras==2.15.0 - # via tensorflow-hub -tfds-nightly==4.9.2.dev202308090034 - # via seqio -tiktoken==0.6.0 - # via -r requirements.in -toml==0.10.2 - # via tfds-nightly -tomli==2.0.1 - # via pytest -toolz==0.12.1 - # via chex -tqdm==4.66.3 - # via - # etils - # tfds-nightly -typing-extensions==4.5.0 - # via - # chex - # clu - # etils - # fastapi - # flax - # orbax-checkpoint - # pydantic - # tensorflow - # uvicorn -urllib3==2.2.2 - # via - # blobfile - # requests -uvicorn==0.30.1 - # via -r requirements.in -werkzeug==3.0.1 - # via - # tensorboard - # tensorboard-plugin-profile -wheel==0.42.0 - # via - # astunparse - # tensorboard -wrapt==1.16.0 - # via - # clu - # tensorflow - # tfds-nightly -zipp==3.19.1 - # via etils - -# The following packages are considered to be unsafe in a requirements file: -# setuptools +absl-py +coverage +flax +grpcio +jax +jaxlib +numpy +portpicker +prometheus-client +pytest +seqio +tiktoken +blobfile +parameterized +shortuuid +fastapi +uvicorn +# For profiling +tensorboard-plugin-profile \ No newline at end of file From 1521350101452d9a05159817165f51ea4a0c9ea9 Mon Sep 17 00:00:00 2001 From: Zijun Zhou Date: Sat, 10 Aug 2024 01:59:01 +0000 Subject: [PATCH 16/23] ignore --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 9788b422..6a970569 100644 --- a/Makefile +++ b/Makefile @@ -16,7 +16,7 @@ install-deps: install-submodules: git submodule update --init --recursive - cd ./jetstream/engine/implementations/maxtext && chmod +x ./setup.sh && ./setup.sh + - cd ./jetstream/engine/implementations/maxtext && chmod +x ./setup.sh && ./setup.sh $(PIP) jetstream_pt @ git+https://github.com/google/jetstream-pytorch.git@jetstream-v0.2.2#egg=jetstream_pt # Code generation/formatting targets From a695c4b85a2743584a2d0b687a275031b6a09677 Mon Sep 17 00:00:00 2001 From: Zijun Zhou Date: Sat, 10 Aug 2024 02:02:41 +0000 Subject: [PATCH 17/23] fix --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 6a970569..9038384b 100644 --- a/Makefile +++ b/Makefile @@ -17,7 +17,7 @@ install-deps: install-submodules: git submodule update --init --recursive - cd ./jetstream/engine/implementations/maxtext && chmod +x ./setup.sh && ./setup.sh - $(PIP) jetstream_pt @ git+https://github.com/google/jetstream-pytorch.git@jetstream-v0.2.2#egg=jetstream_pt + $(PIP) install jetstream_pt @ git+https://github.com/google/jetstream-pytorch.git@jetstream-v0.2.2#egg=jetstream_pt # Code generation/formatting targets generate-protos: generate-and-prepend-preambles format From 84467f74ef593b69991c7a1694629c9bbccfa698 Mon Sep 17 00:00:00 2001 From: Zijun Zhou Date: Sat, 10 Aug 2024 02:07:38 +0000 Subject: [PATCH 18/23] fix pt module --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 9038384b..9c008ffc 100644 --- a/Makefile +++ b/Makefile @@ -17,7 +17,7 @@ install-deps: install-submodules: git submodule update --init --recursive - cd ./jetstream/engine/implementations/maxtext && chmod +x ./setup.sh && ./setup.sh - $(PIP) install jetstream_pt @ git+https://github.com/google/jetstream-pytorch.git@jetstream-v0.2.2#egg=jetstream_pt + - $(PIP) install jetstream_pt@git+https://github.com/google/jetstream-pytorch.git@jetstream-v0.2.2#egg=jetstream_pt # Code generation/formatting targets generate-protos: generate-and-prepend-preambles format From b7fdf0748bbd808fb7365e09ef78f5647f0f03fc Mon Sep 17 00:00:00 2001 From: Zijun Zhou Date: Sat, 10 Aug 2024 02:18:58 +0000 Subject: [PATCH 19/23] cpu jax --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 9c008ffc..bef4c574 100644 --- a/Makefile +++ b/Makefile @@ -16,7 +16,7 @@ install-deps: install-submodules: git submodule update --init --recursive - - cd ./jetstream/engine/implementations/maxtext && chmod +x ./setup.sh && ./setup.sh + - cd ./jetstream/engine/implementations/maxtext && $(PIP) install -r requirements.txt - $(PIP) install jetstream_pt@git+https://github.com/google/jetstream-pytorch.git@jetstream-v0.2.2#egg=jetstream_pt # Code generation/formatting targets From 7f83000f8d7e7f412a7d7322b7ce302aaee9be21 Mon Sep 17 00:00:00 2001 From: Zijun Zhou Date: Sat, 10 Aug 2024 02:21:43 +0000 Subject: [PATCH 20/23] fix --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index bef4c574..ea1d6d06 100644 --- a/Makefile +++ b/Makefile @@ -16,7 +16,7 @@ install-deps: install-submodules: git submodule update --init --recursive - - cd ./jetstream/engine/implementations/maxtext && $(PIP) install -r requirements.txt + - $(PIP) install -r ./jetstream/engine/implementations/maxtext/requirements.txt - $(PIP) install jetstream_pt@git+https://github.com/google/jetstream-pytorch.git@jetstream-v0.2.2#egg=jetstream_pt # Code generation/formatting targets From deb1d68db00393f18361eb8712b1ec071d0219d2 Mon Sep 17 00:00:00 2001 From: Zijun Zhou Date: Sat, 10 Aug 2024 02:37:52 +0000 Subject: [PATCH 21/23] update maxtext submodule --- .gitignore | 3 --- jetstream/engine/implementations/maxtext | 2 +- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/.gitignore b/.gitignore index a9ebd15f..a13d13c3 100644 --- a/.gitignore +++ b/.gitignore @@ -11,6 +11,3 @@ logs/ tmp/ venv/ .vscode/ - -# engine imple submodules -jetstream/engine/implementations/ diff --git a/jetstream/engine/implementations/maxtext b/jetstream/engine/implementations/maxtext index 34412d44..2a6154f2 160000 --- a/jetstream/engine/implementations/maxtext +++ b/jetstream/engine/implementations/maxtext @@ -1 +1 @@ -Subproject commit 34412d44f23b9d2d3bb00e681435e1cfdf678a69 +Subproject commit 2a6154f254bf5dbe67e659360775a83a797ed7f9 From 4745eae443e19af50d2fa6b48215cf5f71974c1e Mon Sep 17 00:00:00 2001 From: Zijun Zhou Date: Mon, 12 Aug 2024 20:37:20 +0000 Subject: [PATCH 22/23] Fix unit test and mock engine --- jetstream/engine/__init__.py | 4 +++- jetstream/engine/mock_engine.py | 16 +++++++++++----- .../tests/entrypoints/http/test_api_server.py | 5 +++++ 3 files changed, 19 insertions(+), 6 deletions(-) diff --git a/jetstream/engine/__init__.py b/jetstream/engine/__init__.py index b6b6f1f9..4101fcd1 100644 --- a/jetstream/engine/__init__.py +++ b/jetstream/engine/__init__.py @@ -25,5 +25,7 @@ import os import sys -submodule_path = os.path.join(os.path.dirname(__file__), 'implementations/maxtext/MaxText') +submodule_path = os.path.join( + os.path.dirname(__file__), "implementations/maxtext/MaxText" +) sys.path.append(submodule_path) diff --git a/jetstream/engine/mock_engine.py b/jetstream/engine/mock_engine.py index 0277e9a3..2e1f6a3f 100644 --- a/jetstream/engine/mock_engine.py +++ b/jetstream/engine/mock_engine.py @@ -129,7 +129,7 @@ def prefill( samples_per_slot=self.generate_cache_batch // self.prefill_cache_batch, ) - return (prefill_cache, first_step), first_token + return (prefill_cache.astype(jnp.float32), first_step), first_token @functools.partial(jax.jit, static_argnums=(0,)) def generate( @@ -152,7 +152,7 @@ def generate( # Update generate cache generate_cache = jax.lax.dynamic_update_slice_in_dim( - generate_cache, + generate_cache.astype(jnp.float32), previous_timestep, start_index=generate_cache_index, axis=1, @@ -198,7 +198,7 @@ def generate( ) return DecodeState( prefill_cache=prefill_cache, - generate_cache=generate_cache, + generate_cache=generate_cache.astype(jnp.float32), generate_cache_index=generate_cache_index, generate_lengths=new_lengths, generate_tokens=new_timestep, @@ -225,12 +225,18 @@ def insert( """Adds `prefix` into `decode_state` at `slot`.""" # [B, T], [T,] -> [B, T] prefill_cache, previous_timestep = prefix + print( + f"{decode_state.prefill_cache.dtype=} {prefill_cache.dtype=} {slot.dtype=}" + ) prefill_cache = jax.lax.dynamic_update_slice_in_dim( decode_state.prefill_cache, prefill_cache, slot, axis=0 ) + print( + f"{decode_state.generate_cache.dtype=} {jnp.zeros((1, self.cache_length)).dtype=} {slot.dtype=}" + ) generate_cache = jax.lax.dynamic_update_slice_in_dim( decode_state.generate_cache, - jnp.zeros((1, self.cache_length)), + jnp.zeros((1, self.cache_length), dtype=jnp.float32), slot, axis=0, ) @@ -243,7 +249,7 @@ def insert( ) generate_tokens = jax.lax.dynamic_update_slice_in_dim( decode_state.generate_tokens, - previous_timestep, + previous_timestep.astype(jnp.float32), slot * samples_per_slot, axis=0, ) diff --git a/jetstream/tests/entrypoints/http/test_api_server.py b/jetstream/tests/entrypoints/http/test_api_server.py index e6d42e58..eeebe694 100644 --- a/jetstream/tests/entrypoints/http/test_api_server.py +++ b/jetstream/tests/entrypoints/http/test_api_server.py @@ -14,6 +14,7 @@ """Tests http server end-to-end.""" +import os import subprocess import sys import time @@ -29,6 +30,9 @@ class HTTPServerTest(unittest.IsolatedAsyncioTestCase): def setUpClass(cls): """Sets up a JetStream http server for unit tests.""" cls.base_url = "http://localhost:8080" + my_env = os.environ.copy() # Create a copy of the current environment + my_env["JAX_PLATFORMS"] = "cpu" + my_env["JAX_TRACEBACK_FILTERING"] = "off" cls.server = subprocess.Popen( [ "python", @@ -36,6 +40,7 @@ def setUpClass(cls): "jetstream.entrypoints.http.api_server", "--config=InterleavedCPUTestServer", ], + env=my_env, stdout=sys.stdout, stderr=sys.stderr, ) From 60eb80c855ddeec28903cb507599b427f11b13cd Mon Sep 17 00:00:00 2001 From: Zijun Zhou Date: Mon, 12 Aug 2024 21:54:09 +0000 Subject: [PATCH 23/23] Fix lint and maxtext server --- jetstream/engine/mock_engine.py | 6 ------ jetstream/entrypoints/http/api_server.py | 4 ++-- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/jetstream/engine/mock_engine.py b/jetstream/engine/mock_engine.py index 2e1f6a3f..10c06e6d 100644 --- a/jetstream/engine/mock_engine.py +++ b/jetstream/engine/mock_engine.py @@ -225,15 +225,9 @@ def insert( """Adds `prefix` into `decode_state` at `slot`.""" # [B, T], [T,] -> [B, T] prefill_cache, previous_timestep = prefix - print( - f"{decode_state.prefill_cache.dtype=} {prefill_cache.dtype=} {slot.dtype=}" - ) prefill_cache = jax.lax.dynamic_update_slice_in_dim( decode_state.prefill_cache, prefill_cache, slot, axis=0 ) - print( - f"{decode_state.generate_cache.dtype=} {jnp.zeros((1, self.cache_length)).dtype=} {slot.dtype=}" - ) generate_cache = jax.lax.dynamic_update_slice_in_dim( decode_state.generate_cache, jnp.zeros((1, self.cache_length), dtype=jnp.float32), diff --git a/jetstream/entrypoints/http/api_server.py b/jetstream/entrypoints/http/api_server.py index c8dae1f2..3879b435 100644 --- a/jetstream/entrypoints/http/api_server.py +++ b/jetstream/entrypoints/http/api_server.py @@ -94,11 +94,11 @@ def server(argv: Sequence[str]): app.include_router(router) # Init LLMOrchestrator which would be the main handler in the api endpoints. - devices = server_lib.get_devices() - print(f"devices: {devices}") server_config = get_server_config(flags.FLAGS.config, argv) print(f"server_config: {server_config}") del argv + devices = server_lib.get_devices() + print(f"devices: {devices}") metrics_server_config: config_lib.MetricsServerConfig | None = None # Setup Prometheus server