From 08ab8972ef4fe0c8a253b1f28385587e72dbdb94 Mon Sep 17 00:00:00 2001 From: Setepenre Date: Fri, 7 Jun 2024 13:45:15 -0400 Subject: [PATCH 1/2] Update log.py --- milabench/log.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/milabench/log.py b/milabench/log.py index a6f7388a9..fd2cddabf 100644 --- a/milabench/log.py +++ b/milabench/log.py @@ -225,7 +225,7 @@ def __init__(self): def prune(self): now = time.time() for tag, endtime in list(self.endtimes.items()): - if now - endtime > 60: + if now - endtime > 10: del self.endtimes[tag] del self.rows[tag] From 3ed4225e7382c3d3a0b325cfdaffce1384bfa8b7 Mon Sep 17 00:00:00 2001 From: Setepenre Date: Fri, 7 Jun 2024 18:52:40 +0000 Subject: [PATCH 2/2] Tweak dash tto show a global progress and add live report dash --- milabench/cli/run.py | 4 +- milabench/common.py | 2 +- milabench/config.py | 11 +- milabench/dashboard/__init__.py | 386 +++++++++++++++++++++++++++++ milabench/dashboard/live_report.py | 55 ++++ milabench/{ => dashboard}/log.py | 27 +- milabench/dashboard/rawoutput.py | 0 milabench/multi.py | 17 ++ milabench/report.py | 9 +- milabench/summary.py | 116 +++++---- 10 files changed, 572 insertions(+), 55 deletions(-) rename milabench/{ => dashboard}/log.py (93%) delete mode 100644 milabench/dashboard/rawoutput.py diff --git a/milabench/cli/run.py b/milabench/cli/run.py index 6d42a11f7..7e2eb4263 100644 --- a/milabench/cli/run.py +++ b/milabench/cli/run.py @@ -14,7 +14,8 @@ run_with_loggers, validation_names, ) -from ..log import ( +from ..dashbaord.live_report import LiveReportFormatter +from ..dashboard.log import ( DataReporter, LongDashFormatter, ShortDashFormatter, @@ -75,6 +76,7 @@ def cli_run(args=None): dash_class = { "short": ShortDashFormatter, "long": LongDashFormatter, + "live": LiveReportFormatter "no": None, }.get(args.dash, None) diff --git a/milabench/common.py b/milabench/common.py index 35f9cf125..fa66a9801 100644 --- a/milabench/common.py +++ b/milabench/common.py @@ -17,7 +17,7 @@ from .config import build_config, build_system_config from .fs import XPath -from .log import TerminalFormatter +from .dashboard.log import TerminalFormatter from .merge import merge from .multi import MultiPackage from .report import make_report diff --git a/milabench/config.py b/milabench/config.py index bfee806e7..c5c7a67f8 100644 --- a/milabench/config.py +++ b/milabench/config.py @@ -11,7 +11,16 @@ from .merge import merge system_global = contextvars.ContextVar("system") -config_global = contextvars.ContextVar("Config") +config_global = contextvars.ContextVar("config") +execution_count = contextvars.ContextVar("count") + + +def set_run_count(total): + execution_count.set(total) + + +def get_run_count(): + return execution_count.get() def relative_to(pth, cwd): diff --git a/milabench/dashboard/__init__.py b/milabench/dashboard/__init__.py index e69de29bb..fd2cddabf 100644 --- a/milabench/dashboard/__init__.py +++ b/milabench/dashboard/__init__.py @@ -0,0 +1,386 @@ +import json +import os +import pprint +import shlex +import time +from collections import defaultdict +from datetime import datetime + +from blessed import Terminal +from rich.console import Console +from rich.live import Live +from rich.panel import Panel +from rich.progress import BarColumn, Progress, TextColumn, TimeRemainingColumn +from rich.table import Table +from rich.text import Text + +from .fs import XPath + +T = Terminal() +color_wheel = [T.cyan, T.magenta, T.yellow, T.red, T.green, T.blue] + + +class BaseLogger: + def start(self): + pass + + def end(self): + pass + + def __enter__(self): + self.start() + return self + + def __exit__(self, *args, **kwargs): + self.end() + + +class TagConsole(BaseLogger): + def __init__(self, tag, i): + self.header = color_wheel[i % len(color_wheel)](T.bold(tag)) + + def _ensure_line(self, x): + if not x.endswith("\n"): + x += "\n" + return x + + def sprint(self, *parts): + parts = [self.header, *parts] + return self._ensure_line(" ".join(map(str, parts))) + + def spretty(self, *parts): + *parts, obj = parts + parts = [ + self.header, + *parts, + obj if isinstance(obj, str) else pprint.pformat(obj, width=120), + ] + return self._ensure_line(" ".join(map(str, parts))) + + def print(self, *parts): + print(self.sprint(*parts), end="") + + def pretty(self, *parts): + print(self.spretty(*parts), end="") + + def close(self): + pass + + +class TerminalFormatter(BaseLogger): + def __init__(self): + self.consoles = {} + self.error_happened = set() + self.early_stop = False + + def console(self, tag): + if tag not in self.consoles: + self.consoles[tag] = TagConsole(tag, len(self.consoles)) + + return self.consoles[tag] + + def __call__(self, entry): + event = entry.event + data = entry.data + pipe = entry.pipe + tag = entry.tag + + console = self.console(tag) + + if event == "line": + data = (data or "").rstrip() + if pipe == "stderr": + console.print(T.bold_yellow("[stderr]"), T.yellow(data)) + else: + console.print(T.bold("[stdout]"), data) + + elif event == "data": + data = dict(data) + if "progress" in data: + return + console.pretty(T.bold_magenta("[data]"), data) + + elif event == "start": + console.print( + T.bold_green("[start]"), + T.bold_green(shlex.join(data.get("command", []))), + T.gray(f'[at {datetime.fromtimestamp(data["time"])}]'), + ) + + elif event == "stop": + self.early_stop = True + + elif event == "error": + self.error_happened.add(tag) + if data["message"]: + console.print( + T.bold_red(data["type"] + ":"), + T.red(data["message"]), + ) + else: + console.print(T.bold_red(data["type"])) + + elif event == "end": + rc = data.get( + "return_code", + ) + wrong = not self.early_stop and ( + (tag in self.error_happened) or data["return_code"] != 0 + ) + if wrong: + rc = data["return_code"] or "ERROR" + console.print( + T.bold_red(f"[{event} ({rc})]"), + T.bold_red(shlex.join(data.get("command", []))), + T.gray(f'[at {datetime.fromtimestamp(data["time"])}]'), + ) + else: + console.print( + T.bold_green(f"[{event}]"), + T.bold_green(shlex.join(data.get("command", []))), + T.gray(f'[at {datetime.fromtimestamp(data["time"])}]'), + ) + + elif event == "phase": + pass + + elif event == "config": + + def _show(k, entry): + if isinstance(entry, dict): + for k2, v in entry.items(): + _show(f"{k}.{k2}", v) + else: + console.pretty(T.bold(f"[{k}]"), entry) + + _show("config", data) + + elif event == "message": + console.pretty(T.bold(f"[{event}]"), data["message"]) + + else: + console.pretty(T.bold(f"[{event}]"), data) + + +class BaseReporter(BaseLogger): + def __init__(self, pipe): + self.pipe = pipe + self.files = {} + + def file(self, entry): + if entry.tag not in self.files: + file = entry.pack.logfile(self.pipe) + os.makedirs(XPath(file).parent, exist_ok=True) + self.files[entry.tag] = open(file, "w").__enter__() + return self.files[entry.tag] + + def log(self, entry): + pass + + def cleanup(self, entry): + if entry.event == "end": + if entry.tag in self.files: + self.files[entry.tag].__exit__(None, None, None) + del self.files[entry.tag] + + def __call__(self, entry): + self.log(entry) + self.cleanup(entry) + + +class TextReporter(BaseReporter): + def log(self, entry): + if entry.event == "line" and entry.pipe == self.pipe: + self.file(entry).write(entry.data) + + def close(self): + assert not self.files + for open_file in self.files.values(): + open_file.__exit__(None, None, None) + + +class DataReporter(BaseReporter): + def __init__(self): + super().__init__(pipe="data") + + def log(self, entry): + d = entry.dict() + d.pop("pack") + try: + j = json.dumps(d) + except TypeError: + j = {"#unrepresentable": str(d)} + self.file(entry).write(f"{j}\n") + + +class DashFormatter(BaseLogger): + def __init__(self): + self.panel = Panel("") + self.console = Console() + self.live = Live(self.panel, refresh_per_second=4, console=self.console) + self.rows = defaultdict(dict) + self.endtimes = {} + self.early_stop = {} + + def prune(self): + now = time.time() + for tag, endtime in list(self.endtimes.items()): + if now - endtime > 10: + del self.endtimes[tag] + del self.rows[tag] + + def refresh(self): + self.prune() + self.live.update(self.make_table()) + + def start(self): + self.live.__enter__() + + def end(self): + self.live.__exit__(None, None, None) + + def __call__(self, entry): + event = entry.event + data = entry.data + tag = entry.tag + row = self.rows[tag] + + method = getattr(self, f"on_{event}", None) + if method: + method(entry, data, row) + + def on_stop(self, entry, data, row): + self.early_stop[entry.tag] = True + + def on_end(self, entry, data, row): + self.endtimes[entry.tag] = time.time() + + +class ShortDashFormatter(DashFormatter): + def make_table(self): + table = Table(padding=(0, 3, 0, 0)) + table.add_column("bench", style="bold white") + table.add_column("status") + table.add_column("progress", style="bold white") + table.add_column("rate", style="bold green") + table.add_column("loss", style="bold cyan") + table.add_column("gpu_load", style="bold magenta") + table.add_column("gpu_mem", style="bold magenta") + table.add_column("gpu_temp", style="bold magenta") + + for bench, values in self.rows.items(): + table.add_row( + bench, + values.get("status", "?"), + values.get("progress", "??%"), + values.get("rate", "?"), + values.get("loss", "?"), + values.get("gpu_load", "?"), + values.get("gpu_mem", "?"), + values.get("gpu_temp", "?"), + ) + + return table + + def on_data(self, entry, data, row): + data = dict(data) + task = data.get("task", None) + if prog := data.get("progress", None): + if task == "early_stop": + current, total = prog + if total > 0: + perc = int(100 * (current / total)) + if perc >= 100: + perc = "DONE" + else: + perc = f"{perc}%" + row["progress"] = perc + elif gpudata := data.get("gpudata", None): + for gpuid, data in gpudata.items(): + load = int(data.get("load", 0) * 100) + currm, totalm = data.get("memory", [0, 0]) + temp = int(data.get("temperature", 0)) + row[f"gpu:{gpuid}"] = ( + f"{load}% load | {currm:.0f}/{totalm:.0f} MB | {temp}C" + ) + row["gpu_load"] = f"{load}%" + row["gpu_mem"] = f"{currm:.0f}/{totalm:.0f} MB" + row["gpu_temp"] = f"{temp}C" + break + elif (rate := data.get("rate", None)) is not None: + if task == "train": + row["rate"] = f"{rate:.2f}" + elif (loss := data.get("loss", None)) is not None: + if task == "train": + row["loss"] = f"{loss:.2f}" + self.refresh() + + def on_start(self, entry, data, row): + row["status"] = Text("RUNNING", style="bold yellow") + self.refresh() + + def on_error(self, entry, data, row): + row["status"] = Text("ERROR", style="bold red") + self.refresh() + + def on_end(self, entry, data, row): + super().on_end(entry, data, row) + rc = data["return_code"] + if rc == 0 or self.early_stop.get(entry.tag, False): + row["status"] = Text("COMPLETED", style="bold green") + else: + row["status"] = Text(f"FAIL:{rc}", style="bold red") + self.refresh() + + +class LongDashFormatter(DashFormatter): + def make_table(self): + table = Table.grid(padding=(0, 3, 0, 0)) + table.add_column("bench", style="bold yellow") + table.add_column("key", style="bold green") + table.add_column("value") + + for bench, values in self.rows.items(): + values = dict(values) + progress = values.pop("progress", None) + if progress is not None: + table.add_row(bench, "", progress) + bench = "" + for key, value in values.items(): + table.add_row(bench, key, value) + bench = "" # Avoid displaying the bench for the other rows + + return Panel(table) + + def on_data(self, entry, data, row): + data = dict(data) + if prog := data.get("progress", None): + task = data.get("task", None) + if task == "early_stop": + current, total = prog + if "progress" not in row: + progress_bar = Progress( + BarColumn(), + TimeRemainingColumn(), + TextColumn("({task.completed}/{task.total})"), + ) + progress_bar._task = progress_bar.add_task("progress") + row["progress"] = progress_bar + else: + progress_bar = row["progress"] + progress_bar.update( + progress_bar._task, completed=current, total=total + ) + elif gpudata := data.get("gpudata", None): + for gpuid, data in gpudata.items(): + load = int(data.get("load", 0) * 100) + currm, totalm = data.get("memory", [0, 0]) + temp = int(data.get("temperature", 0)) + row[f"gpu:{gpuid}"] = ( + f"{load}% load | {currm:.0f}/{totalm:.0f} MB | {temp}C" + ) + else: + task = data.pop("task", "") + units = data.pop("units", "") + row.update({f"{task} {k}".strip(): f"{v} {units}" for k, v in data.items()}) + self.refresh() diff --git a/milabench/dashboard/live_report.py b/milabench/dashboard/live_report.py index e69de29bb..93a60f1e6 100644 --- a/milabench/dashboard/live_report.py +++ b/milabench/dashboard/live_report.py @@ -0,0 +1,55 @@ +import json +import os +import pprint +import shlex +import time +from collections import defaultdict +from datetime import datetime + +from blessed import Terminal +from rich.console import Console +from rich.live import Live +from rich.panel import Panel +from rich.progress import BarColumn, Progress, TextColumn, TimeRemainingColumn +from rich.table import Table +from rich.text import Text + +from .fs import XPath +from ..summary import make_summary_from_aggregates, Aggregator + +T = Terminal() +color_wheel = [T.cyan, T.magenta, T.yellow, T.red, T.green, T.blue] + + +class LiveReportFormatter(DashFormatter): + def __init__(self): + super().__init__() + self.running_set = defaultdict(int) + self.aggregator_set = defaultdict(lambda: defaultdict(Aggregator)) + self.prune_delay = 0 + + def __call__(self, entry): + benchname = entry.tag.split('.')[0] + + agg = self.aggregator_set[benchname][entry.tag] + agg.event_aggregator(entry) + + super().__call__(entry) + + def on_start(self, entry, data, row): + self.running_set[benchname] += 1 + + def on_end(self, entry, data, row): + super().on_end(entry, data, row) + + benchname = entry.tag.split('.')[0] + self.running_set[benchname] -= 1 + + if self.running_set[benchname] == 0: + self.produce_report_line(benchname) + + def produce_report_line(self, benchname): + aggregators = self.aggregator_set.pop(benchname) + summary = make_summary_from_aggregates([agg.group_by() for agg in aggregators]) + df = make_dataframe(summary, None, None) + print(x.to_string(formatters=_formatters)) \ No newline at end of file diff --git a/milabench/log.py b/milabench/dashboard/log.py similarity index 93% rename from milabench/log.py rename to milabench/dashboard/log.py index fd2cddabf..f76cb35bf 100644 --- a/milabench/log.py +++ b/milabench/dashboard/log.py @@ -15,6 +15,7 @@ from rich.text import Text from .fs import XPath +from .config import get_run_count T = Terminal() color_wheel = [T.cyan, T.magenta, T.yellow, T.red, T.green, T.blue] @@ -221,11 +222,34 @@ def __init__(self): self.rows = defaultdict(dict) self.endtimes = {} self.early_stop = {} + self.prune_delay = 30 + + self.current = 0 + self.total = get_run_count() + self.rows["GLOBAL"] = { + "progress": self._make_global_progress_bar() + } + + def _make_global_progress_bar(self): + progress_bar = Progress( + BarColumn(), + TimeRemainingColumn(), + TextColumn("({task.completed}/{task.total})"), + ) + progress_bar._task = progress_bar.add_task("progress") + self._update_global(0, self.total) + return progress_bar + + def _update_global(self, inc): + self.current += inc + self.rows["GLOBAL"].update( + progress_bar._task, completed=self.current, total=self.total + ) def prune(self): now = time.time() for tag, endtime in list(self.endtimes.items()): - if now - endtime > 10: + if now - endtime > self.prune_delay: del self.endtimes[tag] del self.rows[tag] @@ -253,6 +277,7 @@ def on_stop(self, entry, data, row): self.early_stop[entry.tag] = True def on_end(self, entry, data, row): + self._update_global(1) self.endtimes[entry.tag] = time.time() diff --git a/milabench/dashboard/rawoutput.py b/milabench/dashboard/rawoutput.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/milabench/multi.py b/milabench/multi.py index 9946a3642..cc0d96b79 100644 --- a/milabench/multi.py +++ b/milabench/multi.py @@ -17,6 +17,7 @@ milabench_remote_prepare, milabench_remote_run, ) +from .config import set_run_count from .utils import make_constraints_file here = XPath(__file__).parent @@ -154,8 +155,24 @@ async def do_prepare(self): await self.do_phase("prepare", remote_task, "prepare") + def count_runs(self): + acc = 0 + for index in range(repeat): + for pack in self.packs.values(): + if not await is_system_capable(pack): + continue + + exec_plan = make_execution_plan(pack, index, repeat) + + if isinstance(exec_plan, PerGPU): + acc += len(exec_plan.gpus) + else: + acc += 1 + return acc + async def do_run(self, repeat=1): setup = self.setup_pack() + set_run_count(self.count_runs()) if is_remote(setup): # if we are not on the main node right now diff --git a/milabench/report.py b/milabench/report.py index 5d7d32057..94d7fcc97 100644 --- a/milabench/report.py +++ b/milabench/report.py @@ -200,7 +200,7 @@ def make_dataframe(summary, compare=None, weights=None): ) ) - return DataFrame( + df = DataFrame( { key: _make_row( summary.get(key, {}), @@ -210,6 +210,10 @@ def make_dataframe(summary, compare=None, weights=None): for key in all_keys } ).transpose() + + df = df[sorted(df.columns, key=lambda k: columns_order.get(k, 0))] + + return df @error_guard({}) @@ -230,9 +234,6 @@ def make_report( df = make_dataframe(summary, compare, weights) - # Reorder columns - df = df[sorted(df.columns, key=lambda k: columns_order.get(k, 0))] - out = Outputter(stdout=stream, html=html) if sources: diff --git a/milabench/summary.py b/milabench/summary.py index 0bc4a8d4e..b102c4541 100644 --- a/milabench/summary.py +++ b/milabench/summary.py @@ -6,16 +6,16 @@ from .utils import error_guard -@error_guard(None) -def aggregate(run_data): - """Group all the data inside a dictionary of lists""" - omnibus = defaultdict(list) - config = None - start = None - end = None - early_stop = False - for entry in run_data: +class Aggregator: + def __int__(self): + self.omnibus = defaultdict(list) + self.config = None + self.start = None + self.end = None + self.early_stop = False + + def event_aggregator(self, entry): event = entry["event"] if event == "config": @@ -27,63 +27,81 @@ def aggregate(run_data): for k, v in data.items(): if task is not None and k == "rate": k = f"{task}_{k}" - omnibus[k].append(v) + self.omnibus[k].append(v) elif event == "line": - omnibus[entry["pipe"]].append(entry["data"]) + self.omnibus[entry["pipe"]].append(entry["data"]) elif event == "stop": - early_stop = True + self.early_stop = True elif event == "start": - assert start is None - start = entry["data"] + assert self.start is None + self.start = entry["data"] elif event == "end": - assert end is None - end = entry["data"] + assert self.end is None + self.end = entry["data"] - if not config: - # This is not a run - return None + def aggregate(self, run_data): + for entry in run_data: + self.event_aggregator(entry) - assert config and start and end + def group_by(self): + if not self.config: + # This is not a run + return None - device = config.get("device", None) - omnibus["gpudata"] = [ - {str(device): entry[str(device)]} if device is not None else entry - for entry in omnibus.get("gpudata", []) - if device is None or str(device) in entry - ] + assert self.config and start and end, "Missing metrics" - if device is not None: - omnibus["per_gpu"] = [(device, tr) for tr in omnibus["train_rate"]] + device = self.config.get("device", None) - if "loss" in omnibus: - fl, ll = omnibus["loss"][0], omnibus["loss"][-1] - omnibus["loss_gain"] = [ll - fl] + newdata = [] + for entry in self.omnibus.get("gpudata", []) + if device is None or str(device) in entry + if device is not None: + newdata.append({str(device): entry[str(device)]}) + else: + newdata.append(entry) + self.omnibus["gpudata"] = newdata - omnibus["walltime"] = [end["time"] - start["time"]] + if device is not None: + self.omnibus["per_gpu"] = [(device, tr) for tr in self.omnibus["train_rate"]] - success = early_stop or ( - end["return_code"] == 0 - and not any(isnan(loss) for loss in omnibus.get("loss", [])) - and bool(omnibus.get("train_rate", [])) - ) + if "loss" in self.omnibus: + fl, ll = self.omnibus["loss"][0], self.omnibus["loss"][-1] + self.omnibus["loss_gain"] = [ll - fl] - if "nolog" in config["tag"]: - success = True + self.omnibus["walltime"] = [self.end["time"] - self.start["time"]] - omnibus["success"] = [success] + success = self.early_stop or ( + end["return_code"] == 0 + and not any(isnan(loss) for loss in self.omnibus.get("loss", [])) + and bool(self.omnibus.get("train_rate", [])) + ) - return { - "config": config, - "start": start, - "end": end, - "data": omnibus, - } + if "nolog" in self.config["tag"]: + success = True + + self.omnibus["success"] = [success] + + return { + "config": self.config, + "start": self.start, + "end": self.end, + "data": self.omnibus, + } + +@error_guard(None) +def aggregate(run_data): + """Group all the data inside a dictionary of lists""" + + agg = Aggregator() + agg.aggregate(run_data) + return agg.group_by() + def _classify(all_aggregates): """Group data by benchmark names""" classified = defaultdict(list) @@ -175,7 +193,11 @@ def _summarize(group): def make_summary(runs): aggs = [agg for run in runs if (agg := aggregate(run))] + return make_summary_from_aggregates(aggs) + + +def make_summary_from_aggregates(aggs): classified = _classify(aggs) merged = {name: _merge(runs) for name, runs in classified.items()} summarized = {name: _summarize(agg) for name, agg in merged.items()} - return summarized + return summarized \ No newline at end of file