diff --git a/benchmarks/diffusion/main.py b/benchmarks/diffusion/main.py old mode 100644 new mode 100755 index 2b4fe9bfd..09513f606 --- a/benchmarks/diffusion/main.py +++ b/benchmarks/diffusion/main.py @@ -1,3 +1,5 @@ +#!/usr/bin/env python + from dataclasses import dataclass from accelerate import Accelerator diff --git a/config/base.yaml b/config/base.yaml index 438485de9..11d50c018 100644 --- a/config/base.yaml +++ b/config/base.yaml @@ -394,9 +394,8 @@ _diffusion: inherits: _defaults definition: ../benchmarks/diffusion install_group: torch - plan: - method: njobs - n: 1 + tags: + - diffusion argv: --num_epochs: 5 @@ -408,10 +407,13 @@ diffusion-single: inherits: _diffusion num_machines: 1 plan: - method: njobs + method: per_gpu diffusion-gpus: inherits: _diffusion + plan: + method: njobs + n: 1 num_machines: 1 diffusion-nodes: @@ -426,6 +428,8 @@ _lightning: inherits: _defaults definition: ../benchmarks/lightning install_group: torch + tags: + - lightning argv: --epochs: 10 --num-workers: "auto({n_worker}, 8)" @@ -452,6 +456,9 @@ _dinov2: definition: ../benchmarks/dinov2 install_group: torch num_machines: 1 + tags: + - image + - transformer plan: method: njobs n: 1 @@ -505,7 +512,9 @@ _llm: voir: options: stop: 30 - + tags: + - nlp + - llm max_duration: 1200 num_machines: 1 inherits: _defaults @@ -517,7 +526,6 @@ llm-lora-single: inherits: _llm plan: method: per_gpu - argv: "{milabench_code}/recipes/lora_finetune_single_device.py": true --config: "{milabench_code}/configs/llama3_8B_lora_single_device.yaml" @@ -664,6 +672,10 @@ ppo: _geo_gnn: inherits: _defaults + tags: + - graph + # FIXME: torch cluster is laging behind pytorch + # we are forced to use torch==2.3 instead of torch==2.4 install_group: gnn group: geo_gnn definition: ../benchmarks/geo_gnn @@ -682,6 +694,8 @@ recursiongfn: definition: ../benchmarks/recursiongfn install_group: gnn group: recursiongfn_gnn + tags: + - graph plan: method: per_gpu @@ -698,7 +712,8 @@ torchatari: install_group: torch plan: method: per_gpu - + tags: + - rl argv: --num-minibatches: 16 --update-epochs: 4 diff --git a/milabench/_version.py b/milabench/_version.py index 70125045a..0202d13c4 100644 --- a/milabench/_version.py +++ b/milabench/_version.py @@ -3,3 +3,4 @@ __tag__ = "v0.1.0-28-g8069946" __commit__ = "8069946d331fb92090057d7eedd598515249521d" __date__ = "2024-08-01 12:39:13 -0400" + diff --git a/milabench/scripts/vcs.py b/milabench/scripts/vcs.py index 0f895f886..54bc7638d 100644 --- a/milabench/scripts/vcs.py +++ b/milabench/scripts/vcs.py @@ -26,10 +26,16 @@ def retrieve_git_versions(tag="", commit="", date=""): } +def version_file(): + return os.path.join(ROOT, "milabench", "_version.py") + def read_previous(): info = ["", "", ""] - - with open(os.path.join(ROOT, "milabench", "_version.py"), "r") as file: + + if not os.path.exists(version_file()): + return info + + with open(version_file(), "r") as file: for line in file.readlines(): if "tag" in line: _, v = line.split("=") @@ -49,7 +55,7 @@ def read_previous(): def update_version_file(): version_info = retrieve_git_versions(*read_previous()) - with open(os.path.join(ROOT, "milabench", "_version.py"), "w") as file: + with open(version_file(), "w") as file: file.write('"""') file.write("This file is generated, do not modify") file.write('"""\n\n')