diff --git a/jetstream/core/server_lib.py b/jetstream/core/server_lib.py index 1bccb109..22180f09 100644 --- a/jetstream/core/server_lib.py +++ b/jetstream/core/server_lib.py @@ -97,7 +97,7 @@ def create_driver( config: Type[config_lib.ServerConfig], devices: Any, jax_padding: bool = True, - metrics_server_config: config_lib.MetricsServerConfig | None = None, + metrics_collector: JetstreamMetricsCollector | None = None, enable_model_warmup: bool = False, ): """Creates a driver with a specified config. @@ -106,16 +106,12 @@ def create_driver( config: A ServerConfig to config engine, model, device slices, etc. devices: Device objects, will be used to get engine with proper slicing. jax_padding: The flag to enable JAX padding during tokenization. - metrics_server_config: The config to enable Promethus metric server. + metrics_collector: The JetStream Promethus metric collector. enable_model_warmup: The flag to enable model server warmup with AOT. Returns: An orchestrator driver. """ - - server_start_time = time.time() - - logging.info("Kicking off gRPC server.") engines = config_lib.get_engines(config, devices=devices) prefill_params = [pe.load_params() for pe in engines.prefill_engines] generate_params = [ge.load_params() for ge in engines.generate_engines] @@ -125,19 +121,6 @@ def create_driver( len(config.prefill_slices) + len(config.generate_slices) == 0 ) - # Setup Prometheus server - metrics_collector: JetstreamMetricsCollector = None - if metrics_server_config and metrics_server_config.port: - logging.info( - "Starting Prometheus server on port %d", metrics_server_config.port - ) - start_http_server(metrics_server_config.port) - metrics_collector = JetstreamMetricsCollector() - else: - logging.info( - "Not starting Prometheus server: --prometheus_port flag not set" - ) - prefill_engines = engines.prefill_engines + engines.interleaved_engines generate_engines = engines.generate_engines + engines.interleaved_engines prefill_params = prefill_params + shared_params @@ -213,10 +196,23 @@ def run( Returns: JetStreamServer that wraps the grpc server and orchestrator driver. """ + server_start_time = time.time() logging.info("Kicking off gRPC server.") + # Setup Prometheus server + metrics_collector: JetstreamMetricsCollector = None + if metrics_server_config and metrics_server_config.port: + logging.info( + "Starting Prometheus server on port %d", metrics_server_config.port + ) + start_http_server(metrics_server_config.port) + metrics_collector = JetstreamMetricsCollector() + else: + logging.info( + "Not starting Prometheus server: --prometheus_port flag not set" + ) driver = create_driver( - config, devices, jax_padding, metrics_server_config, enable_model_warmup + config, devices, jax_padding, metrics_collector, enable_model_warmup ) # We default threads to the total number of concurrent allowed decodes, # to make sure we can fully saturate the model. Set default minimum to 64.