Skip to content

Commit

Permalink
Improve the dashboard
Browse files Browse the repository at this point in the history
  • Loading branch information
pierre.delaunay committed Jun 26, 2024
1 parent 3c3bb8f commit 4ad1b1c
Showing 1 changed file with 55 additions and 13 deletions.
68 changes: 55 additions & 13 deletions milabench/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,11 +235,12 @@ def __init__(self):
self.console = Console()
self.live = Live(self.panel, refresh_per_second=4, console=self.console)
self.rows = defaultdict(dict)
self.benchcount = defaultdict(int)
self.endtimes = {}
self.early_stop = {}
# Limit the number of rows to avoid too much clutering
# This is a soft limit, it only prunes finished runs
self.max_rows = 2
self.max_rows = 8
self.prune_delay = 60
self.current = 0

Expand All @@ -258,10 +259,21 @@ def _update_global(self, inc):
progress = self._get_global_progress_bar()
progress.update(progress._task, completed=self.current, total=total)

def should_prune(self, tag, elasped):
bench = tag.split(".")[0]
if self.benchcount.get(bench, 0) != 0:
return False

old = (elasped > self.prune_delay)
if self.max_rows:
return old or len(self.rows) > self.max_rows

return old

def prune(self):
now = time.time()
for tag, endtime in list(self.endtimes.items()):
if (now - endtime > self.prune_delay) or len(self.rows) > self.max_rows:
if self.should_prune(tag, now - endtime):
del self.endtimes[tag]
del self.rows[tag]

Expand Down Expand Up @@ -289,7 +301,11 @@ def __call__(self, entry):
if method:
method(entry, data, row)

def on_start(self, entry, data, row):
self.benchcount[entry.tag.split('.')[0]] += 1

def on_stop(self, entry, data, row):
self.benchcount[entry.tag.split('.')[0]] -= 1
self.early_stop[entry.tag] = True

def on_end(self, entry, data, row):
Expand All @@ -310,16 +326,22 @@ def make_table(self):
table.add_column("gpu_temp", style="bold magenta")

for bench, values in self.rows.items():
table.add_row(
bench,
values.get("status", "?"),
values.get("progress", "??%"),
values.get("rate", "?"),
values.get("loss", "?"),
values.get("gpu_load", "?"),
values.get("gpu_mem", "?"),
values.get("gpu_temp", "?"),
)
if bench == "GLOBAL":
table.add_row(
bench,
values.get("progress", "??%"),
)
else:
table.add_row(
bench,
values.get("status", "?"),
values.get("progress", "??%"),
values.get("rate", "?"),
values.get("loss", "?"),
values.get("gpu_load", "?"),
values.get("gpu_mem", "?"),
values.get("gpu_temp", "?"),
)

return table

Expand Down Expand Up @@ -357,6 +379,8 @@ def on_data(self, entry, data, row):
self.refresh()

def on_start(self, entry, data, row):
super().on_start(entry, data, row)

row["status"] = Text("RUNNING", style="bold yellow")
self.refresh()

Expand Down Expand Up @@ -423,5 +447,23 @@ def on_data(self, entry, data, row):
else:
task = data.pop("task", "")
units = data.pop("units", "")
row.update({f"{task} {k}".strip(): f"{v} {units}" for k, v in data.items()})
metric_rows = {
f"{task} {k}".strip(): f"{self.format(v)} {self.unit(k, units)}"
for k, v in data.items()
}
row.update(metric_rows)
self.refresh()

def time(self, t):
return int(t - self.created_time)

def unit(self, k, unit):
if k == "time":
return "s"
return unit

def format(self, value):
try:
return f"{value:0.2f}"
except:
return str(value)

0 comments on commit 4ad1b1c

Please sign in to comment.