diff --git a/jetstream/core/metrics/prometheus.py b/jetstream/core/metrics/prometheus.py index dc8a00e9..9e37dca7 100644 --- a/jetstream/core/metrics/prometheus.py +++ b/jetstream/core/metrics/prometheus.py @@ -75,7 +75,11 @@ def __new__(cls): documentation="The percentage of decode slots currently being used", labelnames=["id", "idx"], ) - + _model_load_time = Gauge( + name="jetstream_model_load_time", + documentation="Total time taken to load the model", + labelnames=["id"], + ) _server_startup_latency = Gauge( name="jetstream_server_startup_latency", documentation="Total time taken to start the Jetstream server", @@ -232,6 +236,9 @@ def get_slots_used_percentage_metric(self, idx: int): def get_server_startup_latency_metric(self): return self._server_startup_latency.labels(id=self._id) + def get_model_load_time_metric(self): + return self._model_load_time.labels(id=self._id) + def get_time_to_first_token(self): return self._time_to_first_token.labels(id=self._id) diff --git a/jetstream/core/server_lib.py b/jetstream/core/server_lib.py index b323286a..92cb8bee 100644 --- a/jetstream/core/server_lib.py +++ b/jetstream/core/server_lib.py @@ -113,10 +113,15 @@ def create_driver( An orchestrator driver. """ engines = config_lib.get_engines(config, devices=devices) + model_load_start_time = time.time() prefill_params = [pe.load_params() for pe in engines.prefill_engines] generate_params = [ge.load_params() for ge in engines.generate_engines] shared_params = [ie.load_params() for ie in engines.interleaved_engines] logging.info("Loaded all weights.") + if metrics_collector: + metrics_collector.get_model_load_time_metric().set( + time.time() - model_load_start_time + ) interleaved_mode = ( len(config.prefill_slices) + len(config.generate_slices) == 0 )