Skip to content

Commit

Permalink
Merge refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeZijunZhou committed Jul 22, 2024
1 parent 78e8120 commit acc64da
Showing 1 changed file with 16 additions and 20 deletions.
36 changes: 16 additions & 20 deletions jetstream/core/server_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def create_driver(
config: Type[config_lib.ServerConfig],
devices: Any,
jax_padding: bool = True,
metrics_server_config: config_lib.MetricsServerConfig | None = None,
metrics_collector: JetstreamMetricsCollector | None = None,
enable_model_warmup: bool = False,
):
"""Creates a driver with a specified config.
Expand All @@ -106,16 +106,12 @@ def create_driver(
config: A ServerConfig to config engine, model, device slices, etc.
devices: Device objects, will be used to get engine with proper slicing.
jax_padding: The flag to enable JAX padding during tokenization.
metrics_server_config: The config to enable Promethus metric server.
metrics_collector: The JetStream Promethus metric collector.
enable_model_warmup: The flag to enable model server warmup with AOT.
Returns:
An orchestrator driver.
"""

server_start_time = time.time()

logging.info("Kicking off gRPC server.")
engines = config_lib.get_engines(config, devices=devices)
prefill_params = [pe.load_params() for pe in engines.prefill_engines]
generate_params = [ge.load_params() for ge in engines.generate_engines]
Expand All @@ -125,19 +121,6 @@ def create_driver(
len(config.prefill_slices) + len(config.generate_slices) == 0
)

# Setup Prometheus server
metrics_collector: JetstreamMetricsCollector = None
if metrics_server_config and metrics_server_config.port:
logging.info(
"Starting Prometheus server on port %d", metrics_server_config.port
)
start_http_server(metrics_server_config.port)
metrics_collector = JetstreamMetricsCollector()
else:
logging.info(
"Not starting Prometheus server: --prometheus_port flag not set"
)

prefill_engines = engines.prefill_engines + engines.interleaved_engines
generate_engines = engines.generate_engines + engines.interleaved_engines
prefill_params = prefill_params + shared_params
Expand Down Expand Up @@ -213,10 +196,23 @@ def run(
Returns:
JetStreamServer that wraps the grpc server and orchestrator driver.
"""
server_start_time = time.time()
logging.info("Kicking off gRPC server.")
# Setup Prometheus server
metrics_collector: JetstreamMetricsCollector = None
if metrics_server_config and metrics_server_config.port:
logging.info(
"Starting Prometheus server on port %d", metrics_server_config.port
)
start_http_server(metrics_server_config.port)
metrics_collector = JetstreamMetricsCollector()
else:
logging.info(
"Not starting Prometheus server: --prometheus_port flag not set"
)

driver = create_driver(
config, devices, jax_padding, metrics_server_config, enable_model_warmup
config, devices, jax_padding, metrics_collector, enable_model_warmup
)
# We default threads to the total number of concurrent allowed decodes,
# to make sure we can fully saturate the model. Set default minimum to 64.
Expand Down

0 comments on commit acc64da

Please sign in to comment.