From 9a45c32dd24248412e893876e63ad81a02773a04 Mon Sep 17 00:00:00 2001 From: Zijun Zhou Date: Thu, 18 Jul 2024 00:07:07 +0000 Subject: [PATCH 1/7] Add http server to JetStream --- jetstream/core/server_lib.py | 59 ++++++++++---- jetstream/entrypoints/__init__.py | 14 ++++ jetstream/entrypoints/config.py | 35 +++++++++ jetstream/entrypoints/http/__init__.py | 14 ++++ jetstream/entrypoints/http/api_server.py | 98 ++++++++++++++++++++++++ requirements.in | 2 + 6 files changed, 205 insertions(+), 17 deletions(-) create mode 100644 jetstream/entrypoints/__init__.py create mode 100644 jetstream/entrypoints/config.py create mode 100644 jetstream/entrypoints/http/__init__.py create mode 100644 jetstream/entrypoints/http/api_server.py diff --git a/jetstream/core/server_lib.py b/jetstream/core/server_lib.py index 9c1c5986..87b04836 100644 --- a/jetstream/core/server_lib.py +++ b/jetstream/core/server_lib.py @@ -92,37 +92,25 @@ def wait_for_termination(self) -> None: self.stop() -def run( - port: int, +def create_driver( config: Type[config_lib.ServerConfig], devices: Any, - credentials: Any = grpc.insecure_server_credentials(), - threads: int | None = None, jax_padding: bool = True, metrics_server_config: config_lib.MetricsServerConfig | None = None, - enable_jax_profiler: bool = False, - jax_profiler_port: int = 9999, enable_model_warmup: bool = False, -) -> JetStreamServer: - """Runs a server with a specified config. +): + """Creates a driver with a specified config. Args: - port: Port on which the server will be made available. config: A ServerConfig to config engine, model, device slices, etc. devices: Device objects, will be used to get engine with proper slicing. - credentials: Should use grpc credentials by default. - threads: Number of RPC handlers worker threads. This should be at least - equal to the decoding batch size to fully saturate the decoding queue. jax_padding: The flag to enable JAX padding during tokenization. metrics_server_config: The config to enable Promethus metric server. - enable_jax_profiler: The flag to enable JAX profiler server. - jax_profiler_port: The port JAX profiler server (default to 9999). enable_model_warmup: The flag to enable model server warmup with AOT. Returns: - JetStreamServer that wraps the grpc server and orchestrator driver. + An orchestrator driver. """ - logging.info("Kicking off gRPC server.") engines = config_lib.get_engines(config, devices=devices) prefill_params = [pe.load_params() for pe in engines.prefill_engines] generate_params = [ge.load_params() for ge in engines.generate_engines] @@ -178,7 +166,7 @@ def run( traceback.print_exc() os.kill(os.getpid(), signal.SIGKILL) - driver = orchestrator.Driver( + return orchestrator.Driver( prefill_engines=prefill_engines, generate_engines=generate_engines, prefill_params=prefill_params, @@ -188,6 +176,43 @@ def run( metrics_collector=metrics_collector, is_ray_backend=config.is_ray_backend, ) + + +def run( + port: int, + config: Type[config_lib.ServerConfig], + devices: Any, + credentials: Any = grpc.insecure_server_credentials(), + threads: int | None = None, + jax_padding: bool = True, + metrics_server_config: config_lib.MetricsServerConfig | None = None, + enable_jax_profiler: bool = False, + jax_profiler_port: int = 9999, + enable_model_warmup: bool = False, +) -> JetStreamServer: + """Runs a server with a specified config. + + Args: + port: Port on which the server will be made available. + config: A ServerConfig to config engine, model, device slices, etc. + devices: Device objects, will be used to get engine with proper slicing. + credentials: Should use grpc credentials by default. + threads: Number of RPC handlers worker threads. This should be at least + equal to the decoding batch size to fully saturate the decoding queue. + jax_padding: The flag to enable JAX padding during tokenization. + metrics_server_config: The config to enable Promethus metric server. + enable_jax_profiler: The flag to enable JAX profiler server. + jax_profiler_port: The port JAX profiler server (default to 9999). + enable_model_warmup: The flag to enable model server warmup with AOT. + + Returns: + JetStreamServer that wraps the grpc server and orchestrator driver. + """ + logging.info("Kicking off gRPC server.") + + driver = create_driver( + config, devices, jax_padding, metrics_server_config, enable_model_warmup + ) # We default threads to the total number of concurrent allowed decodes, # to make sure we can fully saturate the model. Set default minimum to 64. threads = threads or max(driver.get_total_concurrent_requests(), 64) diff --git a/jetstream/entrypoints/__init__.py b/jetstream/entrypoints/__init__.py new file mode 100644 index 00000000..c38dc3b1 --- /dev/null +++ b/jetstream/entrypoints/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/jetstream/entrypoints/config.py b/jetstream/entrypoints/config.py new file mode 100644 index 00000000..c8a3ba84 --- /dev/null +++ b/jetstream/entrypoints/config.py @@ -0,0 +1,35 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Config for JetStream Server (including engine init).""" + +import functools +import os +from typing import Sequence, Type + +import jax +from jetstream.core import config_lib + + +def get_server_config( + config_str: str, argv: Sequence[str] +) -> config_lib.ServerConfig | Type[config_lib.ServerConfig]: + match config_str: + case "InterleavedCPUTestServer": + server_config = config_lib.InterleavedCPUTestServer + case "CPUTestServer": + server_config = config_lib.CPUTestServer + case _: + raise NotImplementedError + return server_config diff --git a/jetstream/entrypoints/http/__init__.py b/jetstream/entrypoints/http/__init__.py new file mode 100644 index 00000000..c38dc3b1 --- /dev/null +++ b/jetstream/entrypoints/http/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/jetstream/entrypoints/http/api_server.py b/jetstream/entrypoints/http/api_server.py new file mode 100644 index 00000000..2fccbc55 --- /dev/null +++ b/jetstream/entrypoints/http/api_server.py @@ -0,0 +1,98 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from typing import Sequence +from absl import app +from absl import flags +from fastapi import APIRouter, Response +import fastapi +from fastapi.responses import JSONResponse, StreamingResponse +import uvicorn + +from jetstream.core import config_lib, orchestrator, server_lib +from jetstream.entrypoints.config import get_server_config + +flags.DEFINE_string("host", "0.0.0.0", "server host address") +flags.DEFINE_integer("port", 8080, "http server port") +flags.DEFINE_string( + "config", + "InterleavedCPUTestServer", + "available servers", +) +flags.DEFINE_integer( + "prometheus_port", + 9988, + "prometheus_port", +) + +driver: orchestrator.Driver + +# Define Fast API endpoints (use driver to handle). +router = APIRouter() + + +@router.get("/") +def root(): + """Root path for Jetstream HTTP Server.""" + return Response( + content=json.dumps({"message": "JetStream HTTP Server"}, indent=4), + media_type="application/json", + ) + + +@router.get("/v1/health") +async def health() -> Response: + """Health check.""" + return Response( + content=json.dumps({"is_live": str(driver.live)}, indent=4), + media_type="application/json", + status_code=200, + ) + + +def server(argv: Sequence[str]): + # Init Fast API. + app = fastapi.FastAPI() + app.include_router(router) + + # Init driver which would be the main handler in the api endpoints. + devices = server_lib.get_devices() + print(f"devices: {devices}") + server_config = get_server_config(flags.FLAGS.config, argv) + print(f"server_config: {server_config}") + del argv + + metrics_server_config: config_lib.MetricsServerConfig | None = None + if flags.FLAGS.prometheus_port != 0: + metrics_server_config = config_lib.MetricsServerConfig( + port=flags.FLAGS.prometheus_port + ) + + global driver + driver = server_lib.create_driver( + config=server_config, + devices=devices, + metrics_server_config=metrics_server_config, + ) + + # Start uvicorn http server. + uvicorn.run( + app, host=flags.FLAGS.host, port=flags.FLAGS.port, log_level="info" + ) + + +if __name__ == "__main__": + # Run Abseil app w flags parser. + app.run(server) diff --git a/requirements.in b/requirements.in index 459749ae..86841a57 100644 --- a/requirements.in +++ b/requirements.in @@ -13,5 +13,7 @@ tiktoken blobfile parameterized shortuuid +fastapi +uvicorn # For profiling tensorboard-plugin-profile \ No newline at end of file From 07fe623b5637c7e3b17f501a1e4b5ebd04a6ae49 Mon Sep 17 00:00:00 2001 From: Zijun Zhou Date: Thu, 18 Jul 2024 18:32:46 +0000 Subject: [PATCH 2/7] Add generate api and cleanup --- jetstream/entrypoints/__init__.py | 1 - jetstream/entrypoints/config.py | 7 +--- jetstream/entrypoints/http/__init__.py | 1 - jetstream/entrypoints/http/api_server.py | 48 +++++++++++++++++------- jetstream/entrypoints/http/protocol.py | 36 ++++++++++++++++++ jetstream/entrypoints/http/utils.py | 27 +++++++++++++ 6 files changed, 100 insertions(+), 20 deletions(-) create mode 100644 jetstream/entrypoints/http/protocol.py create mode 100644 jetstream/entrypoints/http/utils.py diff --git a/jetstream/entrypoints/__init__.py b/jetstream/entrypoints/__init__.py index c38dc3b1..6d5e14bc 100644 --- a/jetstream/entrypoints/__init__.py +++ b/jetstream/entrypoints/__init__.py @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/jetstream/entrypoints/config.py b/jetstream/entrypoints/config.py index c8a3ba84..79f2b012 100644 --- a/jetstream/entrypoints/config.py +++ b/jetstream/entrypoints/config.py @@ -14,16 +14,13 @@ """Config for JetStream Server (including engine init).""" -import functools -import os -from typing import Sequence, Type +from typing import Type -import jax from jetstream.core import config_lib def get_server_config( - config_str: str, argv: Sequence[str] + config_str: str, ) -> config_lib.ServerConfig | Type[config_lib.ServerConfig]: match config_str: case "InterleavedCPUTestServer": diff --git a/jetstream/entrypoints/http/__init__.py b/jetstream/entrypoints/http/__init__.py index c38dc3b1..6d5e14bc 100644 --- a/jetstream/entrypoints/http/__init__.py +++ b/jetstream/entrypoints/http/__init__.py @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/jetstream/entrypoints/http/api_server.py b/jetstream/entrypoints/http/api_server.py index 2fccbc55..5f8a8087 100644 --- a/jetstream/entrypoints/http/api_server.py +++ b/jetstream/entrypoints/http/api_server.py @@ -12,17 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""JetStream Http API server.""" + import json from typing import Sequence -from absl import app +from absl import app as abslapp from absl import flags from fastapi import APIRouter, Response import fastapi -from fastapi.responses import JSONResponse, StreamingResponse +from fastapi.responses import StreamingResponse import uvicorn +from google.protobuf.json_format import Parse from jetstream.core import config_lib, orchestrator, server_lib +from jetstream.core.proto import jetstream_pb2 from jetstream.entrypoints.config import get_server_config +from jetstream.entrypoints.http.protocol import DecodeRequest +from jetstream.entrypoints.http.utils import proto_to_json_generator flags.DEFINE_string("host", "0.0.0.0", "server host address") flags.DEFINE_integer("port", 8080, "http server port") @@ -37,9 +43,9 @@ "prometheus_port", ) -driver: orchestrator.Driver +llm_orchestrator: orchestrator.LLMOrchestrator -# Define Fast API endpoints (use driver to handle). +# Define Fast API endpoints (use llm_orchestrator to handle). router = APIRouter() @@ -52,11 +58,25 @@ def root(): ) +@router.post("/v1/generate") +async def generate(request: DecodeRequest): + proto_request = Parse( + request.model_dump_json(), jetstream_pb2.DecodeRequest() + ) + generator = llm_orchestrator.Decode(proto_request) + return StreamingResponse( + content=proto_to_json_generator(generator), media_type="text/event-stream" + ) + + @router.get("/v1/health") async def health() -> Response: """Health check.""" + response = await llm_orchestrator.HealthCheck( + jetstream_pb2.HealthCheckRequest() + ) return Response( - content=json.dumps({"is_live": str(driver.live)}, indent=4), + content=json.dumps({"is_live": str(response.is_live)}, indent=4), media_type="application/json", status_code=200, ) @@ -67,10 +87,10 @@ def server(argv: Sequence[str]): app = fastapi.FastAPI() app.include_router(router) - # Init driver which would be the main handler in the api endpoints. + # Init LLMOrchestrator which would be the main handler in the api endpoints. devices = server_lib.get_devices() print(f"devices: {devices}") - server_config = get_server_config(flags.FLAGS.config, argv) + server_config = get_server_config(flags.FLAGS.config) print(f"server_config: {server_config}") del argv @@ -80,11 +100,13 @@ def server(argv: Sequence[str]): port=flags.FLAGS.prometheus_port ) - global driver - driver = server_lib.create_driver( - config=server_config, - devices=devices, - metrics_server_config=metrics_server_config, + global llm_orchestrator + llm_orchestrator = orchestrator.LLMOrchestrator( + driver=server_lib.create_driver( + config=server_config, + devices=devices, + metrics_server_config=metrics_server_config, + ) ) # Start uvicorn http server. @@ -95,4 +117,4 @@ def server(argv: Sequence[str]): if __name__ == "__main__": # Run Abseil app w flags parser. - app.run(server) + abslapp.run(server) diff --git a/jetstream/entrypoints/http/protocol.py b/jetstream/entrypoints/http/protocol.py new file mode 100644 index 00000000..4a7437aa --- /dev/null +++ b/jetstream/entrypoints/http/protocol.py @@ -0,0 +1,36 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Http API server protocol.""" + +from pydantic import BaseModel + + +class TextContent(BaseModel): + text: str + + +class TokenContent(BaseModel): + token_ids: list[int] + + +class DecodeRequest(BaseModel): + max_tokens: int + text_content: TextContent | None = None + token_content: TokenContent | None = None + + # Config to enforce the oneof behavior at runtime. + class Config: + extra = "forbid" # Prevent extra fields. + anystr_strip_whitespace = True diff --git a/jetstream/entrypoints/http/utils.py b/jetstream/entrypoints/http/utils.py new file mode 100644 index 00000000..7765a785 --- /dev/null +++ b/jetstream/entrypoints/http/utils.py @@ -0,0 +1,27 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Http API server utilities.""" + +from google.protobuf.json_format import MessageToJson + + +async def proto_to_json_generator(proto_generator): + """Wraps a generator yielding Protocol Buffer messages into a generator + + yielding JSON messages. + """ + async for proto_message in proto_generator: + json_string = MessageToJson(proto_message) + yield json_string From a8f25631dfb0fd3b7f95ee2be61489bb40292262 Mon Sep 17 00:00:00 2001 From: Zijun Zhou Date: Thu, 18 Jul 2024 20:45:11 +0000 Subject: [PATCH 3/7] Add unit tests --- jetstream/tests/entrypoints/__init__.py | 14 ++++ jetstream/tests/entrypoints/http/__init__.py | 14 ++++ .../tests/entrypoints/http/test_api_server.py | 84 +++++++++++++++++++ 3 files changed, 112 insertions(+) create mode 100644 jetstream/tests/entrypoints/__init__.py create mode 100644 jetstream/tests/entrypoints/http/__init__.py create mode 100644 jetstream/tests/entrypoints/http/test_api_server.py diff --git a/jetstream/tests/entrypoints/__init__.py b/jetstream/tests/entrypoints/__init__.py new file mode 100644 index 00000000..c38dc3b1 --- /dev/null +++ b/jetstream/tests/entrypoints/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/jetstream/tests/entrypoints/http/__init__.py b/jetstream/tests/entrypoints/http/__init__.py new file mode 100644 index 00000000..c38dc3b1 --- /dev/null +++ b/jetstream/tests/entrypoints/http/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/jetstream/tests/entrypoints/http/test_api_server.py b/jetstream/tests/entrypoints/http/test_api_server.py new file mode 100644 index 00000000..e6d42e58 --- /dev/null +++ b/jetstream/tests/entrypoints/http/test_api_server.py @@ -0,0 +1,84 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests http server end-to-end.""" + +import subprocess +import sys +import time +import unittest + + +import requests + + +class HTTPServerTest(unittest.IsolatedAsyncioTestCase): + + @classmethod + def setUpClass(cls): + """Sets up a JetStream http server for unit tests.""" + cls.base_url = "http://localhost:8080" + cls.server = subprocess.Popen( + [ + "python", + "-m", + "jetstream.entrypoints.http.api_server", + "--config=InterleavedCPUTestServer", + ], + stdout=sys.stdout, + stderr=sys.stderr, + ) + time.sleep(10) + + @classmethod + def tearDownClass(cls): + """Stop the server gracefully.""" + cls.server.terminate() + + async def test_root_endpoint(self): + response = requests.get(self.base_url + "/", timeout=5) + assert response.status_code == 200 + expected_data = {"message": "JetStream HTTP Server"} + assert response.json() == expected_data + + async def test_health_endpoint(self): + response = requests.get(self.base_url + "/v1/health", timeout=5) + assert response.status_code == 200 + data = response.json() + assert "is_live" in data + assert data["is_live"] == "True" + + async def test_generate_endpoint(self): + # Prepare a sample request (replace with actual data) + sample_request_data = { + "max_tokens": 10, + "text_content": {"text": "translate this to french: hello world"}, + } + + response = requests.post( + self.base_url + "/v1/generate", + json=sample_request_data, + stream=True, + timeout=5, + ) + assert response.status_code == 200 + full_response = [] + for chunk in response.iter_content( + chunk_size=None + ): # chunk_size=None for complete lines + if chunk: + stream_response = chunk.decode("utf-8") + print(f"{stream_response=}") + full_response.append(stream_response) + assert len(full_response) == 11 # 10 tokens + eos token From cc7316c18e4a411270eedc62a39ccada8bc4e542 Mon Sep 17 00:00:00 2001 From: Zijun Zhou Date: Thu, 18 Jul 2024 20:52:18 +0000 Subject: [PATCH 4/7] format & deps --- jetstream/tests/entrypoints/__init__.py | 1 - jetstream/tests/entrypoints/http/__init__.py | 1 - requirements.txt | 31 ++++++++++++++++++-- 3 files changed, 28 insertions(+), 5 deletions(-) diff --git a/jetstream/tests/entrypoints/__init__.py b/jetstream/tests/entrypoints/__init__.py index c38dc3b1..6d5e14bc 100644 --- a/jetstream/tests/entrypoints/__init__.py +++ b/jetstream/tests/entrypoints/__init__.py @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/jetstream/tests/entrypoints/http/__init__.py b/jetstream/tests/entrypoints/http/__init__.py index c38dc3b1..6d5e14bc 100644 --- a/jetstream/tests/entrypoints/http/__init__.py +++ b/jetstream/tests/entrypoints/http/__init__.py @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - diff --git a/requirements.txt b/requirements.txt index 6bce9a98..67e31fdd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,6 +19,10 @@ absl-py==1.4.0 # tensorflow # tensorflow-metadata # tfds-nightly +anyio==3.7.1 + # via + # fastapi + # starlette array-record==0.5.0 # via tfds-nightly astunparse==1.6.3 @@ -34,7 +38,9 @@ charset-normalizer==3.3.2 chex==0.1.7 # via optax click==8.1.7 - # via tfds-nightly + # via + # tfds-nightly + # uvicorn clu==0.0.10 # via seqio contextlib2==21.6.0 @@ -56,7 +62,11 @@ etils[array-types,enp,epath,epy,etqdm,etree]==1.6.0 # orbax-checkpoint # tfds-nightly exceptiongroup==1.2.0 - # via pytest + # via + # anyio + # pytest +fastapi==0.103.2 + # via -r requirements.in filelock==3.14.0 # via blobfile flatbuffers==23.5.26 @@ -86,10 +96,14 @@ grpcio==1.60.1 # tensorflow gviz-api==1.10.0 # via tensorboard-plugin-profile +h11==0.14.0 + # via uvicorn h5py==3.10.0 # via tensorflow idna==3.7 - # via requests + # via + # anyio + # requests importlib-resources==6.1.1 # via etils iniconfig==2.0.0 @@ -208,6 +222,8 @@ pyasn1-modules==0.3.0 # via google-auth pycryptodomex==3.20.0 # via blobfile +pydantic==1.10.17 + # via fastapi pyglove==0.4.4 # via seqio pygments==2.17.2 @@ -252,6 +268,10 @@ six==1.16.0 # promise # tensorboard-plugin-profile # tensorflow +sniffio==1.3.1 + # via anyio +starlette==0.27.0 + # via fastapi tensorboard==2.13.0 # via tensorflow tensorboard-data-server==0.7.2 @@ -299,13 +319,18 @@ typing-extensions==4.5.0 # chex # clu # etils + # fastapi # flax # orbax-checkpoint + # pydantic # tensorflow + # uvicorn urllib3==2.2.2 # via # blobfile # requests +uvicorn==0.30.1 + # via -r requirements.in werkzeug==3.0.1 # via # tensorboard From e485bc60e94cf7c86fb6ebc58f3b5cde4b9ed13d Mon Sep 17 00:00:00 2001 From: Zijun Zhou Date: Thu, 18 Jul 2024 21:27:43 +0000 Subject: [PATCH 5/7] type & lint --- jetstream/entrypoints/http/api_server.py | 4 +--- jetstream/entrypoints/http/protocol.py | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/jetstream/entrypoints/http/api_server.py b/jetstream/entrypoints/http/api_server.py index 5f8a8087..f876e0ab 100644 --- a/jetstream/entrypoints/http/api_server.py +++ b/jetstream/entrypoints/http/api_server.py @@ -60,9 +60,7 @@ def root(): @router.post("/v1/generate") async def generate(request: DecodeRequest): - proto_request = Parse( - request.model_dump_json(), jetstream_pb2.DecodeRequest() - ) + proto_request = Parse(request.json(), jetstream_pb2.DecodeRequest()) generator = llm_orchestrator.Decode(proto_request) return StreamingResponse( content=proto_to_json_generator(generator), media_type="text/event-stream" diff --git a/jetstream/entrypoints/http/protocol.py b/jetstream/entrypoints/http/protocol.py index 4a7437aa..fb003386 100644 --- a/jetstream/entrypoints/http/protocol.py +++ b/jetstream/entrypoints/http/protocol.py @@ -14,7 +14,7 @@ """Http API server protocol.""" -from pydantic import BaseModel +from pydantic import BaseModel # type: ignore class TextContent(BaseModel): From acc64da8a11928f298e63cca3a6215db63f5d7f6 Mon Sep 17 00:00:00 2001 From: Zijun Zhou Date: Mon, 22 Jul 2024 23:01:31 +0000 Subject: [PATCH 6/7] Merge refactor --- jetstream/core/server_lib.py | 36 ++++++++++++++++-------------------- 1 file changed, 16 insertions(+), 20 deletions(-) diff --git a/jetstream/core/server_lib.py b/jetstream/core/server_lib.py index 1bccb109..22180f09 100644 --- a/jetstream/core/server_lib.py +++ b/jetstream/core/server_lib.py @@ -97,7 +97,7 @@ def create_driver( config: Type[config_lib.ServerConfig], devices: Any, jax_padding: bool = True, - metrics_server_config: config_lib.MetricsServerConfig | None = None, + metrics_collector: JetstreamMetricsCollector | None = None, enable_model_warmup: bool = False, ): """Creates a driver with a specified config. @@ -106,16 +106,12 @@ def create_driver( config: A ServerConfig to config engine, model, device slices, etc. devices: Device objects, will be used to get engine with proper slicing. jax_padding: The flag to enable JAX padding during tokenization. - metrics_server_config: The config to enable Promethus metric server. + metrics_collector: The JetStream Promethus metric collector. enable_model_warmup: The flag to enable model server warmup with AOT. Returns: An orchestrator driver. """ - - server_start_time = time.time() - - logging.info("Kicking off gRPC server.") engines = config_lib.get_engines(config, devices=devices) prefill_params = [pe.load_params() for pe in engines.prefill_engines] generate_params = [ge.load_params() for ge in engines.generate_engines] @@ -125,19 +121,6 @@ def create_driver( len(config.prefill_slices) + len(config.generate_slices) == 0 ) - # Setup Prometheus server - metrics_collector: JetstreamMetricsCollector = None - if metrics_server_config and metrics_server_config.port: - logging.info( - "Starting Prometheus server on port %d", metrics_server_config.port - ) - start_http_server(metrics_server_config.port) - metrics_collector = JetstreamMetricsCollector() - else: - logging.info( - "Not starting Prometheus server: --prometheus_port flag not set" - ) - prefill_engines = engines.prefill_engines + engines.interleaved_engines generate_engines = engines.generate_engines + engines.interleaved_engines prefill_params = prefill_params + shared_params @@ -213,10 +196,23 @@ def run( Returns: JetStreamServer that wraps the grpc server and orchestrator driver. """ + server_start_time = time.time() logging.info("Kicking off gRPC server.") + # Setup Prometheus server + metrics_collector: JetstreamMetricsCollector = None + if metrics_server_config and metrics_server_config.port: + logging.info( + "Starting Prometheus server on port %d", metrics_server_config.port + ) + start_http_server(metrics_server_config.port) + metrics_collector = JetstreamMetricsCollector() + else: + logging.info( + "Not starting Prometheus server: --prometheus_port flag not set" + ) driver = create_driver( - config, devices, jax_padding, metrics_server_config, enable_model_warmup + config, devices, jax_padding, metrics_collector, enable_model_warmup ) # We default threads to the total number of concurrent allowed decodes, # to make sure we can fully saturate the model. Set default minimum to 64. From d253afdfa1644ae7d76a1d00a5486e327d835f0c Mon Sep 17 00:00:00 2001 From: Zijun Zhou Date: Mon, 22 Jul 2024 23:38:40 +0000 Subject: [PATCH 7/7] fix refactor --- jetstream/entrypoints/http/api_server.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/jetstream/entrypoints/http/api_server.py b/jetstream/entrypoints/http/api_server.py index f876e0ab..e7dabfed 100644 --- a/jetstream/entrypoints/http/api_server.py +++ b/jetstream/entrypoints/http/api_server.py @@ -15,16 +15,19 @@ """JetStream Http API server.""" import json +import logging from typing import Sequence from absl import app as abslapp from absl import flags from fastapi import APIRouter, Response import fastapi from fastapi.responses import StreamingResponse +from prometheus_client import start_http_server import uvicorn from google.protobuf.json_format import Parse from jetstream.core import config_lib, orchestrator, server_lib +from jetstream.core.metrics.prometheus import JetstreamMetricsCollector from jetstream.core.proto import jetstream_pb2 from jetstream.entrypoints.config import get_server_config from jetstream.entrypoints.http.protocol import DecodeRequest @@ -93,17 +96,28 @@ def server(argv: Sequence[str]): del argv metrics_server_config: config_lib.MetricsServerConfig | None = None + # Setup Prometheus server + metrics_collector: JetstreamMetricsCollector = None if flags.FLAGS.prometheus_port != 0: metrics_server_config = config_lib.MetricsServerConfig( port=flags.FLAGS.prometheus_port ) + logging.info( + "Starting Prometheus server on port %d", metrics_server_config.port + ) + start_http_server(metrics_server_config.port) + metrics_collector = JetstreamMetricsCollector() + else: + logging.info( + "Not starting Prometheus server: --prometheus_port flag not set" + ) global llm_orchestrator llm_orchestrator = orchestrator.LLMOrchestrator( driver=server_lib.create_driver( config=server_config, devices=devices, - metrics_server_config=metrics_server_config, + metrics_collector=metrics_collector, ) )