Skip to content

Commit

Permalink
use a moving window to predict total runtimes
Browse files Browse the repository at this point in the history
  • Loading branch information
ric-evans committed Oct 16, 2024
1 parent 11b46ad commit d89a67a
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 65 deletions.
1 change: 1 addition & 0 deletions skymap_scanner/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ class EnvConfig:
#

SKYSCAN_PROGRESS_INTERVAL_SEC: int = 1 * 60
SKYSCAN_PROGRESS_RUNTIME_PREDICTION_WINDOW_RATIO: float = 0.1
SKYSCAN_RESULT_INTERVAL_SEC: int = 2 * 60

SKYSCAN_KILL_SWITCH_CHECK_INTERVAL: int = 5 * 60
Expand Down
153 changes: 88 additions & 65 deletions skymap_scanner/server/reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import math
import statistics
import time
from collections import deque
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional

Expand Down Expand Up @@ -38,25 +37,26 @@ class WorkerStats:

def __init__(
self,
worker_runtimes: Optional[List[float]] = None,
roundtrips: Optional[List[float]] = None,
on_server_roundtrip_start: float = float("inf"),
on_server_roundtrip_end: float = float("-inf"),
ends: Optional[List[float]] = None,
on_worker_runtimes: list[float],
on_server_roundtrips: list[float],
on_server_roundtrip_starts: list[float],
on_server_roundtrip_ends: list[float],
) -> None:
self.on_server_roundtrip_start = on_server_roundtrip_start
self.on_server_roundtrip_end = on_server_roundtrip_end

self.on_worker_runtimes: List[float] = (
worker_runtimes if worker_runtimes else []
)
self.on_worker_runtimes = on_worker_runtimes
self.on_worker_runtimes.sort() # speed up stats
#
self.on_server_roundtrips: List[float] = roundtrips if roundtrips else []
self.on_server_roundtrips = on_server_roundtrips
self.on_server_roundtrips.sort() # speed up stats
#
self.ends: List[float] = ends if ends else []
self.ends.sort() # speed up stats
self.on_server_roundtrip_starts: List[float] = on_server_roundtrip_starts
self.on_server_roundtrip_starts.sort() # speed up stats
self.on_server_first_roundtrip_start = lambda: min(
self.on_server_roundtrip_starts
)
#
self.on_server_roundtrip_ends = on_server_roundtrip_ends
self.on_server_roundtrip_ends.sort() # speed up stats
self.on_server_last_roundtrip_end = lambda: max(self.on_server_roundtrip_ends)

self.fastest_worker = lambda: min(self.on_worker_runtimes)
self.fastest_roundtrip = lambda: min(self.on_server_roundtrips)
Expand Down Expand Up @@ -102,14 +102,12 @@ def update(
on_server_roundtrip_end - on_server_roundtrip_start,
)
bisect.insort(
self.ends,
on_server_roundtrip_end,
)
self.on_server_roundtrip_start = min(
self.on_server_roundtrip_start, on_server_roundtrip_start
self.on_server_roundtrip_starts,
on_server_roundtrip_start,
)
self.on_server_roundtrip_end = max(
self.on_server_roundtrip_end, on_server_roundtrip_end
bisect.insort(
self.on_server_roundtrip_ends,
on_server_roundtrip_end,
)
return self

Expand Down Expand Up @@ -147,25 +145,27 @@ def get_summary(self) -> Dict[str, Dict[str, str]]:
},
"wall time": {
"start": str(
dt.datetime.fromtimestamp(int(self.on_server_roundtrip_start))
dt.datetime.fromtimestamp(
int(self.on_server_first_roundtrip_start())
)
),
"end": str(
dt.datetime.fromtimestamp(int(self.on_server_roundtrip_end))
dt.datetime.fromtimestamp(int(self.on_server_last_roundtrip_end()))
),
"runtime": str(
dt.timedelta(
seconds=int(
self.on_server_roundtrip_end
- self.on_server_roundtrip_start
self.on_server_last_roundtrip_end()
- self.on_server_first_roundtrip_start()
)
)
),
"mean reco": str(
dt.timedelta(
seconds=int(
(
self.on_server_roundtrip_end
- self.on_server_roundtrip_start
self.on_server_last_roundtrip_end()
- self.on_server_first_roundtrip_start()
)
/ len(self.on_worker_runtimes)
)
Expand All @@ -181,7 +181,24 @@ class WorkerStatsCollection:
def __init__(self) -> None:
self._worker_stats_by_nside: Dict[int, WorkerStats] = {}
self._aggregate: Optional[WorkerStats] = None
self.recent_runtimes = deque(maxlen=100)

def on_server_recent_reco_per_sec_rate(self, window_size: int) -> float:
"""The reco/sec rate from server pov within a moving window."""

# look at a window, so don't use the first start time
try:
# psst, we know that this list is sorted, ascending
nth_most_recent_start = self._aggregate.on_server_roundtrip_starts[
-window_size
]
n_recos = window_size
except IndexError:
nth_most_recent_start = self._aggregate.on_server_first_roundtrip_start()
n_recos = len(self._aggregate.on_worker_runtimes)

return n_recos / (
self._aggregate.on_server_last_roundtrip_end() - nth_most_recent_start
)

def ct_by_nside(self, nside: int) -> int:
"""Get length per given nside."""
Expand All @@ -199,10 +216,8 @@ def total_ct(self) -> int:

@property
def first_roundtrip_start(self) -> float:
"""O(n), n < 10."""
return min(
w.on_server_roundtrip_start for w in self._worker_stats_by_nside.values()
)
"""Get the first roundtrip start time from server pov."""
return self._aggregate.on_server_first_roundtrip_start()

def update(
self,
Expand All @@ -215,14 +230,20 @@ def update(
self._aggregate = None # clear
try:
worker_stats = self._worker_stats_by_nside[nside]
worker_stats.update(
on_worker_runtime,
on_server_roundtrip_start,
on_server_roundtrip_end,
)
except KeyError:
worker_stats = self._worker_stats_by_nside[nside] = WorkerStats()
worker_stats.update(
on_worker_runtime,
on_server_roundtrip_start,
on_server_roundtrip_end,
)
self.recent_runtimes.append(on_worker_runtime)
worker_stats = self._worker_stats_by_nside[nside] = WorkerStats(
on_worker_runtimes=[on_worker_runtime],
on_server_roundtrips=[
on_server_roundtrip_end - on_server_roundtrip_start
],
on_server_roundtrip_starts=[on_server_roundtrip_start],
on_server_roundtrip_ends=[on_server_roundtrip_end],
)
return len(worker_stats.on_worker_runtimes)

@property
Expand All @@ -231,21 +252,20 @@ def aggregate(self) -> WorkerStats:
if not self._aggregate:
instances = self._worker_stats_by_nside.values()
if not instances:
return WorkerStats()
return WorkerStats([], [], [], [])
self._aggregate = WorkerStats(
worker_runtimes=list(
on_worker_runtimes=list(
itertools.chain(*[i.on_worker_runtimes for i in instances])
),
roundtrips=list(
on_server_roundtrips=list(
itertools.chain(*[i.on_server_roundtrips for i in instances])
),
on_server_roundtrip_start=min(
i.on_server_roundtrip_start for i in instances
on_server_roundtrip_starts=list(
itertools.chain(*[i.on_server_roundtrip_starts for i in instances])
),
on_server_roundtrip_end=max(
i.on_server_roundtrip_end for i in instances
on_server_roundtrip_ends=list(
itertools.chain(*[i.on_server_roundtrip_ends for i in instances])
),
ends=list(itertools.chain(*[i.ends for i in instances])),
)
return self._aggregate

Expand Down Expand Up @@ -508,28 +528,29 @@ def _get_processing_progress(self) -> StrDict:
proc_stats["finished"] = True
else:
# MAKE PREDICTIONS
# NOTE: this is a simple mean, may want to visit more sophisticated methods
secs_predicted = elapsed_reco_server_walltime / (
self.worker_stats_collection.total_ct / self.predicted_total_recos()
n_recos_left = (
self.predicted_total_recos() - self.worker_stats_collection.total_ct
)
proc_stats["predictions"] = {
"time left": str(
dt.timedelta(
seconds=int(secs_predicted - elapsed_reco_server_walltime)
time_left = ( # this uses a moving window average
self.worker_stats_collection.on_server_recent_reco_per_sec_rate(
window_size=int(
self.worker_stats_collection.total_ct
* cfg.ENV.SKYSCAN_PROGRESS_RUNTIME_PREDICTION_WINDOW_RATIO
)
),
)
* n_recos_left
)
proc_stats["predictions"] = {
"time left": str(dt.timedelta(seconds=int(time_left))),
"total runtime at finish": str(
dt.timedelta(seconds=int(secs_predicted + startup_runtime))
),
"total # of reconstructions": self.predicted_total_recos(),
"end": str(
dt.datetime.fromtimestamp(
int(
time.time()
+ (secs_predicted - elapsed_reco_server_walltime)
dt.timedelta(
seconds=int(
startup_runtime + elapsed_reco_server_walltime + time_left
)
)
),
"total # of reconstructions": self.predicted_total_recos(),
"end": str(dt.datetime.fromtimestamp(int(time.time() + time_left))),
}

return proc_stats
Expand Down Expand Up @@ -588,7 +609,9 @@ def pixels_percent_string(nside: int) -> str:
index = math.ceil(predicted_total * i) - 1
name = str(i)
try:
when = self.worker_stats_collection.aggregate.ends[index]
when = self.worker_stats_collection.aggregate.on_server_roundtrip_ends[
index
]
timeline[name] = str(
dt.timedelta(seconds=int(when - self.global_start))
)
Expand Down

0 comments on commit d89a67a

Please sign in to comment.