From 06d9da4ec84fa21f5cfd0545267eb999716b8494 Mon Sep 17 00:00:00 2001 From: "pierre.delaunay" Date: Wed, 17 Jul 2024 17:27:00 -0400 Subject: [PATCH] Diffusion benchmark --- benchmarks/accelerate_opt/benchfile.py | 2 +- benchmarks/diffusion/Makefile | 14 ++ benchmarks/diffusion/README.md | 4 + benchmarks/diffusion/benchfile.py | 38 +++++ benchmarks/diffusion/dev.yaml | 32 ++++ benchmarks/diffusion/main.py | 226 +++++++++++++++++++++++++ benchmarks/diffusion/prepare.py | 16 ++ benchmarks/diffusion/requirements.in | 6 + benchmarks/lightning/dev.yaml | 1 - benchmate/benchmate/metrics.py | 3 +- benchmate/benchmate/observer.py | 15 +- milabench/cli/run.py | 11 +- milabench/commands/__init__.py | 105 +++++++++--- milabench/commands/executors.py | 3 +- milabench/pack.py | 8 +- milabench/scripts/activator | 7 + 16 files changed, 455 insertions(+), 36 deletions(-) create mode 100644 benchmarks/diffusion/Makefile create mode 100644 benchmarks/diffusion/README.md create mode 100644 benchmarks/diffusion/benchfile.py create mode 100644 benchmarks/diffusion/dev.yaml create mode 100644 benchmarks/diffusion/main.py create mode 100755 benchmarks/diffusion/prepare.py create mode 100644 benchmarks/diffusion/requirements.in create mode 100755 milabench/scripts/activator diff --git a/benchmarks/accelerate_opt/benchfile.py b/benchmarks/accelerate_opt/benchfile.py index 72607ff84..c0c00a381 100644 --- a/benchmarks/accelerate_opt/benchfile.py +++ b/benchmarks/accelerate_opt/benchfile.py @@ -35,7 +35,7 @@ def build_prepare_plan(self): ) def build_run_plan(self): - # FIXME: use ForeachNode + # FIXME: or AccelerateAllNodes plans = [] max_num = self.config["num_machines"] diff --git a/benchmarks/diffusion/Makefile b/benchmarks/diffusion/Makefile new file mode 100644 index 000000000..680a2a265 --- /dev/null +++ b/benchmarks/diffusion/Makefile @@ -0,0 +1,14 @@ + +install: + milabench install --config dev.yaml --base base --force + +tests: + milabench install --config dev.yaml --base base + milabench prepare --config dev.yaml --base base + milabench run --config dev.yaml --base base + +gpus: + milabench run --config dev.yaml --base base --select diffusion-gpus + +nodes: + milabench run --config dev.yaml --base base --select diffusion-nodes \ No newline at end of file diff --git a/benchmarks/diffusion/README.md b/benchmarks/diffusion/README.md new file mode 100644 index 000000000..367fff836 --- /dev/null +++ b/benchmarks/diffusion/README.md @@ -0,0 +1,4 @@ + +# Diffusion + +Rewrite this README to explain what the benchmark is! diff --git a/benchmarks/diffusion/benchfile.py b/benchmarks/diffusion/benchfile.py new file mode 100644 index 000000000..6940c72c0 --- /dev/null +++ b/benchmarks/diffusion/benchfile.py @@ -0,0 +1,38 @@ +from milabench.pack import Package + +from milabench.commands import AccelerateAllNodes + + +class Diffusion(Package): + # Requirements file installed by install(). It can be empty or absent. + base_requirements = "requirements.in" + + # The preparation script called by prepare(). It must be executable, + # but it can be any type of script. It can be empty or absent. + prepare_script = "prepare.py" + + # The main script called by run(). It must be a Python file. It has to + # be present. + main_script = "main.py" + + # You can remove the functions below if you don't need to modify them. + + def make_env(self): + return { + **super().make_env(), + "OMP_NUM_THREADS": str(self.config.get("cpus_per_gpu", 8)), + } + + async def install(self): + await super().install() # super() call installs the requirements + + async def prepare(self): + await super().prepare() # super() call executes prepare_script + + def build_run_plan(self): + plan = super().build_run_plan() + + return AccelerateAllNodes(plan, use_stdout=True) + + +__pack__ = Diffusion diff --git a/benchmarks/diffusion/dev.yaml b/benchmarks/diffusion/dev.yaml new file mode 100644 index 000000000..4fb8ab576 --- /dev/null +++ b/benchmarks/diffusion/dev.yaml @@ -0,0 +1,32 @@ + +diffusion: + inherits: _defaults + definition: . + install-variant: unpinned + install_group: torch + num_machines: 1 + plan: + method: per_gpu + +diffusion-gpus: + inherits: _defaults + definition: . + install-variant: unpinned + install_group: torch + num_machines: 1 + plan: + method: njobs + n: 1 + +diffusion-nodes: + inherits: _defaults + num_machines: 2 + definition: . + install-variant: unpinned + install_group: torch + plan: + method: njobs + n: 1 + + requires_capabilities: + - "len(nodes) >= ${num_machines}" diff --git a/benchmarks/diffusion/main.py b/benchmarks/diffusion/main.py new file mode 100644 index 000000000..85fd4af87 --- /dev/null +++ b/benchmarks/diffusion/main.py @@ -0,0 +1,226 @@ +from pathlib import Path +import os +from dataclasses import dataclass + +import torch +from torchvision import transforms +import torch.nn.functional as F +from diffusers import DDPMScheduler + +from diffusers import DDPMPipeline +from accelerate import Accelerator +from tqdm.auto import tqdm +from datasets import load_dataset +from diffusers import UNet2DModel +from diffusers.optimization import get_cosine_schedule_with_warmup + +# from huggingface_hub import HfFolder, Repository, whoami + +@dataclass +class TrainingConfig: + image_size = 128 # the generated image resolution + train_batch_size = 16 + eval_batch_size = 16 # how many images to sample during evaluation + num_epochs = 50 + gradient_accumulation_steps = 1 + dataset_name: str = "huggan/smithsonian_butterflies_subset" + learning_rate = 1e-4 + lr_warmup_steps = 500 + save_image_epochs = 10 + save_model_epochs = 30 + mixed_precision = "fp16" # `no` for float32, `fp16` for automatic mixed precision + output_dir = "ddpm-butterflies-128" # the model name locally and on the HF Hub + push_to_hub = False # whether to upload the saved model to the HF Hub + hub_private_repo = False + overwrite_output_dir = True # overwrite the old model when re-running the notebook + seed = 0 + + +def build_dataset(config): + dataset = load_dataset(config.dataset_name, split="train") + + preprocess = transforms.Compose( + [ + transforms.Resize((config.image_size, config.image_size)), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + def transform(examples): + images = [preprocess(image.convert("RGB")) for image in examples["image"]] + return {"images": images} + + dataset.set_transform(transform) + + loader = torch.utils.data.DataLoader( + dataset, + batch_size=config.train_batch_size, + shuffle=True + ) + return loader + + + +def build_model(config): + model = UNet2DModel( + sample_size=config.image_size, # the target image resolution + in_channels=3, # the number of input channels, 3 for RGB images + out_channels=3, # the number of output channels + layers_per_block=2, # how many ResNet layers to use per UNet block + block_out_channels=(128, 128, 256, 256, 512, 512), # the number of output channels for each UNet block + down_block_types=( + "DownBlock2D", # a regular ResNet downsampling block + "DownBlock2D", + "DownBlock2D", + "DownBlock2D", + "AttnDownBlock2D", # a ResNet downsampling block with spatial self-attention + "DownBlock2D", + ), + up_block_types=( + "UpBlock2D", # a regular ResNet upsampling block + "AttnUpBlock2D", # a ResNet upsampling block with spatial self-attention + "UpBlock2D", + "UpBlock2D", + "UpBlock2D", + "UpBlock2D", + ), + ) + + return model + + + +def get_full_repo_name(model_id: str, organization: str = None, token: str = None): + if token is None: + token = HfFolder.get_token() + if organization is None: + username = whoami(token)["name"] + return f"{username}/{model_id}" + else: + return f"{organization}/{model_id}" + + +def build_loss(): + return F.mse_loss + +def train_loop(config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler): + from benchmate.observer import BenchObserver + + def batch_size(x): + return x["images"].shape[0] + + observer = BenchObserver( + earlystop=65, + batch_size_fn=lambda x: batch_size(x), + stdout=True, + raise_stop_program=True + ) + + # Initialize accelerator and tensorboard logging + accelerator = Accelerator( + mixed_precision=config.mixed_precision, + gradient_accumulation_steps=config.gradient_accumulation_steps, + # log_with="tensorboard", + # project_dir=os.path.join(config.output_dir, "logs"), + ) + if accelerator.is_main_process: + if False: + if config.push_to_hub: + repo_name = get_full_repo_name(Path(config.output_dir).name) + repo = Repository(config.output_dir, clone_from=repo_name) + elif config.output_dir is not None: + os.makedirs(config.output_dir, exist_ok=True) + accelerator.init_trackers("train_example") + + # Prepare everything + # There is no specific order to remember, you just need to unpack the + # objects in the same order you gave them to the prepare method. + model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + model, optimizer, train_dataloader, lr_scheduler + ) + + global_step = 0 + criterion = build_loss() + + # Now you train the model + for epoch in range(config.num_epochs): + # progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process) + # progress_bar.set_description(f"Epoch {epoch}") + + for step, batch in enumerate(observer.iterate(train_dataloader)): + clean_images = batch["images"].to(model.device) + # Sample noise to add to the images + noise = torch.randn(clean_images.shape).to(clean_images.device) + bs = clean_images.shape[0] + + # Sample a random timestep for each image + timesteps = torch.randint( + 0, noise_scheduler.config.num_train_timesteps, (bs,), device=clean_images.device + ).long() + + # Add noise to the clean images according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps) + + with accelerator.accumulate(model): + # Predict the noise residual + noise_pred = model(noisy_images, timesteps, return_dict=False)[0] + loss = criterion(noise_pred, noise) + accelerator.backward(loss) + observer.record_loss(loss) + + accelerator.clip_grad_norm_(model.parameters(), 1.0) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # progress_bar.update(1) + global_step += 1 + + # After each epoch you optionally sample some demo images with evaluate() and save the model + if accelerator.is_main_process: + if False: + pipeline = DDPMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler) + + if (epoch + 1) % config.save_image_epochs == 0 or epoch == config.num_epochs - 1: + evaluate(config, epoch, pipeline) + + if (epoch + 1) % config.save_model_epochs == 0 or epoch == config.num_epochs - 1: + if config.push_to_hub: + repo.push_to_hub(commit_message=f"Epoch {epoch}", blocking=True) + else: + pipeline.save_pretrained(config.output_dir) + + + + +def build_optimizer(config, model): + return torch.optim.AdamW(model.parameters(), lr=config.learning_rate) + +def main(): + config = TrainingConfig() + + model = build_model(config) + dataset = build_dataset(config) + optimizer = build_optimizer(config, model) + noise_scheduler = DDPMScheduler(num_train_timesteps=1000) + + + lr_scheduler = get_cosine_schedule_with_warmup( + optimizer=optimizer, + num_warmup_steps=config.lr_warmup_steps, + num_training_steps=(len(dataset) * config.num_epochs), + ) + + from benchmate.metrics import StopProgram + + try: + train_loop(config, model, noise_scheduler, optimizer, dataset, lr_scheduler) + + except StopProgram: + pass + +if __name__ == "__main__": + main() diff --git a/benchmarks/diffusion/prepare.py b/benchmarks/diffusion/prepare.py new file mode 100755 index 000000000..32bd5901d --- /dev/null +++ b/benchmarks/diffusion/prepare.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python + +import os + +if __name__ == "__main__": + # If you need the whole configuration: + # config = json.loads(os.environ["MILABENCH_CONFIG"]) + + data_directory = os.environ["MILABENCH_DIR_DATA"] + + # Download (or generate) the needed dataset(s). You are responsible + # to check if it has already been properly downloaded or not, and to + # do nothing if it has been. + print("Hello I am doing some data stuff!") + + # If there is nothing to download or generate, just delete this file. diff --git a/benchmarks/diffusion/requirements.in b/benchmarks/diffusion/requirements.in new file mode 100644 index 000000000..ee696fb44 --- /dev/null +++ b/benchmarks/diffusion/requirements.in @@ -0,0 +1,6 @@ +voir>=0.2.9,<0.3 +diffusers[torch] +accelerate +datasets +tqdm +torchvision \ No newline at end of file diff --git a/benchmarks/lightning/dev.yaml b/benchmarks/lightning/dev.yaml index 962991d65..e59a4ebbb 100644 --- a/benchmarks/lightning/dev.yaml +++ b/benchmarks/lightning/dev.yaml @@ -17,7 +17,6 @@ lightning: num_machines: 1 plan: method: per_gpu - n: 1 lightning-gpus: inherits: _lightning diff --git a/benchmate/benchmate/metrics.py b/benchmate/benchmate/metrics.py index fc0056bca..360291b19 100644 --- a/benchmate/benchmate/metrics.py +++ b/benchmate/benchmate/metrics.py @@ -73,7 +73,8 @@ def materialize(self, *args, **kwargs): def push(self, pusher): """Iterate through data and push metrics.""" for args, kwargs in self.delayed: - pusher(**self.materialize(*args, **kwargs)) + data = self.materialize(*args, **kwargs) + pusher(**data) self.delayed = [] diff --git a/benchmate/benchmate/observer.py b/benchmate/benchmate/observer.py index 78cb3af14..92f587896 100644 --- a/benchmate/benchmate/observer.py +++ b/benchmate/benchmate/observer.py @@ -42,7 +42,7 @@ def __init__( self.task = "train" self.rank = rank self.losses = LazyLossPusher(self.task) - + self.instance = None self.pusher = give_push() if self.stdout: self.pusher = sumggle_push() @@ -52,7 +52,7 @@ def on_iterator_stop_iterator(self): self.losses.push(self.pusher) def record_loss(self, loss): - if self.rank is None or self.rank == 1: + if self.rank is None or self.rank == 0: self.losses.record(loss) return loss @@ -74,11 +74,15 @@ def iterate(self, iterator): def loader(self, loader): """Wrap a dataloader or an iterable which enable accurate measuring of time spent in the loop's body""" + if self.instance: + return self.instance + self.wrapped = TimedIterator( loader, *self.args, rank=self.rank, push=self.pusher, **self.kwargs ) self.wrapped.task = self.task self.wrapped.on_iterator_stop_iterator = self.on_iterator_stop_iterator + self.instance = self.wrapped return self.wrapped def criterion(self, criterion): @@ -110,5 +114,12 @@ def new_step(*args, **kwargs): original(*args, **kwargs) self.optimizer_step_callback() + # wow + # pytorch does + # instance_ref = weakref.ref(method.__self__) + # where + # method == self.optimizer.step + new_step.__self__ = optimizer.step.__self__ + new_step.__func__ = optimizer.step.__func__ optimizer.step = new_step return optimizer diff --git a/milabench/cli/run.py b/milabench/cli/run.py index c59e7e6df..f5e75b702 100644 --- a/milabench/cli/run.py +++ b/milabench/cli/run.py @@ -63,6 +63,15 @@ def arguments(): return Arguments(run_name, repeat, fulltrace, report, dash, noterm, validations) + +def _fetch_arch(mp): + try: + arch = next(iter(mp.packs.values())).config["system"]["arch"] + except StopIteration: + print("no selected bench") + return None + + @tooled def cli_run(args=None): """Run the benchmarks.""" @@ -78,7 +87,7 @@ def cli_run(args=None): }.get(args.dash, None) mp = get_multipack(run_name=args.run_name) - arch = next(iter(mp.packs.values())).config["system"]["arch"] + arch = _fetch_arch(mp) # Initialize the backend here so we can retrieve GPU stats init_arch(arch) diff --git a/milabench/commands/__init__.py b/milabench/commands/__init__.py index 828de68eb..9acc07bf9 100644 --- a/milabench/commands/__init__.py +++ b/milabench/commands/__init__.py @@ -520,10 +520,31 @@ def _argv(self, **kwargs): class ForeachNode(ListCommand): - def __init__(self, executor: Command, use_stdout=True, **kwargs) -> None: + def __init__(self, executor: Command, **kwargs) -> None: super().__init__(None, **kwargs) + self.options = kwargs self.executor = executor - self.use_stdout = use_stdout + + def make_new_node_pack(self, rank, node, base) -> "BasePackage": + """Make a new environment/config for the run""" + config = base.pack.config + tags = [*config["tag"], node["name"]] + + # Workers do not send training data + # tag it as such so validation can ignore this pack + if rank != 0: + tags.append("nolog") + + run = clone_with(config, {"tag": tags}) + return base.pack.copy(run) + + def make_new_node_executor(self, rank, node, base): + """Make a new environment and create a new executor for the node""" + pack = self.make_new_node_pack(rank, node, base) + return base.copy(pack) + + def single_node(self): + return self.executor @property def executors(self): @@ -538,34 +559,24 @@ def executors(self): # useless in single node setups if len(self.nodes) == 1 or max_num == 1: - return [self.executor] + return [self.single_node()] - def new_executor(base, cfg): - nonlocal config - run = clone_with(config, cfg) - new_pack = self.executor.pack.copy(run) - return base.copy(new_pack) - for rank, node in enumerate(self.nodes): options = dict() - tags = [*self.executor.pack.config["tag"], node["name"]] + # Hummm... if rank == 0: options = dict( setsid=True, - use_stdout=self.use_stdout, + **self.options ) - else: - # Workers do not send training data - # tag it as such so validation can ignore this pack - tags.append("nolog") worker = SSHCommand( host=node["ip"], user=node["user"], key=key, port=node.get("port", 22), - executor=new_executor(self.executor, {"tag": tags}), + executor=self.make_new_node_executor(rank, node, self.executor), **options ) executors.append(worker) @@ -773,6 +784,43 @@ def _argv(self, **_) -> List: return [f"{self.pack.dirs.code / 'activator'}", f"{self.pack.dirs.venv}"] + +class AccelerateAllNodes(ForeachNode): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def single_node(self): + ngpu = len(self.executor.pack.config.get("devices", [])) + + # Multi GPU + if ngpu > 1: + return AccelerateLaunchCommand(self.executor, rank=0, **self.options) + + # Single GPU + return self.executor + + def make_new_node_executor(self, rank, node, base): + config = base.pack.config + + pack = self.make_new_node_pack(rank, node, base) + + return DockerRunCommand( + AccelerateLaunchCommand(pack, rank=rank), + config["system"].get("docker_image"), + ) + + +def activator_script(): + """Scripts that activate the venv just before executing a script + + Useful for commands that SSH somewhere and need to execute a command in a particular venv + """ + + path = XPath(__file__).parent.parent / "scripts" / "activator" + assert path.exists() + return str(path) + + # Accelerate class AccelerateLaunchCommand(SingleCmdCommand): """Execute a `pack.BasePackage` with Accelerate @@ -801,28 +849,33 @@ def _argv(self, **_) -> List: num_machines = max(1, len(nodes) + 1) - ngpu = len(get_gpu_info()["gpus"].values()) + # Cant do that maybe this run is constrained + # ngpu = len(get_gpu_info()["gpus"].values()) + + ngpu = len(self.pack.config["devices"]) nproc = ngpu * num_machines assert nproc > 0, f"nproc: {nproc} num_machines: {num_machines} ngpu: {ngpu}" - deepspeed_argv = ( - [ + + if self.pack.config.get("use_deepspeed", False): + deepspeed_argv = [ "--use_deepspeed", "--deepspeed_multinode_launcher=standard", "--zero_stage=2", ] - if self.pack.config["use_deepspeed"] - else ["--multi_gpu"] - ) - - cpu_per_process = self.pack.resolve_argument('--cpus_per_gpu') + elif ngpu > 1: + deepspeed_argv = ["--multi_gpu"] + else: + deepspeed_argv = [] + + cpu_per_process = self.pack.resolve_argument('--cpus_per_gpu', 4) return [ # -- Run the command in the right venv # This could be inside the SSH Command # but it would need to be repeated for Docker # could be its own Command like VenvCommand that execute code # inside a specifc venv - f"{self.pack.dirs.code / 'activator'}", + activator_script(), f"{self.pack.dirs.venv}", # -- "accelerate", @@ -832,7 +885,7 @@ def _argv(self, **_) -> List: f"--machine_rank={self.rank}", f"--num_machines={num_machines}", *deepspeed_argv, - f"--gradient_accumulation_steps={self.pack.config['gradient_accumulation_steps']}", + f"--gradient_accumulation_steps={self.pack.config.get('gradient_accumulation_steps', 1)}", f"--num_cpu_threads_per_process={cpu_per_process}", f"--main_process_ip={manager['ip']}", f"--main_process_port={manager['port']}", diff --git a/milabench/commands/executors.py b/milabench/commands/executors.py index fe03c33eb..1ac456639 100644 --- a/milabench/commands/executors.py +++ b/milabench/commands/executors.py @@ -10,7 +10,7 @@ from ..syslog import syslog -async def execute(pack, *args, cwd=None, env={}, external=False, **kwargs): +async def execute(pack, *args, cwd=None, env={}, external=False, use_stdout=False, **kwargs): """Run a command in the virtual environment. Unless specified otherwise, the command is run with @@ -35,6 +35,7 @@ async def execute(pack, *args, cwd=None, env={}, external=False, **kwargs): return await run( final_args, **kwargs, + use_stdout=use_stdout, info={"pack": pack}, env=exec_env, constructor=BenchLogEntry, diff --git a/milabench/pack.py b/milabench/pack.py index 5ad45ccff..8ace63628 100644 --- a/milabench/pack.py +++ b/milabench/pack.py @@ -506,10 +506,12 @@ def build_run_plan(self) -> "cmd.Command": pack = cmd.PackCommand(self, *self.argv, lazy=True) return cmd.VoirCommand(pack, cwd=main.parent) - def resolve_argument(self, name): + def resolve_argument(self, name, default): """Resolve as single placeholder argument""" - placeholder = str(self.config["argv"][name]) - return self.resolve_placeholder(placeholder) + placeholder = self.config.get("argv", {}).get(name) + if placeholder: + return self.resolve_placeholder(placeholder) + return default def resolve_placeholder(self, placeholder): """Resolve as single placeholder argument diff --git a/milabench/scripts/activator b/milabench/scripts/activator new file mode 100755 index 000000000..083c28cb1 --- /dev/null +++ b/milabench/scripts/activator @@ -0,0 +1,7 @@ +#!/bin/bash + +venv="$1" +shift + +source "$venv"/bin/activate +exec "$@"