diff --git a/.gitignore b/.gitignore index 778dcc7b7..8fedda861 100644 --- a/.gitignore +++ b/.gitignore @@ -43,3 +43,4 @@ test.out output/ workspace/ .pin/tmp-* +dry/ diff --git a/config/base.yaml b/config/base.yaml index b4e01b5e1..1f8bdcd78 100644 --- a/config/base.yaml +++ b/config/base.yaml @@ -22,7 +22,7 @@ _torchvision: --lr: 0.01 --no-stdout: true --epochs: 50 - --num-workers: 8 + --num-workers: "auto({n_worker}, 8)" --loader: pytorch --data: "{milabench_data}/FakeImageNet" @@ -37,7 +37,7 @@ _torchvision_ddp: n: 1 argv: --epochs: 10 - --num-workers: 8 + --num-workers: "auto({n_worker}, 8)" --loader: pytorch --data: "{milabench_data}/FakeImageNet" @@ -82,7 +82,6 @@ llama: argv: --pretrained: true - _hf: inherits: _defaults definition: ../benchmarks/huggingface @@ -90,7 +89,7 @@ _hf: install_group: torch argv: --precision: 'tf32-fp16' - --num-workers: 8 + --num-workers: "auto({n_worker}, 8)" plan: method: per_gpu @@ -111,6 +110,7 @@ _timm: --val-split: '' --data-dir: "{milabench_data}" --dataset: "FakeImageNet" + --workers: "auto({n_worker}, 8)" _sb3: inherits: _defaults @@ -143,7 +143,7 @@ _accelerate_opt: --dataset_rev: "b08601e" --validation_split_percentage: 5 --per_gpu_batch_size: 1 - --cpus_per_gpu: 8 + --cpus_per_gpu: "auto({n_worker}, 8)" # --model_name: "facebook/opt-2.7b" # --model_name: "facebook/opt-1.3b" # --model_name: "facebook/opt-350m" @@ -203,7 +203,7 @@ resnet50: argv: --model: resnet50 --batch-size: 256 - --num-workers: "{cpu_per_gpu}" + --num-workers: "auto({n_worker}, 8)" --loader: pytorch resnet50-noio: @@ -231,7 +231,7 @@ resnet152-ddp: argv: --model: resnet152 --batch-size: 256 - --num-workers: 8 + --num-workers: "auto({n_worker}, 8)" --loader: dali efficientnet_b4: @@ -507,6 +507,7 @@ stargan: --model_save_dir: "{milabench_extra}/models" --sample_dir: "{milabench_extra}/samples" --result_dir: "{milabench_extra}/results" + --num_workers: "auto({n_worker}, 8)" super-slomo: inherits: _defaults @@ -524,7 +525,7 @@ super-slomo: --train_batch_size: 64 --dataset_root: "{milabench_data}/FakeImageNet" --loader: pytorch - --num_workers: 8 + --num_workers: "auto({n_worker}, 8)" ppo: inherits: _sb3 @@ -588,6 +589,7 @@ dlrm: --test-mini-batch-size: 16384 --test-num-workers: 0 --use-gpu: true + --num-workers: "auto({n_worker}, 8)" rwkv: inherits: _defaults @@ -625,6 +627,7 @@ rwkv: --grad_cp: 0 --random_seed: 1234 --enable_progress_bar: "False" + brax: inherits: _defaults tags: diff --git a/config/scaling.yaml b/config/scaling.yaml index f4947c213..520752452 100644 --- a/config/scaling.yaml +++ b/config/scaling.yaml @@ -122,12 +122,18 @@ focalnet: optimized: 128 opt-1_3b: arg: --per_gpu_batch_size + model: + 1: 42126 MiB optimized: 1 opt-1_3b-multinode: arg: --per_gpu_batch_size + model: + 1: 42126 MiB optimized: 1 opt-6_7b-multinode: arg: --per_gpu_batch_size + model: + 1: 55380 MiB optimized: 1 reformer: arg: --batch-size diff --git a/milabench/_version.py b/milabench/_version.py index 3e7d71512..b0896b12d 100644 --- a/milabench/_version.py +++ b/milabench/_version.py @@ -1,5 +1,5 @@ """This file is generated, do not modify""" -__tag__ = "v0.0.10-147-gc6540c3e" -__commit__ = "c6540c3e470222e44b4a841954593185db49b111" -__date__ = "2024-06-12 07:11:39 -0400" +__tag__ = "v0.0.10-147-g1ef648ee" +__commit__ = "1ef648eeb78233e53274058cd9cfcdc539f01bae" +__date__ = "2024-06-12 09:39:51 -0400" diff --git a/milabench/cli/gather.py b/milabench/cli/gather.py new file mode 100644 index 000000000..3669a74df --- /dev/null +++ b/milabench/cli/gather.py @@ -0,0 +1,124 @@ +import argparse +import os +import re +from dataclasses import dataclass, field + +import pandas as pd + +from ..common import _read_reports +from ..report import make_dataframe, pandas_to_string +from ..summary import make_summary + + +def default_tags(): + return [ + "worker=w([a-z0-9]*)", + "multiple=m([0-9]*)", + "power=p([0-9]*)", + "capacity=c([A-Za-z0-9]*(Go)?)", + ] + + +# fmt: off +@dataclass +class Arguments: + runs: str + tags: list = field(default_factory=default_tags) +# fmt: on + + +def arguments(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--runs", + type=str, + help="Run folder", + default="/home/mila/d/delaunap/batch_x_worker/", + ) + parser.add_argument( + "--tags", + type=str, + help="Tags defined in run names", + default=default_tags(), + ) + return parser.parse_args() # Arguments() + + +def get_config(reports): + k = list(reports.keys())[0] + config = None + for line in reports[k]: + if line["event"] == "config": + config = line["data"] + break + return config + + +def extract_tags(name, tags): + for tag, pat in tags.items(): + if m := pat.search(name): + value = m.group(1) + yield tag, value + else: + print(f"{tag} not found in {name}") + yield tag, "NA" + + +def gather_cli(args=None): + """Gather metrics from runs inside a folder in a neat format. + It can extract tags/flags from the runname and create new columns to uniquely identify runs. + + Examples + -------- + + >>> python -m milabench.cli.gather --runs /home/mila/d/delaunap/batch_x_worker/ + bench | fail | n | perf | sem% | std% | peak_memory | score | weight | elapsed | name | worker | multiple | power | capacity + brax | 0 | 1 | 722480.33 | 0.7% | 5.2% | 6448 | 722480.33 | 1.00 | 94 | w16-m8-c4Go | 16 | 8 | NA | 4Go + dlrm | 0 | 1 | 350641.30 | 0.6% | 4.6% | 7624 | 350641.30 | 1.00 | 124 | w16-m8-c4Go | 16 | 8 | NA | 4Go + .... + brax | 0 | 1 | 723867.42 | 0.6% | 4.5% | 6448 | 723867.42 | 1.00 | 94 | w2-m8-c8Go | 2 | 8 | NA | 8Go + dlrm | 0 | 1 | 403113.36 | 0.7% | 5.1% | 7420 | 403113.36 | 1.00 | 258 | w2-m8-c8Go | 2 | 8 | NA | 8Go + bf16 | 0 | 8 | 293.08 | 0.3% | 7.5% | 5688 | 2361.09 | 0.00 | 18 | w2-m8-c8Go | 2 | 8 | NA | 8Go + fp16 | 0 | 8 | 290.58 | 0.2% | 4.9% | 5688 | 2335.63 | 0.00 | 29 | w2-m8-c8Go | 2 | 8 | NA | 8Go + + """ + if args is None: + args = arguments() + + runs = [] + for folder in os.listdir(args.runs): + if folder.startswith("prepare"): + continue + + if folder.startswith("install"): + continue + + path = f"{args.runs}/{folder}" + if os.path.isdir(path): + runs.append(path) + + tags = dict() + for tag in args.tags: + name, regex = tag.split("=") + tags[name] = re.compile(regex) + + query = ("batch_size", "elapsed") + data = [] + for run in runs: + reports = _read_reports(run) + summary = make_summary(reports.values(), query=query) + df = make_dataframe(summary, None, None, query=query) + + name = run.split("/")[-1] + df["name"] = name.split(".", maxsplit=1)[0] + for tag, value in extract_tags(name, tags): + df[tag] = value + + data.append(df) + + gathered = pd.concat(data) + print(pandas_to_string(gathered)) + + +if __name__ == "__main__": + gather_cli() diff --git a/milabench/cli/matrix.py b/milabench/cli/matrix.py index 732db2b2b..b2e78a510 100644 --- a/milabench/cli/matrix.py +++ b/milabench/cli/matrix.py @@ -1,15 +1,18 @@ +import sys from dataclasses import dataclass +import yaml from coleo import Option, tooled from ..common import ( build_config, - build_system_config, deduce_arch, get_base_defaults, is_selected, merge, ) +from ..sizer import resolve_argv, scale_argv +from ..system import build_system_config # fmt: off @@ -79,7 +82,34 @@ def cli_matrix_run(args=None): clean_config(config, args) - for k in config: - print(k) + def resolve_args(conf, argv): + from ..pack import Package - # yaml.dump(config, sys.stdout) + pack = Package(conf) + + args = [] + for k, v in argv.items(): + args.append(k) + args.append(v) + + sized_args = scale_argv(pack, args) + final_args = resolve_argv(pack, sized_args) + + i = 0 + for k, v in argv.items(): + if final_args[i] == k: + argv[k] = final_args[i + 1] + i += 2 + continue + + print(f"Missing resolved argument {k}") + + return argv + + for _, conf in config.items(): + conf["argv"] = resolve_args(conf, conf["argv"]) + + # for k in config: + # print(k) + + yaml.dump(config, sys.stdout) diff --git a/milabench/commands/__init__.py b/milabench/commands/__init__.py index 6d018b4c8..9e2ca1d77 100644 --- a/milabench/commands/__init__.py +++ b/milabench/commands/__init__.py @@ -659,6 +659,14 @@ def _argv(self, **_) -> List: else ["--multi_gpu"] ) + # + # Can this logic be removed? + # + from ..sizer import new_argument_resolver + + resolver = new_argument_resolver(self.pack) + + cpu_per_process = resolver(str(self.pack.config["argv"]["--cpus_per_gpu"])) return [ # -- Run the command in the right venv # This could be inside the SSH Command @@ -676,7 +684,7 @@ def _argv(self, **_) -> List: f"--num_machines={num_machines}", *deepspeed_argv, f"--gradient_accumulation_steps={self.pack.config['gradient_accumulation_steps']}", - f"--num_cpu_threads_per_process={self.pack.config['argv']['--cpus_per_gpu']}", + f"--num_cpu_threads_per_process={cpu_per_process}", f"--main_process_ip={manager['ip']}", f"--main_process_port={manager['port']}", f"--num_processes={nproc}", diff --git a/milabench/common.py b/milabench/common.py index 4babb8bbc..429895ef7 100644 --- a/milabench/common.py +++ b/milabench/common.py @@ -15,13 +15,14 @@ from milabench.alt_async import proceed from milabench.utils import available_layers, blabla, multilogger -from .config import build_config, build_system_config +from .config import build_config from .fs import XPath from .log import TerminalFormatter from .merge import merge from .multi import MultiPackage from .report import make_report from .summary import aggregate, make_summary +from .system import build_system_config def get_pack(defn): diff --git a/milabench/config.py b/milabench/config.py index e082ead77..585dee48f 100644 --- a/milabench/config.py +++ b/milabench/config.py @@ -6,13 +6,11 @@ import psutil import yaml from omegaconf import OmegaConf -from voir.instruments.gpu import get_gpu_info from .fs import XPath from .merge import merge -system_global = contextvars.ContextVar("system") -config_global = contextvars.ContextVar("Config") +config_global = contextvars.ContextVar("config", default=None) def relative_to(pth, cwd): @@ -112,7 +110,6 @@ def build_matrix_bench(all_configs): for name, bench_config in all_configs.items(): for k, v in expand_matrix(name, bench_config): - if k in expanded_config: raise ValueError("Bench name is not unique") @@ -136,151 +133,3 @@ def build_config(*config_files): config_global.set(all_configs) return all_configs - - -def check_node_config(nodes): - mandatory_fields = ["name", "ip", "user"] - - for node in nodes: - name = node.get("name", None) - - for field in mandatory_fields: - assert field in node, f"The `{field}` of the node `{name}` is missing" - - -def get_remote_ip(): - """Get all the ip of all the network interfaces""" - addresses = psutil.net_if_addrs() - stats = psutil.net_if_stats() - - result = [] - - for interface, address_list in addresses.items(): - for address in address_list: - if interface in stats and getattr(stats[interface], "isup"): - result.append(address.address) - - return set(result) - - -def _resolve_ip(ip): - # Resolve the IP - try: - hostname, aliaslist, ipaddrlist = socket.gethostbyaddr(ip) - lazy_raise = None - except socket.gaierror as err: - # Get Addr Info (GAI) Error - # - # When we are connecting to a node through a ssh proxy jump - # the node IPs/Hostnames are not available until we reach - # the first node inside the cluster - # - hostname = ip - aliaslist = [] - ipaddrlist = [] - lazy_raise = err - - return hostname, aliaslist, ipaddrlist, lazy_raise - - -def resolve_addresses(nodes): - # Note: it is possible for self to be none - # if we are running milabench on a node that is not part of the system - # in that case it should still work; the local is then going to - # ssh into the main node which will dispatch the work to the other nodes - self = None - lazy_raise = None - ip_list = get_remote_ip() - - for node in nodes: - hostname, aliaslist, ipaddrlist, lazy_raise = _resolve_ip(node["ip"]) - - node["hostname"] = hostname - node["aliaslist"] = aliaslist - node["ipaddrlist"] = ipaddrlist - - if hostname.endswith(".server.mila.quebec.server.mila.quebec"): - print() - print("Hostname was extra long for no reason") - print(hostname, socket.gethostname()) - print() - - # why is this happening - hostname = hostname[: -len(".server.mila.quebec")] - - is_local = ( - ("127.0.0.1" in ipaddrlist) - or (hostname in ("localhost", socket.gethostname())) - or len(ip_list.intersection(ipaddrlist)) > 0 - ) - node["local"] = is_local - - if is_local: - self = node - node["ipaddrlist"] = list(ip_list) - - # if self is node we might be outisde the cluster - # which explains why we could not resolve the IP of the nodes - if self is not None and lazy_raise: - raise RuntimeError("Could not resolve node ip") from lazy_raise - - return self - - -def get_gpu_capacity(strict=False): - try: - capacity = 0 - - for k, v in get_gpu_info()["gpus"].items(): - capacity = min(v["memory"]["total"], capacity) - - return capacity - except: - print("GPU not available, defaulting to 0 MiB") - if strict: - raise - return 0 - - -def is_autoscale_enabled(): - return ( - os.getenv("MILABENCH_SIZER_AUTO", False) - or os.getenv("MILABENCH_SIZER_MULTIPLE") is not None - ) - - -def build_system_config(config_file, defaults=None, gpu=True): - """Load the system configuration, verify its validity and resolve ip addresses - - Notes - ----- - * node['local'] true when the code is executing on the machine directly - * node["main"] true when the machine is in charge of distributing the workload - """ - - if config_file is None: - config = {"system": {}} - else: - config_file = XPath(config_file).absolute() - with open(config_file) as cf: - config = yaml.safe_load(cf) - - if defaults: - config = merge(defaults, config) - - system = config.get("system", {}) - - # capacity is only required if batch resizer is enabled - if (gpu or is_autoscale_enabled()) and not "gpu" not in system: - system["gpu"] = {"capacity": f"{int(get_gpu_capacity())} MiB"} - - if system.get("sshkey") is not None: - system["sshkey"] = str(XPath(system["sshkey"]).resolve()) - - check_node_config(system["nodes"]) - - self = resolve_addresses(system["nodes"]) - system["self"] = self - - system_global.set(system) - return config diff --git a/milabench/log.py b/milabench/log.py index a6f7388a9..bed8aac3e 100644 --- a/milabench/log.py +++ b/milabench/log.py @@ -147,6 +147,9 @@ def __call__(self, entry): elif event == "config": def _show(k, entry): + if k in ("meta", "system"): + return + if isinstance(entry, dict): for k2, v in entry.items(): _show(f"{k}.{k2}", v) @@ -300,9 +303,9 @@ def on_data(self, entry, data, row): load = int(data.get("load", 0) * 100) currm, totalm = data.get("memory", [0, 0]) temp = int(data.get("temperature", 0)) - row[f"gpu:{gpuid}"] = ( - f"{load}% load | {currm:.0f}/{totalm:.0f} MB | {temp}C" - ) + row[ + f"gpu:{gpuid}" + ] = f"{load}% load | {currm:.0f}/{totalm:.0f} MB | {temp}C" row["gpu_load"] = f"{load}%" row["gpu_mem"] = f"{currm:.0f}/{totalm:.0f} MB" row["gpu_temp"] = f"{temp}C" @@ -376,9 +379,9 @@ def on_data(self, entry, data, row): load = int(data.get("load", 0) * 100) currm, totalm = data.get("memory", [0, 0]) temp = int(data.get("temperature", 0)) - row[f"gpu:{gpuid}"] = ( - f"{load}% load | {currm:.0f}/{totalm:.0f} MB | {temp}C" - ) + row[ + f"gpu:{gpuid}" + ] = f"{load}% load | {currm:.0f}/{totalm:.0f} MB | {temp}C" else: task = data.pop("task", "") units = data.pop("units", "") diff --git a/milabench/report.py b/milabench/report.py index 8488482e4..42a701a16 100644 --- a/milabench/report.py +++ b/milabench/report.py @@ -13,7 +13,7 @@ @error_guard({}) -def _make_row(summary, compare, weights): +def _make_row(summary, compare, weights, query=None): mkey = "train_rate" metric = "mean" row = {} @@ -55,6 +55,12 @@ def _make_row(summary, compare, weights): row["weight"] = weights.get("weight", summary.get("weight", 0)) # ---- + if query is not None: + extra = summary.get("extra", dict()) + for q in query: + if v := extra.get(q): + row[q] = v + return row @@ -185,33 +191,50 @@ def _report_pergpu(entries, measure="50"): } -def make_dataframe(summary, compare=None, weights=None): +def make_dataframe(summary, compare=None, weights=None, query=None): if weights is None: weights = dict() all_keys = list( - sorted( - { - *(summary.keys() if summary else []), - *(compare.keys() if compare else []), - *(weights.keys() if weights else []), - } - ) + { + *(summary.keys() if summary else []), + *(compare.keys() if compare else []), + *(weights.keys() if weights else []), + } ) + def sort_by(key): + """Group similar runs together""" + if weights: + return weights[key]["group"] + + if summary: + return summary[key]["group"] + + return key + + # Sort by name first so bench with similar names are together + # we want bench in the same group with similar names to be close + all_keys = sorted(all_keys) + + # Sort by group so bench are grouped together + # we want flops bench to be close together no matter what their names are + all_keys = sorted(all_keys, key=sort_by) + df = DataFrame( { key: _make_row( summary.get(key, {}), compare and compare.get(key, {}), weights and weights.get(key, {}), + query=query, ) for key in all_keys } ).transpose() # Reorder columns - df = df[sorted(df.columns, key=lambda k: columns_order.get(k, 0))] + df = df[sorted(df.columns, key=lambda k: columns_order.get(k, 2000))] return df @@ -301,37 +324,6 @@ def _score(column): out.finalize() -def pandas_to_string(df, formatters): - """Default stdout printer does not insert a column sep which makes it hard to retranscribe results elsewhere. - to_csv does not align the output. - """ - from collections import defaultdict - - columns = df.columns.tolist() - - sep = " | " - lines = [] - col_size = defaultdict(int) - - for index, row in df.iterrows(): - line = [f"{index:<30}"] - for col, val in zip(columns, row): - fmt = formatters.get(col) - val = fmt(val) - col_size[col] = max(col_size[col], len(val)) - line.append(val) - - lines.append(sep.join(line)) - - def fmtcol(col): - size = col_size[col] - return f"{col:>{size}}" - - header = sep.join([f"{'bench':<30}"] + [fmtcol(col) for col in columns]) - - return "\n".join([header] + lines) - - _formatters = { "fail": "{:4.0f}".format, "n": "{:3.0f}".format, @@ -347,8 +339,10 @@ def fmtcol(col): "sem%": "{:6.1%}".format, "iqr%": "{:6.1%}".format, "score": "{:10.2f}".format, - "weight": "{:5.2f}".format, + "weight": "{:6.2f}".format, "peak_memory": "{:11.0f}".format, + "elapsed": "{:5.0f}".format, + "batch_size": "{:3.0f}".format, 0: "{:.0%}".format, 1: "{:.0%}".format, 2: "{:.0%}".format, @@ -368,6 +362,42 @@ def fmtcol(col): } +def pandas_to_string(df, formatters=_formatters): + """Default stdout printer does not insert a column sep which makes it hard to retranscribe results elsewhere. + to_csv does not align the output. + """ + from collections import defaultdict + + columns = df.columns.tolist() + + sep = " | " + lines = [] + col_size = defaultdict(int) + + for index, row in df.iterrows(): + line = [f"{index:<30}"] + for col, val in zip(columns, row): + fmt = formatters.get(col) + + if fmt is not None: + val = fmt(val) + col_size[col] = max(col_size[col], len(val)) + else: + val = str(val) + + line.append(val) + + lines.append(sep.join(line)) + + def fmtcol(col): + size = col_size[col] + return f"{col:>{size}}" + + header = sep.join([f"{'bench':<30}"] + [fmtcol(col) for col in columns]) + + return "\n".join([header] + lines) + + _table_style = H.style( """ body { diff --git a/milabench/sizer.py b/milabench/sizer.py index 7d2e56194..55ddcdb0d 100644 --- a/milabench/sizer.py +++ b/milabench/sizer.py @@ -2,12 +2,13 @@ import multiprocessing import os from copy import deepcopy -from dataclasses import dataclass import numpy as np import yaml +from voir.instruments.gpu import get_gpu_info -from .config import is_autoscale_enabled, system_global +from .merge import merge +from .system import CPUOptions, SizerOptions, system_global from .validation.validation import ValidationLayer ROOT = os.path.dirname(__file__) @@ -15,24 +16,6 @@ default_scaling_config = os.path.join(ROOT, "..", "config", "scaling.yaml") -def getenv(name, type): - value = os.getenv(name) - - if value is not None: - return type(value) - - return value - - -@dataclass -class SizerOptions: - size: int = getenv("MILABENCH_SIZER_BATCH_SIZE", int) - autoscale: bool = is_autoscale_enabled() - multiple: int = getenv("MILABENCH_SIZER_MULTIPLE", int) - optimized: bool = getenv("MILABENCH_SIZER_OPTIMIZED", int) - capacity: str = getenv("MILABENCH_SIZER_CAPACITY", str) - - metric_prefixes = { "T": (12, 4), "G": (9, 3), @@ -86,7 +69,7 @@ def benchscaling(self, benchmark): # benchmark config if isinstance(benchmark, dict) and "name" in benchmark: - return benchmark + return self.scaling_config.get(benchmark["name"]) # pack return self.scaling_config.get(benchmark.config["name"]) @@ -110,13 +93,16 @@ def auto_size(self, benchmark, capacity): model = config.get("model", None) if model is None: - print(f"Missing batch-size model for {benchmark}") + print(f"Missing batch-size model for {benchmark.config['name']}") return 1 data = list(sorted(config["model"].items(), key=lambda x: x[0])) mem = [to_octet(v[1]) for v in data] size = [float(v[0]) for v in data] + if len(mem) == 1: + print(f"Not enough data for {benchmark.config['name']}") + return 1 # This does not extrapolate # int(np.interp(capacity, mem, size)) @@ -126,13 +112,21 @@ def auto_size(self, benchmark, capacity): newsize_f = model(capacity) newsize_i = int(newsize_f) + if newsize_i <= 0: + return 1 + if (newsize_f - newsize_i) > 0.5: newsize_i += 1 - if self.options.multiple is not None: - newsize_i = (newsize_i // self.options.multiple) * self.options.multiple + final_size = newsize_i - return max(newsize_i, 1) + if self.options.multiple: + final_size = (newsize_i // self.options.multiple) * self.options.multiple + + if self.options.power: + final_size = int(self.options.power) ** int(np.log2(newsize_i)) + + return max(final_size, 1) def size(self, benchmark, capacity): config = self.benchscaling(benchmark) @@ -148,6 +142,27 @@ def size(self, benchmark, capacity): return None + def find_batch_size(self, benchmark, event): + config = self.benchscaling(benchmark) + + if config is None: + return None + + argname = config.get("arg") + if argname is None: + return -1 + + if "event" in event: + event = event["data"] + + argv = event["command"] + + for i, arg in enumerate(argv): + if str(arg).endswith(argname): + return int(argv[i + 1]) + + return -1 + def argv(self, benchmark, capacity, argv): """Find the batch size and override it with a new value""" @@ -177,11 +192,25 @@ def argv(self, benchmark, capacity, argv): return argv -sizer_global = contextvars.ContextVar("sizer_global", default=Sizer()) +sizer_global = contextvars.ContextVar("sizer_global", default=None) -def scale_argv(pack, argv): +def batch_sizer() -> Sizer: sizer = sizer_global.get() + if sizer is None: + sizer_global.set(Sizer()) + return batch_sizer() + return sizer + + +def get_batch_size(config, start_event): + sizer = batch_sizer() + return sizer.find_batch_size(config, start_event) + + +def scale_argv(pack, argv): + sizer = batch_sizer() + system = system_global.get() capacity = system.get("gpu", dict()).get("capacity") @@ -193,9 +222,9 @@ class MemoryUsageExtractor(ValidationLayer): """Extract max memory usage per benchmark to populate the memory model""" def __init__(self): - self.filepath = getenv("MILABENCH_SIZER_SAVE", str) - - self.memory = deepcopy(sizer_global.get().scaling_config) + sizer = batch_sizer() + self.filepath = sizer.options.save + self.memory = deepcopy(sizer.scaling_config) self.scaling = None self.benchname = None self.batch_size = 0 @@ -271,35 +300,69 @@ def on_end(self, entry): def report(self, *args): if self.filepath is not None: + newdata = self.memory + + if os.path.exists(self.filepath): + with open(self.filepath, "r") as fp: + previous_data = yaml.safe_load(fp) + newdata = merge(previous_data, self.memory) + with open(self.filepath, "w") as file: - yaml.dump(self.memory, file) + yaml.dump(newdata, file) -def resolve_argv(pack, argv): - context = system_global.get() +def new_argument_resolver(pack): + context = deepcopy(system_global.get()) arch = context.get("arch", "cpu") - device_count = len(pack.config.get("devices", [0])) + if hasattr(pack, "config"): + device_count = len(pack.config.get("devices", [0])) + else: + device_count = len(get_gpu_info()["gpus"]) ccl = {"hpu": "hccl", "cuda": "nccl", "rocm": "rccl", "xpu": "ccl", "cpu": "gloo"} if device_count <= 0: device_count = 1 + options = CPUOptions() + + def auto(value, default): + if options.enabled: + return value + return default + + def clamp(x, mn=options.cpu_min, mx=options.cpu_max): + return min(max(x, mn), mx) + + total_cpu = multiprocessing.cpu_count() + total_available = total_cpu - options.reserved_cores + + context["cpu_count"] = total_available + context["cpu_per_gpu"] = total_available // device_count + context["n_worker"] = clamp(context["cpu_per_gpu"]) + + if options.n_workers is not None: + context["n_worker"] = options.n_workers + context["arch"] = arch context["ccl"] = ccl.get(arch, "gloo") - context["cpu_count"] = multiprocessing.cpu_count() - context["cpu_per_gpu"] = multiprocessing.cpu_count() // device_count - context["milabench_data"] = pack.config.get("dirs", {}).get("data", None) context["milabench_cache"] = pack.config.get("dirs", {}).get("cache", None) context["milabench_extra"] = pack.config.get("dirs", {}).get("extra", None) - max_worker = 16 - context["n_worker"] = min(context["cpu_per_gpu"], max_worker) + def auto_eval(arg): + newvalue = str(arg).format(**context) + if newvalue.startswith("auto"): + newvalue = str(eval(newvalue, {"auto": auto}, {})) + return newvalue + + return auto_eval + +def resolve_argv(pack, argv): + resolver = new_argument_resolver(pack) argv = list(argv) for i, arg in enumerate(argv): - argv[i] = str(arg).format(**context) - + argv[i] = resolver(arg) return argv diff --git a/milabench/summary.py b/milabench/summary.py index 0bc4a8d4e..946f6e6ef 100644 --- a/milabench/summary.py +++ b/milabench/summary.py @@ -135,8 +135,28 @@ def _metrics(xs): return metrics +@error_guard(dict()) +def augment(group, query=tuple([])): + """Optional augmentation steps that will add additional data. + Usually extracted from the run itself + """ + data = {} + + if "batch_size" in query: + from .sizer import get_batch_size + + data["batch_size"] = get_batch_size(group["config"], group["start"]) + + if "elapsed" in query: + start_time = group["start"]["time"] + end_time = group["end"]["time"] + data["elapsed"] = end_time - start_time + + return data + + @error_guard(None) -def _summarize(group): +def _summarize(group, query): agg = group["data"] gpudata = defaultdict(lambda: defaultdict(list)) @@ -152,8 +172,12 @@ def _summarize(group): per_gpu[device].append(tr) config = group["config"] + + additional = augment(group, query) + return { "name": config["name"], + "group": config["group"], "n": len(agg["success"]), "successes": sum(agg["success"]), "failures": sum(not x for x in agg["success"]), @@ -170,12 +194,13 @@ def _summarize(group): for device, data in gpudata.items() }, "weight": config.get("weight", 0), + "extra": additional, } -def make_summary(runs): +def make_summary(runs, query=None): aggs = [agg for run in runs if (agg := aggregate(run))] classified = _classify(aggs) merged = {name: _merge(runs) for name, runs in classified.items()} - summarized = {name: _summarize(agg) for name, agg in merged.items()} + summarized = {name: _summarize(agg, query) for name, agg in merged.items()} return summarized diff --git a/milabench/system.py b/milabench/system.py new file mode 100644 index 000000000..c470bfbe6 --- /dev/null +++ b/milabench/system.py @@ -0,0 +1,281 @@ +import contextvars +import os +import socket +from dataclasses import dataclass, field + +import psutil +import yaml +from voir.instruments.gpu import get_gpu_info + +from .fs import XPath +from .merge import merge + +system_global = contextvars.ContextVar("system", default=None) + + +def getenv(name, expected_type): + value = os.getenv(name) + + if value is not None: + try: + return expected_type(value) + except TypeError: + print(f"{name}={value} expected type {expected_type} got {type(value)}") + return None + return value + + +def print_once(*args, **kwargs): + printed = 0 + + def _print(): + nonlocal printed + if printed == 0: + print(*args, **kwargs) + printed += 1 + + return _print + + +warn_no_config = print_once("No system config found, using defaults") + + +def option(name, etype, default=None): + options = dict() + system = system_global.get() + + if system: + options = system.get("options", dict()) + else: + warn_no_config() + + frags = name.split(".") + env_name = "MILABENCH_" + "_".join(map(str.upper, frags)) + env_value = getenv(env_name, etype) + + lookup = options + for frag in frags[:-1]: + lookup = lookup.get(frag, dict()) + + system_value = lookup.get(frags[-1], None) + final_value = env_value or system_value or default + + if final_value is None: + return None + + try: + value = etype(final_value) + lookup[frags[-1]] = value + return value + except ValueError: + print(f"{name}={value} expected type {etype} got {type(value)}") + return None + + +def is_autoscale_enabled(): + return option("sizer.auto", int, 0) > 0 + + +def default_save_location(): + from pathlib import Path + + return Path.home() / "new_scaling.yaml" + + +@dataclass +class SizerOptions: + size: int = option("sizer.batch_size", int) + autoscale: bool = option("sizer.auto", int, 0) + multiple: int = option("sizer.multiple", int, 8) + power: int = option("sizer.power", int) + optimized: bool = option("sizer.optimized", int) + capacity: str = option("sizer.capacity", str) + save: str = option("sizer.save", str, None) + + +@dataclass +class CPUOptions: + enabled: bool = option("cpu.auto", bool, False) + + # max number of CPU per GPU + cpu_max: int = option("cpu.max", int, 16) + + # min number of CPU per GPU + cpu_min: int = option("cpu.min", int, 2) + + # reserved CPU cores (i.e not available for the benchmark) + reserved_cores: int = option("cpu.reserved_cores", int, 0) + + # Number of workers (ignores cpu_max and cpu_min) + n_workers: int = option("cpu.n_workers", int) + + +@dataclass +class Options: + sizer: SizerOptions + cpu: CPUOptions + + +@dataclass +class GPUConfig: + capacity: str = None + + +@dataclass +class Nodes: + name: str + ip: str + port: int + main: bool + user: str + + +@dataclass +class SystemConfig: + arch: str = getenv("MILABENCH_GPU_ARCH", str) + sshkey: str = None + docker_image: str = None + nodes: list[Nodes] = field(default_factory=list) + gpu: GPUConfig = None + options: Options = None + + +def check_node_config(nodes): + mandatory_fields = ["name", "ip", "user"] + + for node in nodes: + name = node.get("name", None) + + for field in mandatory_fields: + assert field in node, f"The `{field}` of the node `{name}` is missing" + + +def get_remote_ip(): + """Get all the ip of all the network interfaces""" + addresses = psutil.net_if_addrs() + stats = psutil.net_if_stats() + + result = [] + + for interface, address_list in addresses.items(): + for address in address_list: + if interface in stats and getattr(stats[interface], "isup"): + result.append(address.address) + + return set(result) + + +def _resolve_ip(ip): + # Resolve the IP + try: + hostname, aliaslist, ipaddrlist = socket.gethostbyaddr(ip) + lazy_raise = None + except socket.gaierror as err: + # Get Addr Info (GAI) Error + # + # When we are connecting to a node through a ssh proxy jump + # the node IPs/Hostnames are not available until we reach + # the first node inside the cluster + # + hostname = ip + aliaslist = [] + ipaddrlist = [] + lazy_raise = err + + return hostname, aliaslist, ipaddrlist, lazy_raise + + +def resolve_addresses(nodes): + # Note: it is possible for self to be none + # if we are running milabench on a node that is not part of the system + # in that case it should still work; the local is then going to + # ssh into the main node which will dispatch the work to the other nodes + self = None + lazy_raise = None + ip_list = get_remote_ip() + + for node in nodes: + hostname, aliaslist, ipaddrlist, lazy_raise = _resolve_ip(node["ip"]) + + node["hostname"] = hostname + node["aliaslist"] = aliaslist + node["ipaddrlist"] = ipaddrlist + + if hostname.endswith(".server.mila.quebec.server.mila.quebec"): + print() + print("Hostname was extra long for no reason") + print(hostname, socket.gethostname()) + print() + + # why is this happening + hostname = hostname[: -len(".server.mila.quebec")] + + is_local = ( + ("127.0.0.1" in ipaddrlist) + or (hostname in ("localhost", socket.gethostname())) + or len(ip_list.intersection(ipaddrlist)) > 0 + ) + node["local"] = is_local + + if is_local: + self = node + node["ipaddrlist"] = list(ip_list) + + # if self is node we might be outisde the cluster + # which explains why we could not resolve the IP of the nodes + if self is not None and lazy_raise: + raise RuntimeError("Could not resolve node ip") from lazy_raise + + return self + + +def get_gpu_capacity(strict=False): + try: + capacity = 0 + + for k, v in get_gpu_info()["gpus"].items(): + capacity = min(v["memory"]["total"], capacity) + + return capacity + except: + print("GPU not available, defaulting to 0 MiB") + if strict: + raise + return 0 + + +def build_system_config(config_file, defaults=None, gpu=True): + """Load the system configuration, verify its validity and resolve ip addresses + + Notes + ----- + * node['local'] true when the code is executing on the machine directly + * node["main"] true when the machine is in charge of distributing the workload + """ + + if config_file is None: + config = {"system": {}} + else: + config_file = XPath(config_file).absolute() + with open(config_file) as cf: + config = yaml.safe_load(cf) + + if defaults: + config = merge(defaults, config) + + system = config.get("system", {}) + system_global.set(system) + + # capacity is only required if batch resizer is enabled + if (gpu or is_autoscale_enabled()) and not "gpu" not in system: + system["gpu"] = {"capacity": f"{int(get_gpu_capacity())} MiB"} + + if system.get("sshkey") is not None: + system["sshkey"] = str(XPath(system["sshkey"]).resolve()) + + check_node_config(system["nodes"]) + + self = resolve_addresses(system["nodes"]) + system["self"] = self + + return config diff --git a/scripts/article/run_batch_x_worker.sh b/scripts/article/run_batch_x_worker.sh new file mode 100644 index 000000000..66ab58d9f --- /dev/null +++ b/scripts/article/run_batch_x_worker.sh @@ -0,0 +1,112 @@ +#!/bin/bash + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) + +WORKERS=("2" "4" "8" "16" "32") +MEMORY_CAPACITY=("4Go" "8Go" "16Go" "32Go" "64Go" "80Go") +DRY=0 +FINAL_OUTPUT="$HOME/batch_x_worker" + +export MILABENCH_SIZER_SAVE="$FINAL_OUTPUT/scaling.yaml" +mkdir -p $FINAL_OUTPUT + +# +# Install +# +# if [ "$DRY" -eq 0 ]; then +# export MILABENCH_PREPARE=1 +# source $SCRIPT_DIR/run_cuda.sh + +# # +# # Activate +# # +# source $MILABENCH_WORDIR/env/bin/activate +# fi + +export MILABENCH_GPU_ARCH=cuda +export MILABENCH_WORDIR="$(pwd)/$MILABENCH_GPU_ARCH" +export MILABENCH_BASE="$MILABENCH_WORDIR/results" +export MILABENCH_CONFIG="$MILABENCH_WORDIR/milabench/config/standard.yaml" +export MILABENCH_VENV="$MILABENCH_WORDIR/env" +export BENCHMARK_VENV="$MILABENCH_WORDIR/results/venv/torch" + +source $MILABENCH_WORDIR/env/bin/activate + + +maybe_run() { + local name=$1 + if [ "$DRY" -eq 1 ]; then + mkdir -p dry + echo $name + milabench matrix --base output --config config/standard.yaml > dry/$name.yaml + else + milabench prepare + milabench run --run-name $name + mv $MILABENCH_BASE/runs/* ~/batch_x_worker/ + fi +} + +# +# Default everything +# +#export MILABENCH_CPU_AUTO=0 +#export MILABENCH_SIZER_AUTO=0 +#maybe_run "wdef-cdef.{time}" + +# +# Auto everything +# +export MILABENCH_CPU_AUTO=1 +export MILABENCH_SIZER_AUTO=1 +export MILABENCH_SIZER_MULTIPLE=8 +maybe_run "wauto-m8-cauto.{time}" + +# +# Multiple of 8 +# +for CAPACITY in "${MEMORY_CAPACITY[@]}"; do + for WORKER in "${WORKERS[@]}"; do + export MILABENCH_CPU_AUTO=1 + export MILABENCH_CPU_N_WORKERS="$WORKER" + + export MILABENCH_SIZER_AUTO=1 + export MILABENCH_SIZER_MULTIPLE=8 + export MILABENCH_SIZER_CAPACITY="$CAPACITY" + + maybe_run "w$WORKER-m8-c$CAPACITY.{time}" + done +done + +# +# Multiple of 32 +# +for CAPACITY in "${MEMORY_CAPACITY[@]}"; do + for WORKER in "${WORKERS[@]}"; do + export MILABENCH_CPU_AUTO=1 + export MILABENCH_CPU_N_WORKERS="$WORKER" + + export MILABENCH_SIZER_AUTO=1 + export MILABENCH_SIZER_MULTIPLE=32 + export MILABENCH_SIZER_CAPACITY="$CAPACITY" + + maybe_run "w$WORKER-m32-c$CAPACITY.{time}" + done +done + +# +# Power of 2 +# +for CAPACITY in "${MEMORY_CAPACITY[@]}"; do + for WORKER in "${WORKERS[@]}"; do + export MILABENCH_CPU_AUTO=1 + export MILABENCH_CPU_N_WORKERS="$WORKER" + + export MILABENCH_SIZER_AUTO=1 + export MILABENCH_SIZER_MULTIPLE=0 + export MILABENCH_SIZER_POWER=2 + export MILABENCH_SIZER_CAPACITY="$CAPACITY" + + maybe_run "w$WORKER-p2-c$CAPACITY.{time}" + done +done + diff --git a/scripts/article/run_cuda.sh b/scripts/article/run_cuda.sh index 26e789da8..2f0427630 100644 --- a/scripts/article/run_cuda.sh +++ b/scripts/article/run_cuda.sh @@ -4,12 +4,11 @@ set -ex export MILABENCH_GPU_ARCH=cuda export MILABENCH_WORDIR="$(pwd)/$MILABENCH_GPU_ARCH" - export MILABENCH_BASE="$MILABENCH_WORDIR/results" export MILABENCH_CONFIG="$MILABENCH_WORDIR/milabench/config/standard.yaml" export MILABENCH_VENV="$MILABENCH_WORDIR/env" export BENCHMARK_VENV="$MILABENCH_WORDIR/results/venv/torch" - +export MILABENCH_PREPARE=0 install_prepare() { mkdir -p $MILABENCH_WORDIR @@ -20,7 +19,7 @@ install_prepare() { fi if [ ! -d "$MILABENCH_WORDIR/milabench" ]; then - git clone https://github.com/mila-iqia/milabench.git + git clone https://github.com/mila-iqia/milabench.git -b worker_x_batch fi . $MILABENCH_WORDIR/env/bin/activate @@ -58,12 +57,15 @@ else . $MILABENCH_WORDIR/env/bin/activate fi -cd $MILABENCH_WORDIR -# -# Run the benchmakrs -milabench run "$@" +if [ "$MILABENCH_PREPARE" -eq 0 ]; then + cd $MILABENCH_WORDIR + + # + # Run the benchmakrs + milabench run "$@" -# -# Display report -milabench report --runs $MILABENCH_WORDIR/results/runs + # + # Display report + milabench report --runs $MILABENCH_WORDIR/results/runs +fi \ No newline at end of file diff --git a/scripts/batch/run_cuda.sh b/scripts/batch/run_cuda.sh deleted file mode 100644 index c5a4e2ec4..000000000 --- a/scripts/batch/run_cuda.sh +++ /dev/null @@ -1,14 +0,0 @@ - - - -export MILABENCH_GPU_ARCH=cuda -export MILABENCH_WORDIR="$(pwd)/$MILABENCH_GPU_ARCH" -export MILABENCH_CONFIG="$MILABENCH_WORDIR/milabench/config/resnet50.yaml" - -CUDA_VISIBLE_DEVICES=0 bash $MILABENCH_WORDIR/milabench/scripts/article/run_cuda.sh --config $MILABENCH_CONFIG --select resnet - -CUDA_VISIBLE_DEVICES=0,1 bash $MILABENCH_WORDIR/milabench/scripts/article/run_cuda.sh --config $MILABENCH_CONFIG --select resnet - -CUDA_VISIBLE_DEVICES=0,1,2,3 bash $MILABENCH_WORDIR/milabench/scripts/article/run_cuda.sh --config $MILABENCH_CONFIG --select resnet - -CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 bash $MILABENCH_WORDIR/milabench/scripts/article/run_cuda.sh --config $MILABENCH_CONFIG --select resnet diff --git a/tests/test_scaler.py b/tests/test_scaler.py index 5d8d561b4..7df968e35 100644 --- a/tests/test_scaler.py +++ b/tests/test_scaler.py @@ -74,7 +74,7 @@ def fakeexec(pack): def test_scaler_enabled(multipack, config): - from milabench.config import system_global + from milabench.system import system_global import contextvars ctx = contextvars.copy_context()