diff --git a/jetstream/core/server_lib.py b/jetstream/core/server_lib.py index 24f8506c..22180f09 100644 --- a/jetstream/core/server_lib.py +++ b/jetstream/core/server_lib.py @@ -93,40 +93,25 @@ def wait_for_termination(self) -> None: self.stop() -def run( - port: int, +def create_driver( config: Type[config_lib.ServerConfig], devices: Any, - credentials: Any = grpc.insecure_server_credentials(), - threads: int | None = None, jax_padding: bool = True, - metrics_server_config: config_lib.MetricsServerConfig | None = None, - enable_jax_profiler: bool = False, - jax_profiler_port: int = 9999, + metrics_collector: JetstreamMetricsCollector | None = None, enable_model_warmup: bool = False, -) -> JetStreamServer: - """Runs a server with a specified config. +): + """Creates a driver with a specified config. Args: - port: Port on which the server will be made available. config: A ServerConfig to config engine, model, device slices, etc. devices: Device objects, will be used to get engine with proper slicing. - credentials: Should use grpc credentials by default. - threads: Number of RPC handlers worker threads. This should be at least - equal to the decoding batch size to fully saturate the decoding queue. jax_padding: The flag to enable JAX padding during tokenization. - metrics_server_config: The config to enable Promethus metric server. - enable_jax_profiler: The flag to enable JAX profiler server. - jax_profiler_port: The port JAX profiler server (default to 9999). + metrics_collector: The JetStream Promethus metric collector. enable_model_warmup: The flag to enable model server warmup with AOT. Returns: - JetStreamServer that wraps the grpc server and orchestrator driver. + An orchestrator driver. """ - - server_start_time = time.time() - - logging.info("Kicking off gRPC server.") engines = config_lib.get_engines(config, devices=devices) prefill_params = [pe.load_params() for pe in engines.prefill_engines] generate_params = [ge.load_params() for ge in engines.generate_engines] @@ -136,19 +121,6 @@ def run( len(config.prefill_slices) + len(config.generate_slices) == 0 ) - # Setup Prometheus server - metrics_collector: JetstreamMetricsCollector = None - if metrics_server_config and metrics_server_config.port: - logging.info( - "Starting Prometheus server on port %d", metrics_server_config.port - ) - start_http_server(metrics_server_config.port) - metrics_collector = JetstreamMetricsCollector() - else: - logging.info( - "Not starting Prometheus server: --prometheus_port flag not set" - ) - prefill_engines = engines.prefill_engines + engines.interleaved_engines generate_engines = engines.generate_engines + engines.interleaved_engines prefill_params = prefill_params + shared_params @@ -182,7 +154,7 @@ def run( traceback.print_exc() os.kill(os.getpid(), signal.SIGKILL) - driver = orchestrator.Driver( + return orchestrator.Driver( prefill_engines=prefill_engines, generate_engines=generate_engines, prefill_params=prefill_params, @@ -192,6 +164,56 @@ def run( metrics_collector=metrics_collector, is_ray_backend=config.is_ray_backend, ) + + +def run( + port: int, + config: Type[config_lib.ServerConfig], + devices: Any, + credentials: Any = grpc.insecure_server_credentials(), + threads: int | None = None, + jax_padding: bool = True, + metrics_server_config: config_lib.MetricsServerConfig | None = None, + enable_jax_profiler: bool = False, + jax_profiler_port: int = 9999, + enable_model_warmup: bool = False, +) -> JetStreamServer: + """Runs a server with a specified config. + + Args: + port: Port on which the server will be made available. + config: A ServerConfig to config engine, model, device slices, etc. + devices: Device objects, will be used to get engine with proper slicing. + credentials: Should use grpc credentials by default. + threads: Number of RPC handlers worker threads. This should be at least + equal to the decoding batch size to fully saturate the decoding queue. + jax_padding: The flag to enable JAX padding during tokenization. + metrics_server_config: The config to enable Promethus metric server. + enable_jax_profiler: The flag to enable JAX profiler server. + jax_profiler_port: The port JAX profiler server (default to 9999). + enable_model_warmup: The flag to enable model server warmup with AOT. + + Returns: + JetStreamServer that wraps the grpc server and orchestrator driver. + """ + server_start_time = time.time() + logging.info("Kicking off gRPC server.") + # Setup Prometheus server + metrics_collector: JetstreamMetricsCollector = None + if metrics_server_config and metrics_server_config.port: + logging.info( + "Starting Prometheus server on port %d", metrics_server_config.port + ) + start_http_server(metrics_server_config.port) + metrics_collector = JetstreamMetricsCollector() + else: + logging.info( + "Not starting Prometheus server: --prometheus_port flag not set" + ) + + driver = create_driver( + config, devices, jax_padding, metrics_collector, enable_model_warmup + ) # We default threads to the total number of concurrent allowed decodes, # to make sure we can fully saturate the model. Set default minimum to 64. threads = threads or max(driver.get_total_concurrent_requests(), 64) diff --git a/jetstream/entrypoints/__init__.py b/jetstream/entrypoints/__init__.py new file mode 100644 index 00000000..6d5e14bc --- /dev/null +++ b/jetstream/entrypoints/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/jetstream/entrypoints/config.py b/jetstream/entrypoints/config.py new file mode 100644 index 00000000..79f2b012 --- /dev/null +++ b/jetstream/entrypoints/config.py @@ -0,0 +1,32 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Config for JetStream Server (including engine init).""" + +from typing import Type + +from jetstream.core import config_lib + + +def get_server_config( + config_str: str, +) -> config_lib.ServerConfig | Type[config_lib.ServerConfig]: + match config_str: + case "InterleavedCPUTestServer": + server_config = config_lib.InterleavedCPUTestServer + case "CPUTestServer": + server_config = config_lib.CPUTestServer + case _: + raise NotImplementedError + return server_config diff --git a/jetstream/entrypoints/http/__init__.py b/jetstream/entrypoints/http/__init__.py new file mode 100644 index 00000000..6d5e14bc --- /dev/null +++ b/jetstream/entrypoints/http/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/jetstream/entrypoints/http/api_server.py b/jetstream/entrypoints/http/api_server.py new file mode 100644 index 00000000..e7dabfed --- /dev/null +++ b/jetstream/entrypoints/http/api_server.py @@ -0,0 +1,132 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""JetStream Http API server.""" + +import json +import logging +from typing import Sequence +from absl import app as abslapp +from absl import flags +from fastapi import APIRouter, Response +import fastapi +from fastapi.responses import StreamingResponse +from prometheus_client import start_http_server +import uvicorn +from google.protobuf.json_format import Parse + +from jetstream.core import config_lib, orchestrator, server_lib +from jetstream.core.metrics.prometheus import JetstreamMetricsCollector +from jetstream.core.proto import jetstream_pb2 +from jetstream.entrypoints.config import get_server_config +from jetstream.entrypoints.http.protocol import DecodeRequest +from jetstream.entrypoints.http.utils import proto_to_json_generator + +flags.DEFINE_string("host", "0.0.0.0", "server host address") +flags.DEFINE_integer("port", 8080, "http server port") +flags.DEFINE_string( + "config", + "InterleavedCPUTestServer", + "available servers", +) +flags.DEFINE_integer( + "prometheus_port", + 9988, + "prometheus_port", +) + +llm_orchestrator: orchestrator.LLMOrchestrator + +# Define Fast API endpoints (use llm_orchestrator to handle). +router = APIRouter() + + +@router.get("/") +def root(): + """Root path for Jetstream HTTP Server.""" + return Response( + content=json.dumps({"message": "JetStream HTTP Server"}, indent=4), + media_type="application/json", + ) + + +@router.post("/v1/generate") +async def generate(request: DecodeRequest): + proto_request = Parse(request.json(), jetstream_pb2.DecodeRequest()) + generator = llm_orchestrator.Decode(proto_request) + return StreamingResponse( + content=proto_to_json_generator(generator), media_type="text/event-stream" + ) + + +@router.get("/v1/health") +async def health() -> Response: + """Health check.""" + response = await llm_orchestrator.HealthCheck( + jetstream_pb2.HealthCheckRequest() + ) + return Response( + content=json.dumps({"is_live": str(response.is_live)}, indent=4), + media_type="application/json", + status_code=200, + ) + + +def server(argv: Sequence[str]): + # Init Fast API. + app = fastapi.FastAPI() + app.include_router(router) + + # Init LLMOrchestrator which would be the main handler in the api endpoints. + devices = server_lib.get_devices() + print(f"devices: {devices}") + server_config = get_server_config(flags.FLAGS.config) + print(f"server_config: {server_config}") + del argv + + metrics_server_config: config_lib.MetricsServerConfig | None = None + # Setup Prometheus server + metrics_collector: JetstreamMetricsCollector = None + if flags.FLAGS.prometheus_port != 0: + metrics_server_config = config_lib.MetricsServerConfig( + port=flags.FLAGS.prometheus_port + ) + logging.info( + "Starting Prometheus server on port %d", metrics_server_config.port + ) + start_http_server(metrics_server_config.port) + metrics_collector = JetstreamMetricsCollector() + else: + logging.info( + "Not starting Prometheus server: --prometheus_port flag not set" + ) + + global llm_orchestrator + llm_orchestrator = orchestrator.LLMOrchestrator( + driver=server_lib.create_driver( + config=server_config, + devices=devices, + metrics_collector=metrics_collector, + ) + ) + + # Start uvicorn http server. + uvicorn.run( + app, host=flags.FLAGS.host, port=flags.FLAGS.port, log_level="info" + ) + + +if __name__ == "__main__": + # Run Abseil app w flags parser. + abslapp.run(server) diff --git a/jetstream/entrypoints/http/protocol.py b/jetstream/entrypoints/http/protocol.py new file mode 100644 index 00000000..fb003386 --- /dev/null +++ b/jetstream/entrypoints/http/protocol.py @@ -0,0 +1,36 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Http API server protocol.""" + +from pydantic import BaseModel # type: ignore + + +class TextContent(BaseModel): + text: str + + +class TokenContent(BaseModel): + token_ids: list[int] + + +class DecodeRequest(BaseModel): + max_tokens: int + text_content: TextContent | None = None + token_content: TokenContent | None = None + + # Config to enforce the oneof behavior at runtime. + class Config: + extra = "forbid" # Prevent extra fields. + anystr_strip_whitespace = True diff --git a/jetstream/entrypoints/http/utils.py b/jetstream/entrypoints/http/utils.py new file mode 100644 index 00000000..7765a785 --- /dev/null +++ b/jetstream/entrypoints/http/utils.py @@ -0,0 +1,27 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Http API server utilities.""" + +from google.protobuf.json_format import MessageToJson + + +async def proto_to_json_generator(proto_generator): + """Wraps a generator yielding Protocol Buffer messages into a generator + + yielding JSON messages. + """ + async for proto_message in proto_generator: + json_string = MessageToJson(proto_message) + yield json_string diff --git a/jetstream/tests/entrypoints/__init__.py b/jetstream/tests/entrypoints/__init__.py new file mode 100644 index 00000000..6d5e14bc --- /dev/null +++ b/jetstream/tests/entrypoints/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/jetstream/tests/entrypoints/http/__init__.py b/jetstream/tests/entrypoints/http/__init__.py new file mode 100644 index 00000000..6d5e14bc --- /dev/null +++ b/jetstream/tests/entrypoints/http/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/jetstream/tests/entrypoints/http/test_api_server.py b/jetstream/tests/entrypoints/http/test_api_server.py new file mode 100644 index 00000000..e6d42e58 --- /dev/null +++ b/jetstream/tests/entrypoints/http/test_api_server.py @@ -0,0 +1,84 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests http server end-to-end.""" + +import subprocess +import sys +import time +import unittest + + +import requests + + +class HTTPServerTest(unittest.IsolatedAsyncioTestCase): + + @classmethod + def setUpClass(cls): + """Sets up a JetStream http server for unit tests.""" + cls.base_url = "http://localhost:8080" + cls.server = subprocess.Popen( + [ + "python", + "-m", + "jetstream.entrypoints.http.api_server", + "--config=InterleavedCPUTestServer", + ], + stdout=sys.stdout, + stderr=sys.stderr, + ) + time.sleep(10) + + @classmethod + def tearDownClass(cls): + """Stop the server gracefully.""" + cls.server.terminate() + + async def test_root_endpoint(self): + response = requests.get(self.base_url + "/", timeout=5) + assert response.status_code == 200 + expected_data = {"message": "JetStream HTTP Server"} + assert response.json() == expected_data + + async def test_health_endpoint(self): + response = requests.get(self.base_url + "/v1/health", timeout=5) + assert response.status_code == 200 + data = response.json() + assert "is_live" in data + assert data["is_live"] == "True" + + async def test_generate_endpoint(self): + # Prepare a sample request (replace with actual data) + sample_request_data = { + "max_tokens": 10, + "text_content": {"text": "translate this to french: hello world"}, + } + + response = requests.post( + self.base_url + "/v1/generate", + json=sample_request_data, + stream=True, + timeout=5, + ) + assert response.status_code == 200 + full_response = [] + for chunk in response.iter_content( + chunk_size=None + ): # chunk_size=None for complete lines + if chunk: + stream_response = chunk.decode("utf-8") + print(f"{stream_response=}") + full_response.append(stream_response) + assert len(full_response) == 11 # 10 tokens + eos token diff --git a/requirements.in b/requirements.in index 459749ae..86841a57 100644 --- a/requirements.in +++ b/requirements.in @@ -13,5 +13,7 @@ tiktoken blobfile parameterized shortuuid +fastapi +uvicorn # For profiling tensorboard-plugin-profile \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 6bce9a98..67e31fdd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,6 +19,10 @@ absl-py==1.4.0 # tensorflow # tensorflow-metadata # tfds-nightly +anyio==3.7.1 + # via + # fastapi + # starlette array-record==0.5.0 # via tfds-nightly astunparse==1.6.3 @@ -34,7 +38,9 @@ charset-normalizer==3.3.2 chex==0.1.7 # via optax click==8.1.7 - # via tfds-nightly + # via + # tfds-nightly + # uvicorn clu==0.0.10 # via seqio contextlib2==21.6.0 @@ -56,7 +62,11 @@ etils[array-types,enp,epath,epy,etqdm,etree]==1.6.0 # orbax-checkpoint # tfds-nightly exceptiongroup==1.2.0 - # via pytest + # via + # anyio + # pytest +fastapi==0.103.2 + # via -r requirements.in filelock==3.14.0 # via blobfile flatbuffers==23.5.26 @@ -86,10 +96,14 @@ grpcio==1.60.1 # tensorflow gviz-api==1.10.0 # via tensorboard-plugin-profile +h11==0.14.0 + # via uvicorn h5py==3.10.0 # via tensorflow idna==3.7 - # via requests + # via + # anyio + # requests importlib-resources==6.1.1 # via etils iniconfig==2.0.0 @@ -208,6 +222,8 @@ pyasn1-modules==0.3.0 # via google-auth pycryptodomex==3.20.0 # via blobfile +pydantic==1.10.17 + # via fastapi pyglove==0.4.4 # via seqio pygments==2.17.2 @@ -252,6 +268,10 @@ six==1.16.0 # promise # tensorboard-plugin-profile # tensorflow +sniffio==1.3.1 + # via anyio +starlette==0.27.0 + # via fastapi tensorboard==2.13.0 # via tensorflow tensorboard-data-server==0.7.2 @@ -299,13 +319,18 @@ typing-extensions==4.5.0 # chex # clu # etils + # fastapi # flax # orbax-checkpoint + # pydantic # tensorflow + # uvicorn urllib3==2.2.2 # via # blobfile # requests +uvicorn==0.30.1 + # via -r requirements.in werkzeug==3.0.1 # via # tensorboard