diff --git a/README.md b/README.md index 5531a1253..526398938 100644 --- a/README.md +++ b/README.md @@ -39,7 +39,7 @@ It will include all of the necessary data docker run -it --rm --ipc=host --gpus=all \ -v $(pwd)/results:/milabench/envs/runs \ $MILABENCH_IMAGE \ - milabench run + bash -c "milabench prepare && milabench run" ================= Benchmark results diff --git a/benchmarks/geo_gnn/requirements-pre.in b/benchmarks/geo_gnn/requirements-pre.in index 07d44c05a..08ed5eeb4 100644 --- a/benchmarks/geo_gnn/requirements-pre.in +++ b/benchmarks/geo_gnn/requirements-pre.in @@ -1 +1 @@ -torch<2.4 \ No newline at end of file +torch \ No newline at end of file diff --git a/benchmarks/geo_gnn/requirements.in b/benchmarks/geo_gnn/requirements.in index 4887d68ae..6fbdd7dea 100644 --- a/benchmarks/geo_gnn/requirements.in +++ b/benchmarks/geo_gnn/requirements.in @@ -1,6 +1,3 @@ - ---find-links https://data.pyg.org/whl/torch-2.3.0+cu121.html - voir>=0.2.17,<0.3 torch-geometric torch-cluster diff --git a/benchmarks/llm/dev.yaml b/benchmarks/llm/dev.yaml index 44386f209..e965769b1 100644 --- a/benchmarks/llm/dev.yaml +++ b/benchmarks/llm/dev.yaml @@ -13,6 +13,27 @@ _llm: method: per_gpu +llm-rlhf-single: + inherits: _llm + definition: . + install-variant: unpinned + plan: + method: per_gpu + + argv: + "{milabench_code}/recipes/lora_finetune_single_device.py": true + --config: "{milabench_code}/configs/llama3_8B_lora_single_device.yaml" + epochs=1: true + output_dir={milabench_extra}/output: true + tokenizer.path={milabench_data}/llama3_8B/original/tokenizer.model: true + checkpointer.checkpoint_dir={milabench_data}/llama3_8B/original: true + checkpointer.output_dir={milabench_data}/llama3_8B/: true + metric_logger.log_dir={milabench_extra}/metrics: true + repo_id="meta-llama/Meta-Llama-3.1-8B": true + batch_size=8: true + gradient_accumulation_steps=8: true + + llm-lora-single: inherits: _llm definition: . diff --git a/benchmarks/llm/recipes/ppo_full_finetune_single_device.py b/benchmarks/llm/recipes/ppo_full_finetune_single_device.py new file mode 100644 index 000000000..8ee77c06a --- /dev/null +++ b/benchmarks/llm/recipes/ppo_full_finetune_single_device.py @@ -0,0 +1,1084 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import math +import os +import sys +from functools import partial +from itertools import chain +from typing import Any, Dict, List, Optional, Tuple +from warnings import warn + +import torch +from omegaconf import DictConfig, ListConfig +from torch import nn +from torch.optim import Optimizer +from torch.utils.data import DataLoader, DistributedSampler +from torchtune import config, modules, utils +from torchtune.datasets import ConcatDataset +from torchtune.modules import rlhf +from torchtune.modules.rlhf import PPOStats, Trajectory +from torchtune.recipe_interfaces import FTRecipeInterface +from tqdm import tqdm + + +log = utils.get_logger("DEBUG") + + +class PPOFullFinetuneRecipeSingleDevice(FTRecipeInterface): + """ + Full finetuning recipe for RLHF with PPO for dense transformer-based LLMs such as LLama2. This recipe is optimized + for single GPU training. Training on CPU is not supported. + + This implementation is based on `Learning to summarize from human feedback None: + + self._device = utils.get_device(device=cfg.device) + self._dtype = utils.get_dtype(cfg.dtype, device=self._device) + + # Disable for fp16, as we haven't validated "full" fp16 with this recipe, nor + # enabled necessary features such as gradient scaling. + if self._dtype == torch.float16: + raise RuntimeError( + "full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead." + ) + + # logging attributes + self._output_dir = cfg.output_dir + self._log_every_n_steps = cfg.get("log_every_n_steps", 1) + self._log_peak_memory_stats = cfg.get("log_peak_memory_stats", False) + + # These are public properties which are updated by the checkpoint loader + # when ``resume_from_checkpoint`` is `True` or validated in tests + self.seed = utils.set_seed(seed=cfg.seed) + # manually setting up a generator for the recipe + self._rng = torch.Generator(self._device).manual_seed(self.seed) + self._total_steps = 0 + self._steps_run = 0 + self._total_epochs = 0 + self._epochs_run = 0 + self.global_step = 0 + + # Training cfg + self._resume_from_checkpoint = cfg.resume_from_checkpoint + self._gradient_accumulation_steps = cfg.gradient_accumulation_steps + + def setup(self, cfg: DictConfig) -> None: + """ + Sets up the recipe state correctly. This includes setting recipe attributes based + on the ``resume_from_checkpoint`` flag. + """ + self._metric_logger = config.instantiate(cfg.metric_logger) + + # log config with parameter override + self._metric_logger.log_config(cfg) + + # setup checkpointers + ( + self._policy_checkpointer, + ref_policy_checkpointer, + self._value_checkpointer, + reward_checkpointer, + ) = self._setup_checkpointers( + cfg.checkpointer, + cfg.ref_policy_checkpointer, + cfg.value_checkpointer, + cfg.reward_checkpointer, + ) + + # load policy checkpoints + policy_model_checkpoint_dict = self._policy_checkpointer.load_checkpoint() + ref_policy_state_dict = ref_policy_checkpointer.load_checkpoint() + + # load reward and value model checkpoints + value_model_checkpoint_dict = self._value_checkpointer.load_checkpoint() + reward_model_state_dict = reward_checkpointer.load_checkpoint() + + # update recipe state + # ``_setup_model`` handles initialization and loading the state dict. This method + # should be called before ``_setup_optimizer`` since transforming the optimizer + # state dict requires the model + self._model_compile = cfg.compile + self._optimizer_in_bwd = cfg.optimizer_in_bwd + ( + self._policy_model, + self._value_model, + self._reward_model, + self._ref_policy_model, + ) = self._setup_model( + cfg_model=cfg.policy_model, + cfg_reward_value_model=cfg.reward_and_value_model, + enable_activation_checkpointing=cfg.enable_activation_checkpointing, + compile_model=self._model_compile, + policy_state_dict=policy_model_checkpoint_dict[utils.MODEL_KEY], + ref_policy_state_dict=ref_policy_state_dict[utils.MODEL_KEY], + value_model_state_dict=value_model_checkpoint_dict[utils.MODEL_KEY], + reward_model_state_dict=reward_model_state_dict[utils.MODEL_KEY], + ) + + # setup tokenizer + self._tokenizer = config.instantiate(cfg.tokenizer) + log.info("Tokenizer is initialized from file.") + + # _setup_optimizer should take in ckpt_dict only if training is resumed from + # checkpoint. Transforming the opt state dict is handled by this method + self._optimizer = self._setup_optimizer( + cfg_optimizer=cfg.optimizer, + optimizer_in_bwd=cfg.optimizer_in_bwd, + opt_state_dict=( + policy_model_checkpoint_dict[utils.OPT_KEY] + if self._resume_from_checkpoint + else None + ), + ) + + self._loss_fn = config.instantiate(cfg.loss) + log.info("Loss is initialized.") + + # sampler and dataloader depends on the tokenizer and should be set + # setup afterit is initialized + self._sampler, self._dataloader = self._setup_data( + cfg_dataset=cfg.dataset, + shuffle=cfg.shuffle, + batch_size=cfg.batch_size, + ) + + self._setup_training_parameters(cfg) + self._setup_training_hyperparameters(cfg) + + if self._resume_from_checkpoint: + self._update_recipe_state(policy_model_checkpoint_dict) + + # one "step" is a single gradient update update over a minibatch of trajectories + self.global_step = ( + self._steps_run + * self._ppo_epochs + * (self.batch_size // self._ppo_batch_size) + ) + + def _setup_training_hyperparameters(self, cfg) -> None: + """ + Sets up the training hyperparameters for the recipe. This includes the GAE hyperparameters, + generation hyperparameters, reward masking hyperparameters, and stop token ids. + """ + + self._kl_coeff = cfg.kl_coeff + # GAE hyperparameters + self._gamma = cfg.gamma + self._lmbda = cfg.lmbda + self._whiten_rewards = cfg.whiten_rewards + + # trajectory generation args + self._temperature = cfg.temperature + self._top_k = cfg.top_k + self._max_generated_tokens = cfg.max_generated_tokens + + # reward masking args + self._min_response_length = cfg.min_response_length + self._penalise_no_eos = cfg.penalise_no_eos + self._reward_penalty = cfg.reward_penalty + + # lots of hand holding for stop tokens + if cfg.get("stop_token_ids", False): + stop_token_ids = cfg.stop_token_ids + if self._tokenizer.eos_id not in stop_token_ids: + warn( + f"tokenizer eos_id ({self._tokenizer.eos_id}) is not in stop_token_ids ({stop_token_ids})." + "This may lead to unexpected behaviour." + ) + else: + if not hasattr(self._tokenizer.stop_tokens): + warn( + "No stop tokens defined in tokenizer, and no stop_token_ids provided. This may lead to unexpected behaviour." + ) + stop_token_ids = [] + else: + stop_token_ids = self._tokenizer.stop_tokens + self._stop_token_ids = torch.tensor(stop_token_ids, device=self._device) + + def _setup_training_parameters(self, cfg: DictConfig) -> None: + """ + Validates and sets up parameters for used during training and for tracking training state, + batch sizes for model forward passes during trajectory generation, PPO minibatches, and + PPO microbatches for gradient accumulation. + + Raises + - ValueError if: + - batch_size is not divisible by forward_batch_size + - batch_size is not divisible by ppo_batch_size + - ppo_batch_size is not divisible by gradient_accumulation_steps + - num_steps is less than batch_size + - gradient_accumulation_steps > 1 and optimizer_in_bwd is True + """ + self.batch_size = cfg.batch_size + self._forward_batch_size = cfg.forward_batch_size + self._ppo_epochs = cfg.ppo_epochs + self._ppo_batch_size = cfg.ppo_batch_size + self._gradient_accumulation_steps = cfg.gradient_accumulation_steps + self._ppo_backward_batch_size = ( + cfg.ppo_batch_size // self._gradient_accumulation_steps + ) + + if self.batch_size % self._forward_batch_size != 0: + raise ValueError( + f"batch_size ({self.batch_size}) must be exactly divisible by " + f"forward_batch_size ({self._forward_batch_size})." + ) + if self.batch_size % self._ppo_batch_size != 0: + raise ValueError( + f"batch_size ({self.batch_size}) must be exactly divisible by " + f"ppo_batch_size ({self._ppo_batch_size})." + ) + if self._ppo_batch_size % self._gradient_accumulation_steps != 0: + raise ValueError( + f"ppo_batch_size ({self._ppo_batch_size}) must be exactly divisible " + f"by gradient_accumulation_steps ({self._gradient_accumulation_steps})." + ) + + if self._gradient_accumulation_steps > 1 and self._optimizer_in_bwd: + raise RuntimeError( + "Gradient accumulation is not supported with optimizer in bwd." + "Please set gradient_accumulation_steps=1, or optimizer_in_bwd=False." + ) + + self._total_steps = cfg.num_steps // self.batch_size + batches_per_epoch = max( + 1, len(self._dataloader) + ) # when we only have a single batch in the dataset + + self._total_epochs = math.ceil(self._total_steps / batches_per_epoch) + if self._total_steps == 0: + raise ValueError( + f"num_steps {cfg.num_steps} must be greater than the batch size {self.batch_size}." + ) + if self._total_steps < len(self._dataloader): + warn( + f"There are fewer total steps ({self._total_steps}, (num_steps//batch_size) " + f"than there are batches ({len(self._dataloader)}) in the dataset. " + f"Training will stop after ({self._total_steps}) steps without saving intermediate checkpoints" + ) + if (self._total_steps > batches_per_epoch) and ( + self._total_steps % batches_per_epoch != 0 + ): + warn( + f"num_steps ({cfg.num_steps}) is not exactly divisible by " + f"the number of batches in the dataset ({batches_per_epoch}). " + f"Intermediate checkpoints will only be saved every {batches_per_epoch} steps." + ) + log.info( + f"Total steps to run: {self._total_steps}, Total epochs to run: {self._total_epochs}" + ) + + def _setup_checkpointers( + self, + policy_cfg: DictConfig, + ref_policy_cfg: DictConfig, + value_cfg: DictConfig, + reward_cfg: DictConfig, + ) -> Tuple[ + utils.Checkpointer, utils.Checkpointer, utils.Checkpointer, utils.Checkpointer + ]: + """ + Sets up checkpointers for policy, reference policy, value, and reward models. + Only the policy checkpoint handles recipe state for resuming from checkpoints. + """ + + if not self._resume_from_checkpoint: + assert policy_cfg.checkpoint_dir == ref_policy_cfg.checkpoint_dir, ( + "Policy and reference policy should be loaded from the same checkpoint directories" + f"at the start of training. Found: {policy_cfg.checkpoint_dir} and" + f"{ref_policy_cfg.checkpoint_dir}" + ) + assert policy_cfg.checkpoint_files == ref_policy_cfg.checkpoint_files, ( + "Policy and reference policy should be loaded from the same checkpoint files" + f"at the start of training. Found: {policy_cfg.checkpoint_files} and" + f"{ref_policy_cfg.checkpoint_files}" + ) + + policy_checkpointer = config.instantiate( + policy_cfg, + resume_from_checkpoint=self._resume_from_checkpoint, + ) + + ref_policy_checkpointer = config.instantiate( + ref_policy_cfg, + resume_from_checkpoint=False, + ) + + value_checkpointer = config.instantiate( + value_cfg, + resume_from_checkpoint=False, + ) + + reward_checkpointer = config.instantiate( + reward_cfg, + resume_from_checkpoint=False, + ) + + return ( + policy_checkpointer, + ref_policy_checkpointer, + value_checkpointer, + reward_checkpointer, + ) + + def _setup_model( + self, + cfg_model: DictConfig, + cfg_reward_value_model: DictConfig, + enable_activation_checkpointing: bool, + compile_model: bool, + policy_state_dict: Dict[str, Any], + ref_policy_state_dict: Dict[str, Any], + value_model_state_dict: Dict[str, Any], + reward_model_state_dict: Dict[str, Any], + ) -> Tuple[nn.Module, nn.Module, nn.Module]: + """ + Sets up the policy model, reference policy model, reward model, and value model. + """ + + with utils.set_default_dtype(self._dtype), self._device: + policy_model = config.instantiate(cfg_model) + ref_policy_model = config.instantiate(cfg_model) + reward_model = config.instantiate(cfg_reward_value_model) + value_model = config.instantiate(cfg_reward_value_model) + + if enable_activation_checkpointing: + utils.set_activation_checkpointing( + policy_model, auto_wrap_policy={modules.TransformerDecoderLayer} + ) + utils.set_activation_checkpointing( + value_model, auto_wrap_policy={modules.TransformerDecoderLayer} + ) + + policy_model.load_state_dict(policy_state_dict) + ref_policy_model.load_state_dict(ref_policy_state_dict) + + reward_missing, reward_unexpected = reward_model.load_state_dict( + reward_model_state_dict, strict=False + ) + value_missing, value_unexpected = value_model.load_state_dict( + value_model_state_dict, strict=False + ) + + # some extra validation for HF classifier checkpoints with a `score.bias` present + assert ( + reward_missing == value_missing == [] + ), f"Missing keys in reward ({reward_missing}) and value model ({value_missing}) state dicts." + + if reward_unexpected or value_unexpected: + # the only unexpected keys should be when pre-trained HF models were saved with + # bias=True in final classification layers. This happens when training a reward model with TRL. + assert ( + reward_unexpected == value_unexpected == ["output.bias"] + ), f"Unexpected keys in reward ({reward_unexpected}) and value model ({value_unexpected}) state dicts." + + # Validate models were loaded in with the expected dtype. + utils.validate_expected_param_dtype( + value_model.named_parameters(), dtype=self._dtype + ) + utils.validate_expected_param_dtype( + reward_model.named_parameters(), dtype=self._dtype + ) + utils.validate_expected_param_dtype( + value_model.named_parameters(), dtype=self._dtype + ) + utils.validate_expected_param_dtype( + ref_policy_model.named_parameters(), dtype=self._dtype + ) + + log.info(f"Models are initialized with precision {self._dtype}.") + + # disabling dropout if found - non-determinism leads to issues in e.g. comparing logprobs + # between ref policy and current policy + for module in policy_model.modules(): + if isinstance(module, torch.nn.Dropout): + warn( + f"Dropout found in {module}. This is likely to cause issues during training. Disabling." + ) + module.p = 0 + for module in value_model.modules(): + if isinstance(module, torch.nn.Dropout): + warn( + f"Dropout found in {module}. This is likely to cause issues during training. Disabling." + ) + module.p = 0 + + # disabling grad and dropout in reward and reference policy models + reward_model.eval() + ref_policy_model.eval() + + for p in reward_model.parameters(): + p.requires_grad = False + + for p in ref_policy_model.parameters(): + p.requires_grad = False + + # Compile model, if enabled. + if compile_model: + backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor") + log.info("Compiling models with torch.compile...") + + policy_model.compile(backend=backend) + reward_model.compile(backend=backend) + ref_policy_model.compile(backend=backend) + value_model.compile(backend=backend) + + if self._device.type == "cuda": + memory_stats = utils.get_memory_stats(device=self._device) + utils.log_memory_stats(memory_stats) + + return policy_model, value_model, reward_model, ref_policy_model + + def _setup_optimizer( + self, + cfg_optimizer: DictConfig, + optimizer_in_bwd: bool = False, + opt_state_dict: Optional[Dict[str, Any]] = None, + ) -> Optimizer: + + if optimizer_in_bwd: + # Maintain a dict of optims for every parameter. + optim_dict = { + p: config.instantiate(cfg_optimizer, [p]) + for p in chain( + self._policy_model.parameters(), self._value_model.parameters() + ) + } + # Register optimizer step hooks on the models to run optimizer in backward. + utils.register_optim_in_bwd_hooks( + model=self._policy_model, optim_dict=optim_dict + ) + utils.register_optim_in_bwd_hooks( + model=self._value_model, optim_dict=optim_dict + ) + # Create a wrapper for checkpoint save/load of optimizer states when running in backward. + self._optim_ckpt_wrapper = utils.create_optim_in_bwd_wrapper( + model=self._policy_model, optim_dict=optim_dict + ) + self._optim_ckpt_wrapper = utils.create_optim_in_bwd_wrapper( + model=self._value_model, optim_dict=optim_dict + ) + # Load optimizer states. If optimizer states are being restored in an optimizer in backward + # run, these need to have been saved with the same setting. Cannot restore from runs that did not + # use optimizer in backward. + if opt_state_dict is not None: + try: + self._optim_ckpt_wrapper.load_state_dict(opt_state_dict) + except BaseException as e: + raise RuntimeError( + "Failed loading in-backward optimizer checkpoints." + "Please make sure run being restored from was using in-backward optimizer." + ) from e + log.info("In-backward optimizers are set up.") + return None + else: + optimizer = config.instantiate( + cfg_optimizer, + chain(self._policy_model.parameters(), self._value_model.parameters()), + ) + if opt_state_dict: + optimizer.load_state_dict(opt_state_dict) + + log.info("Optimizer is initialized.") + return optimizer + + def _setup_data( + self, cfg_dataset: DictConfig, shuffle: bool, batch_size: int + ) -> Tuple[DistributedSampler, DataLoader]: + """ + All data related setup happens here. + """ + if isinstance(cfg_dataset, ListConfig): + datasets = [ + config.instantiate(single_cfg_dataset, tokenizer=self._tokenizer) + for single_cfg_dataset in cfg_dataset + ] + ds = ConcatDataset(datasets=datasets) + else: + ds = config.instantiate(cfg_dataset, tokenizer=self._tokenizer) + + sampler = DistributedSampler( + ds, + num_replicas=1, + rank=0, + shuffle=shuffle, + seed=0, + ) + dataloader = DataLoader( + dataset=ds, + sampler=sampler, + batch_size=batch_size, + collate_fn=partial( + rlhf.left_padded_collate, + padding_idx=self._tokenizer.pad_id, + ), + drop_last=True, + ) + + return sampler, dataloader + + def save_checkpoint( + self, epoch: int, is_intermediate_checkpoint: bool = False + ) -> None: + """ + Save state dict to file. The recipe save_checkpoint method is responsible for + correctly creating the checkpoint dict and passing to the checkpointer. + """ + policy_ckpt_dict = {utils.MODEL_KEY: self._policy_model.state_dict()} + value_ckpt_dict = {utils.MODEL_KEY: self._value_model.state_dict()} + + # if training is in-progress, checkpoint the optimizer state and rng state as well + if is_intermediate_checkpoint: + policy_ckpt_dict.update( + { + utils.SEED_KEY: self.seed, + utils.EPOCHS_KEY: self._epochs_run, + utils.TOTAL_EPOCHS_KEY: self._total_epochs, + utils.MAX_STEPS_KEY: self._total_steps, + utils.STEPS_KEY: self._steps_run, + utils.RNG_KEY: self._rng.get_state(), + } + ) + if not self._optimizer_in_bwd: + policy_ckpt_dict[utils.OPT_KEY] = self._optimizer.state_dict() + else: + policy_ckpt_dict[utils.OPT_KEY] = self._optim_ckpt_wrapper.state_dict() + + self._policy_checkpointer.save_checkpoint( + policy_ckpt_dict, + epoch=epoch, + intermediate_checkpoint=is_intermediate_checkpoint, + ) + + self._value_checkpointer.save_checkpoint( + value_ckpt_dict, + epoch=epoch, + intermediate_checkpoint=False, + ) + + def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None: + """ + Updates the recipe state from checkpoint. + """ + # If seed or total_steps, or total_epochs don't match, + # warn the user and overwrite. + try: + if ( + self.seed != ckpt_dict[utils.SEED_KEY] + or self._total_steps != ckpt_dict[utils.MAX_STEPS_KEY] + or self._total_epochs != ckpt_dict[utils.TOTAL_EPOCHS_KEY] + ): + warn( + message="""Configured value for seed, total_steps, or total_epochs + does not match the value stored in checkpoint.""" + ) + self.seed = utils.set_seed(seed=ckpt_dict[utils.SEED_KEY]) + self._rng.set_state(ckpt_dict[utils.RNG_KEY]) + self._steps_run = ckpt_dict[utils.STEPS_KEY] + self._total_steps = ckpt_dict[utils.MAX_STEPS_KEY] + self._total_epochs = ckpt_dict[utils.TOTAL_EPOCHS_KEY] + self._epochs_run = ckpt_dict[utils.EPOCHS_KEY] + + except KeyError as e: + raise KeyError from e( + "Checkpoint does not contain the required keys needed for updating recipe state." + "Are you sure you passed in the right recipe checkpoint?" + ) + + def generate_trajectory(self, input_ids: torch.Tensor) -> Trajectory: + """ + Generates a trajectory given the current policy and value models, the reference policy model, the reward model, + and batch of inputs. This is done over the following steps: + + 1: Generate responses, and logits corresponding to the responses using the current policy, + generating (query, response) pairs. + 2. Estimate logprobs of the generated responses using the current policy. + 3. Estimate values from the generated responses using the current value function. + 4. Replace any tokens in the response after the first stop token (usually EOS token) with padding, + producting truncated responses. + 5. Run the reward model on the (query, truncated-response) pairs. + 6. Mask out all the invalid values in the trajectory due to padding tokens. + + Args: + input_ids (torch.Tensor): tensor of input token IDs with shape [b, seq_length] + + Returns: + Trajectory: An instance of :class:`~torchtune.modules.rlhf.Trajectory` comprising + the current trajectory. + """ + batch_size, context_length = input_ids.shape + + # step 1: generate responses, and logits corresponding to the responses using the current policy + query_responses, logits = rlhf.generate_with_logits( + model=self._policy_model, + prompt=input_ids, + max_generated_tokens=self._max_generated_tokens, + temperature=self._temperature, + top_k=self._top_k, + pad_id=self._tokenizer.pad_id, + rng=self._rng, + ) + + responses = query_responses[:, context_length:].clone() + query_response_padding_masks = query_responses == self._tokenizer.pad_id + + # step 1.1 create attention masks and position IDs for any padding tokens in inputs, used for future forward passes + masks = rlhf.get_causal_mask(~(query_response_padding_masks)) + position_ids = (~query_response_padding_masks).cumsum(-1) - ( + ~query_response_padding_masks + ).long() + position_ids = position_ids.type(torch.int) + + del query_response_padding_masks + + # step 2. estimate logprobs of the responses using the current policy + logits = logits[:, context_length - 1 :] + logprobs = rlhf.logits_to_logprobs(logits, responses, self._temperature) + + del logits + + # step 2.1 estimate logprobs of the responses using the reference policy + ref_logits = self._ref_policy_model( + query_responses, input_pos=position_ids, mask=masks + ) + ref_logits = rlhf.truncate_sequence_for_logprobs(ref_logits, context_length) + ref_logprobs = rlhf.logits_to_logprobs(ref_logits, responses, self._temperature) + + del ref_logits + + # step 3. estimate values from the responses using the value function + values = self._value_model(query_responses, input_pos=position_ids, mask=masks) + values = rlhf.truncate_sequence_for_logprobs(values, context_length).squeeze(-1) + + # step 4. replace any tokens in the responses after the first stop token (usually EOS token) with padding + # resulting in truncated responses + response_padding_masks, responses = rlhf.truncate_sequence_at_first_stop_token( + responses, self._stop_token_ids, self._tokenizer.pad_id + ) + + # step 5. run the reward model on the (query, truncated-response) pairs + scores = self._reward_model( + torch.cat([input_ids, responses], dim=1), + input_pos=position_ids, + mask=masks, + ) + + del responses + + # step 5.1 the scores from the reward model are the logits for the last non-padding token in + # each (query, truncated-response) pair + seq_lens = utils.get_unmasked_sequence_lengths(response_padding_masks) + scores = scores[torch.arange(batch_size), seq_lens + context_length].squeeze(-1) + + # step 5.2 if configured, apply any penalties for sequences without EOS tokens + # or shorter than a certain length + if self._penalise_no_eos or self._min_response_length: + reward_penalty_mask = rlhf.get_reward_penalty_mask( + response_padding_masks, + seq_lens, + self._penalise_no_eos, + self._min_response_length, + ) + scores[reward_penalty_mask] = self._reward_penalty + + # step 6. mask out all the invalid values in the trajectory due to padding tokens + logprobs[response_padding_masks] = 1.0 + ref_logprobs[response_padding_masks] = 1.0 + + # step 6.1 values are masked out *after* the last valid token in the response + value_seq_idxs = torch.where( + (seq_lens > 0) & (seq_lens < self._max_generated_tokens - 1), + seq_lens + 1, + seq_lens, + ) + value_padding_masks = response_padding_masks.clone() + value_padding_masks[ + torch.arange(batch_size, device=value_padding_masks.device), + value_seq_idxs, + ] = False + + values[value_padding_masks] = 0.0 + + return Trajectory( + query_responses=query_responses, + logprobs=logprobs, + ref_logprobs=ref_logprobs, + values=values, + masks=masks, + position_ids=position_ids, + response_padding_masks=response_padding_masks, + value_padding_masks=value_padding_masks, + value_seq_idxs=value_seq_idxs, + scores=scores, + seq_lens=seq_lens, + ) + + def generate_trajectory_batched(self, input_ids: torch.Tensor) -> Trajectory: + """ + Generates a ``self.batch_size`` batch of trajectories using `self._forward_batch_size` batch sizes. + See ``generate_trajectory`` for more details. + + Args: + input_ids (torch.Tensor): tensor of input token IDs with shape [b, seq_length] + + Returns: + Trajectory: An instance of :class:`~torchtune.modules.rlhf.Trajectory`, comprising + the current trajectory. + """ + trajectories: List[Trajectory] = [] + with torch.no_grad(): + for batch_start in range(0, self.batch_size, self._forward_batch_size): + batch_input_ids = input_ids[ + batch_start : batch_start + self._forward_batch_size + ] + trajectories.append(self.generate_trajectory(batch_input_ids)) + return Trajectory(*map(torch.cat, zip(*trajectories))) + + def train(self) -> None: + """ + The core training loop.""" + + if self._model_compile: + log.info( + "NOTE: torch.compile is enabled and model is compiled in first forward." + "Expect a relatively slow first iteration." + ) + # zero out the gradients before starting training + if not self._optimizer_in_bwd: + self._optimizer.zero_grad() + + training_completed = False + pbar = tqdm(total=self._total_steps, initial=self._steps_run) + for curr_epoch in range(self._epochs_run, self._total_epochs): + # Update the sampler to ensure data is correctly shuffled across epochs + # in case shuffle is True + self._sampler.set_epoch(curr_epoch) + + for _, batch in enumerate(self._dataloader): + batch = batch.to(self._device) + _, context_length = batch.shape + + # step 1. generate the trajectory using: + # - the current policy (pi_theta) + # - the current value function (V_phi) + # - the reference frozen policy model (pi_theta_0) + trajectory = self.generate_trajectory_batched(batch) + + # step 2. get the rewards for the current trajectory. these are based on: + # - the divergence between the current policy and the reference policy + # - the scores from the reward model + rewards, kl, kl_rewards = rlhf.get_rewards_ppo( + trajectory.scores, + trajectory.logprobs, + trajectory.ref_logprobs, + self._kl_coeff, + trajectory.value_seq_idxs, + ) + + # step 3. estimate the advantages using Generalized Advantage Estimation (GAE) + advantages, returns = rlhf.estimate_advantages( + trajectory.values, + rewards, + self._gamma, + self._lmbda, + masks=~trajectory.response_padding_masks, + ) + + # step 4. optimise using the PPO objective over multiple epochs + ppo_stats: List[PPOStats] = [] + for _ in range(self._ppo_epochs): + batch_idxs = torch.randperm(self.batch_size, device=self._device) + for i in range(0, self.batch_size, self._ppo_batch_size): + mini_batch_idxs = batch_idxs[i : i + self._ppo_batch_size] + + batch_ppo_stats: List[PPOStats] = [] + for j in range( + 0, self._ppo_batch_size, self._ppo_backward_batch_size + ): + backward_batch_idxs = mini_batch_idxs[ + j : j + self._ppo_backward_batch_size + ] + + batch_trajectory = Trajectory( + *map( + partial( + torch.index_select, + dim=0, + index=backward_batch_idxs, + ), + trajectory, + ) + ) + batch_ppo_stats.append( + self._ppo_step( + batch_trajectory, + advantages[backward_batch_idxs], + returns[backward_batch_idxs], + context_length, + ) + ) + del batch_trajectory + + ppo_stats.append(PPOStats(*map(sum, zip(*batch_ppo_stats)))) + + if not self._optimizer_in_bwd: + self._optimizer.step() + self._optimizer.zero_grad(set_to_none=True) + + self.global_step += 1 + + # step 5. profit + self._steps_run += 1 + if self._steps_run % self._log_every_n_steps == 0: + self.log_metrics( + trajectory, + PPOStats(*map(torch.stack, zip(*ppo_stats))), + kl, + kl_rewards, + ) + self.cleanup_after_step( + trajectory, ppo_stats, advantages, returns, kl, kl_rewards + ) + pbar.update(1) + if self._steps_run == self._total_steps: + training_completed = True + break + + # save checkpoint at current epoch + self._epochs_run += 1 + + self.save_checkpoint( + curr_epoch, is_intermediate_checkpoint=not training_completed + ) + if training_completed: + return + + def _ppo_step( + self, + trajectory: Trajectory, + advantages: torch.Tensor, + returns: torch.Tensor, + context_length: int, + ) -> PPOStats: + """ + Perform a single PPO optimisation step over a batch of trajectories and corresponding advantages and returns. + + Args: + trajectory (Trajectory): a batch of trajectories + advantages (torch.Tensor): advantages corresponding to the trajectories + returns (torch.Tensor): returns corresponding the trajectories + context_length (int): input ids sequence length + + Returns: + PPOStats: An instance of :class:`~torchtune.modules.rlhf.PPOStats`, a NamedTuple containing: + - loss (torch.Tensor): The total PPO loss. + - policy_loss (torch.Tensor): The policy function loss. + - value_loss (torch.Tensor): The value function loss. + - ratios (torch.Tensor): The ratio between the current and old policy probabilities. + - clipfrac (torch.Tensor): The fraction of ratios that were clipped. + - approx_policy_kls: Average estimated KL divergence between the policy before and after the optimisation step. + + """ + # estimate logprobs from the policy at the current optimisation step + pi_logits = self._policy_model( + trajectory.query_responses, + input_pos=trajectory.position_ids, + mask=trajectory.masks, + ) + pi_logits = rlhf.truncate_sequence_for_logprobs(pi_logits, context_length) + pi_logprobs = rlhf.logits_to_logprobs( + pi_logits, trajectory.query_responses[:, context_length:], self._temperature + ) + pi_logprobs[trajectory.response_padding_masks] = 1.0 + + del pi_logits + + # estimate the values from the value function at the current optimisation step + phi_values = self._value_model( + trajectory.query_responses, + input_pos=trajectory.position_ids, + mask=trajectory.masks, + ) + + phi_values = rlhf.truncate_sequence_for_logprobs( + phi_values, context_length + ).squeeze(-1) + phi_values[trajectory.value_padding_masks] = 0.0 + + # calculate ppo loss + loss, policy_loss, value_loss, ratios, clipfrac = self._loss_fn( + trajectory.logprobs, + pi_logprobs, + advantages, + trajectory.values, + phi_values, + returns, + padding_masks=~trajectory.response_padding_masks, + value_padding_masks=~trajectory.value_padding_masks, + ) + + loss /= self._gradient_accumulation_steps + loss.backward() + + with torch.no_grad(): + approx_policy_kls = ( + 0.5 * (pi_logprobs - trajectory.logprobs).pow(2) + ).mean() + + return PPOStats( + loss, + policy_loss / self._gradient_accumulation_steps, + value_loss / self._gradient_accumulation_steps, + ratios / self._gradient_accumulation_steps, + clipfrac / self._gradient_accumulation_steps, + approx_policy_kls / self._gradient_accumulation_steps, + ) + + def log_metrics( + self, + trajectory: Trajectory, + ppo_stats: PPOStats, + kl: torch.Tensor, + kl_rewards: torch.Tensor, + ) -> None: + """ + Log metrics and statistics for the current step to the metric logger. + """ + log_dict = { + "scores": trajectory.scores.mean(), + "num_stop_tokens": trajectory.response_padding_masks.any(-1).sum(), + "rlhf_reward": trajectory.scores.mean() + kl_rewards.sum(1).mean(), + "kl": kl.sum(1).mean(), + "kl_reward": kl_rewards.sum(1).mean(), + "loss": ppo_stats.loss.mean(), + "policy_loss": ppo_stats.policy_loss.mean(), + "value_loss": ppo_stats.value_loss.mean(), + "clipfrac": ppo_stats.clipfrac.mean(), + "ratios": ppo_stats.ratios.mean(), + "approx_policy_kl": ppo_stats.approx_policy_kls.mean(), + "response_lengths": trajectory.seq_lens.float().mean(), + } + if self._device.type == "cuda" and self._log_peak_memory_stats: + log_dict.update(utils.get_memory_stats(device=self._device)) + + self._metric_logger.log_dict(log_dict, step=self.global_step) + + def cleanup_after_step( + self, + trajectory: Trajectory, + ppo_stats: PPOStats, + advantages: torch.Tensor, + returns: torch.Tensor, + kl: torch.Tensor, + kl_rewards: torch.Tensor, + ) -> None: + """ + Cleanup tensors after each PPO step to free up memory. + """ + # there shouldn't be any floating references to the individual tensors at the this point, so gc can do its thing + for v in trajectory: + del v + del trajectory + for v in ppo_stats: + del v + del ppo_stats + del advantages + del returns + del kl + del kl_rewards + + def cleanup(self, **kwargs) -> None: + self._metric_logger.close() + + +@config.parse +def recipe_main(cfg: DictConfig) -> None: + """ + Entry point for the recipe. + + Configurable parameters are read in the following order: + - Parameters specified in config (see available configs through ``tune ls``) + - Overwritten by arguments from the command-line + """ + config.log_config(recipe_name="PPOFullFinetuneRecipeSingleDevice", cfg=cfg) + recipe = PPOFullFinetuneRecipeSingleDevice(cfg=cfg) + recipe.setup(cfg=cfg) + recipe.train() + recipe.cleanup() + + +if __name__ == "__main__": + sys.exit(recipe_main()) diff --git a/benchmarks/stargan/README.md b/benchmarks/retired/stargan/README.md similarity index 100% rename from benchmarks/stargan/README.md rename to benchmarks/retired/stargan/README.md diff --git a/benchmarks/stargan/benchfile.py b/benchmarks/retired/stargan/benchfile.py similarity index 100% rename from benchmarks/stargan/benchfile.py rename to benchmarks/retired/stargan/benchfile.py diff --git a/benchmarks/stargan/prepare.py b/benchmarks/retired/stargan/prepare.py similarity index 100% rename from benchmarks/stargan/prepare.py rename to benchmarks/retired/stargan/prepare.py diff --git a/benchmarks/stargan/requirements.cuda.txt b/benchmarks/retired/stargan/requirements.cuda.txt similarity index 100% rename from benchmarks/stargan/requirements.cuda.txt rename to benchmarks/retired/stargan/requirements.cuda.txt diff --git a/benchmarks/stargan/requirements.hpu.txt b/benchmarks/retired/stargan/requirements.hpu.txt similarity index 100% rename from benchmarks/stargan/requirements.hpu.txt rename to benchmarks/retired/stargan/requirements.hpu.txt diff --git a/benchmarks/stargan/requirements.in b/benchmarks/retired/stargan/requirements.in similarity index 100% rename from benchmarks/stargan/requirements.in rename to benchmarks/retired/stargan/requirements.in diff --git a/benchmarks/stargan/requirements.rocm.txt b/benchmarks/retired/stargan/requirements.rocm.txt similarity index 100% rename from benchmarks/stargan/requirements.rocm.txt rename to benchmarks/retired/stargan/requirements.rocm.txt diff --git a/benchmarks/stargan/requirements.xpu.txt b/benchmarks/retired/stargan/requirements.xpu.txt similarity index 100% rename from benchmarks/stargan/requirements.xpu.txt rename to benchmarks/retired/stargan/requirements.xpu.txt diff --git a/benchmarks/stargan/stargan/LICENSE b/benchmarks/retired/stargan/stargan/LICENSE similarity index 100% rename from benchmarks/stargan/stargan/LICENSE rename to benchmarks/retired/stargan/stargan/LICENSE diff --git a/benchmarks/stargan/stargan/ORIGIN.md b/benchmarks/retired/stargan/stargan/ORIGIN.md similarity index 100% rename from benchmarks/stargan/stargan/ORIGIN.md rename to benchmarks/retired/stargan/stargan/ORIGIN.md diff --git a/benchmarks/stargan/stargan/README.md b/benchmarks/retired/stargan/stargan/README.md similarity index 100% rename from benchmarks/stargan/stargan/README.md rename to benchmarks/retired/stargan/stargan/README.md diff --git a/benchmarks/stargan/stargan/data_loader.py b/benchmarks/retired/stargan/stargan/data_loader.py similarity index 100% rename from benchmarks/stargan/stargan/data_loader.py rename to benchmarks/retired/stargan/stargan/data_loader.py diff --git a/benchmarks/stargan/stargan/download.sh b/benchmarks/retired/stargan/stargan/download.sh similarity index 100% rename from benchmarks/stargan/stargan/download.sh rename to benchmarks/retired/stargan/stargan/download.sh diff --git a/benchmarks/stargan/stargan/logger.py b/benchmarks/retired/stargan/stargan/logger.py similarity index 100% rename from benchmarks/stargan/stargan/logger.py rename to benchmarks/retired/stargan/stargan/logger.py diff --git a/benchmarks/stargan/stargan/main.py b/benchmarks/retired/stargan/stargan/main.py similarity index 100% rename from benchmarks/stargan/stargan/main.py rename to benchmarks/retired/stargan/stargan/main.py diff --git a/benchmarks/stargan/stargan/model.py b/benchmarks/retired/stargan/stargan/model.py similarity index 100% rename from benchmarks/stargan/stargan/model.py rename to benchmarks/retired/stargan/stargan/model.py diff --git a/benchmarks/stargan/stargan/solver.py b/benchmarks/retired/stargan/stargan/solver.py similarity index 100% rename from benchmarks/stargan/stargan/solver.py rename to benchmarks/retired/stargan/stargan/solver.py diff --git a/benchmarks/stargan/stargan/synth.py b/benchmarks/retired/stargan/stargan/synth.py similarity index 100% rename from benchmarks/stargan/stargan/synth.py rename to benchmarks/retired/stargan/stargan/synth.py diff --git a/benchmarks/stargan/voirfile.py b/benchmarks/retired/stargan/voirfile.py similarity index 100% rename from benchmarks/stargan/voirfile.py rename to benchmarks/retired/stargan/voirfile.py diff --git a/benchmarks/rlhf/Makefile b/benchmarks/rlhf/Makefile deleted file mode 100644 index 97b871cdc..000000000 --- a/benchmarks/rlhf/Makefile +++ /dev/null @@ -1,31 +0,0 @@ -# Use global base if possible -ifndef MILABENCH_BASE - MILABENCH_BASE="base" -endif - -export MILABENCH_BASE - -BENCH_NAME=rlhf -MILABENCH_CONFIG=dev.yaml -MILABENCH_ARGS=--config $(MILABENCH_CONFIG) --base $(MILABENCH_BASE) - -all: - install prepare single gpus nodes - -install: - milabench install $(MILABENCH_ARGS) --force - -prepare: - milabench prepare $(MILABENCH_ARGS) - -tests: - CUDA_VISIBLE_DEVICES=0 milabench run $(MILABENCH_ARGS) - -single: - milabench run $(MILABENCH_ARGS) --select $(BENCH_NAME) - -gpus: - milabench run $(MILABENCH_ARGS) --select $(BENCH_NAME)-gpus - -nodes: - milabench run $(MILABENCH_ARGS) --select $(BENCH_NAME)-nodes diff --git a/benchmarks/rlhf/README.md b/benchmarks/rlhf/README.md deleted file mode 100644 index 9c22d45ca..000000000 --- a/benchmarks/rlhf/README.md +++ /dev/null @@ -1,4 +0,0 @@ - -# Rlhf - -Rewrite this README to explain what the benchmark is! diff --git a/benchmarks/rlhf/dev.yaml b/benchmarks/rlhf/dev.yaml deleted file mode 100644 index 99ab9b21e..000000000 --- a/benchmarks/rlhf/dev.yaml +++ /dev/null @@ -1,53 +0,0 @@ - -rlhf: - inherits: _defaults - definition: . - install-variant: unpinned - install_group: torch - plan: - method: per_gpu - - argv: - --output_dir: models/minimal/ppo - --per_device_train_batch_size: 1 - --gradient_accumulation_steps: 1 - --total_episodes: 30000 - --model_name_or_path: meta-llama/Llama-2-7b-chat-hf - --sft_model_path: meta-llama/Llama-2-7b-chat-hf - --reward_model_path: cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr - --non_eos_penalty: true - --stop_token: eos - --response_length: 53 - --sanity_check: true - - - -# """ -# python examples/scripts/ppo/ppo_tldr.py \ -# --learning_rate 3e-6 \ -# --output_dir models/minimal/ppo \ -# --per_device_train_batch_size 1 \ -# --gradient_accumulation_steps 64 \ -# --total_episodes 30000 \ -# --model_name_or_path EleutherAI/pythia-1b-deduped \ -# --sft_model_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr \ -# --reward_model_path cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr \ -# --non_eos_penalty \ -# --stop_token eos \ -# --response_length 53 \ -# --sanity_check - -# accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml \ -# examples/scripts/ppo/ppo_tldr.py \ -# --output_dir models/minimal/ppo_tldr \ -# --learning_rate 3e-6 \ -# --per_device_train_batch_size 16 \ -# --gradient_accumulation_steps 4 \ -# --total_episodes 1000000 \ -# --model_name_or_path EleutherAI/pythia-1b-deduped \ -# --sft_model_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr \ -# --reward_model_path cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr \ -# --local_rollout_forward_batch_size 16 \ -# --non_eos_penalty \ -# --stop_token eos -# """ \ No newline at end of file diff --git a/benchmarks/rlhf/main.py b/benchmarks/rlhf/main.py deleted file mode 100644 index 7ab48e1d7..000000000 --- a/benchmarks/rlhf/main.py +++ /dev/null @@ -1,126 +0,0 @@ -import multiprocessing -import shutil - -from datasets import load_dataset -from transformers import ( - AutoModelForCausalLM, - AutoModelForSequenceClassification, - AutoTokenizer, - HfArgumentParser, -) - -from trl import ModelConfig -from trl.trainer.ppov2_trainer import PPOv2Config, PPOv2Trainer -from trl.trainer.utils import SIMPLE_QUERY_CHAT_TEMPLATE - - -def main(): - parser = HfArgumentParser((PPOv2Config, ModelConfig)) - config, model_config = parser.parse_args_into_dataclasses() - # remove output_dir if exists - shutil.rmtree(config.output_dir, ignore_errors=True) - - ################ - # Model & Tokenizer - ################ - tokenizer = AutoTokenizer.from_pretrained( - model_config.model_name_or_path, - padding_side="left", - trust_remote_code=model_config.trust_remote_code, - ) - tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - - if tokenizer.chat_template is None: - tokenizer.chat_template = SIMPLE_QUERY_CHAT_TEMPLATE - value_model = AutoModelForSequenceClassification.from_pretrained( - config.reward_model_path, trust_remote_code=model_config.trust_remote_code, num_labels=1 - ) - reward_model = AutoModelForSequenceClassification.from_pretrained( - config.reward_model_path, trust_remote_code=model_config.trust_remote_code, num_labels=1 - ) - import torch - ref_policy = AutoModelForCausalLM.from_pretrained( - config.sft_model_path, - trust_remote_code=model_config.trust_remote_code, - low_cpu_mem_usage=True, - ) - policy = AutoModelForCausalLM.from_pretrained( - config.sft_model_path, - trust_remote_code=model_config.trust_remote_code, - low_cpu_mem_usage=True, - ) - - from peft import prepare_model_for_kbit_training - from peft import LoraConfig - from peft import get_peft_model - - ref_policy = prepare_model_for_kbit_training(ref_policy) - policy = prepare_model_for_kbit_training(policy) - - lora_config = LoraConfig( - r=16, - lora_alpha=32, - lora_dropout=0.05, - bias="none", - task_type="CAUSAL_LM", - ) - ref_policy = get_peft_model(ref_policy, lora_config) - policy = get_peft_model(policy, lora_config) - - ################ - # Dataset - ################ - raw_datasets = load_dataset("trl-internal-testing/tldr-preference-sft-trl-style") - if config.sanity_check: - for key in raw_datasets: - raw_datasets[key] = raw_datasets[key].select(range(1000)) - train_dataset = raw_datasets["train"] - eval_dataset = raw_datasets["validation"] - - def prepare_dataset(dataset, tokenizer): - """pre-tokenize the dataset before training; only collate during training""" - - def tokenize(element): - input_ids = tokenizer.apply_chat_template( - element["messages"][:1], - padding=False, - add_generation_prompt=True, - ) - return {"input_ids": input_ids, "lengths": len(input_ids)} - - return dataset.map( - tokenize, - remove_columns=dataset.column_names, - num_proc=1 if config.sanity_check else multiprocessing.cpu_count(), - load_from_cache_file=not config.sanity_check, - ) - - train_dataset = prepare_dataset(train_dataset, tokenizer) - eval_dataset = prepare_dataset(eval_dataset, tokenizer) - # filtering - train_dataset = train_dataset.filter(lambda x: x["lengths"] <= 512) - eval_dataset = eval_dataset.filter(lambda x: x["lengths"] <= 512) - - assert train_dataset[0]["input_ids"][-1] != tokenizer.eos_token_id, "The last token should not be an EOS token" - ################ - # Training - ################ - print("DONE") - trainer = PPOv2Trainer( - config=config, - tokenizer=tokenizer, - policy=policy, - ref_policy=ref_policy, - reward_model=reward_model, - value_model=value_model, - train_dataset=train_dataset, - eval_dataset=eval_dataset, - ) - trainer.train() - - # trainer.save_model(config.output_dir) - # trainer.generate_completions() - - -if __name__ == "__main__": - main() diff --git a/benchmarks/rlhf/requirements.in b/benchmarks/rlhf/requirements.in deleted file mode 100644 index d0faef03e..000000000 --- a/benchmarks/rlhf/requirements.in +++ /dev/null @@ -1,5 +0,0 @@ -voir>=0.2.17,<0.3 -torch -trl -bitsandbytes -peft \ No newline at end of file diff --git a/benchmarks/rlhf/voirfile.py b/benchmarks/rlhf/voirfile.py deleted file mode 100644 index d93f886cd..000000000 --- a/benchmarks/rlhf/voirfile.py +++ /dev/null @@ -1,38 +0,0 @@ -from dataclasses import dataclass - -from voir import configurable -from voir.instruments import dash, early_stop, log, rate -from benchmate.monitor import monitor_monogpu - -@dataclass -class Config: - """voir configuration""" - - # Whether to display the dash or not - dash: bool = False - - # How often to log the rates - interval: str = "1s" - - # Number of rates to skip before logging - skip: int = 5 - - # Number of rates to log before stopping - stop: int = 20 - - # Number of seconds between each gpu poll - gpu_poll: int = 3 - - -@configurable -def instrument_main(ov, options: Config): - yield ov.phases.init - - if options.dash: - ov.require(dash) - - ov.require( - log("value", "progress", "rate", "units", "loss", "gpudata", context="task"), - early_stop(n=options.stop, key="rate", task="train"), - monitor_monogpu(poll_interval=options.gpu_poll), - ) diff --git a/benchmarks/torch_ppo_atari_envpool/benchfile.py b/benchmarks/torch_ppo_atari_envpool/benchfile.py deleted file mode 100644 index 5625f7ed9..000000000 --- a/benchmarks/torch_ppo_atari_envpool/benchfile.py +++ /dev/null @@ -1,31 +0,0 @@ -from milabench.pack import Package - - -class Torch_ppo_atari_envpool(Package): - # Requirements file installed by install(). It can be empty or absent. - base_requirements = "requirements.in" - - # The preparation script called by prepare(). It must be executable, - # but it can be any type of script. It can be empty or absent. - prepare_script = "prepare.py" - - # The main script called by run(). It must be a Python file. It has to - # be present. - main_script = "main.py" - - # You can remove the functions below if you don't need to modify them. - - def make_env(self): - # Return a dict of environment variables for prepare_script and - # main_script. - return super().make_env() - - async def install(self): - await super().install() # super() call installs the requirements - - async def prepare(self): - await super().prepare() # super() call executes prepare_script - - - -__pack__ = Torch_ppo_atari_envpool diff --git a/benchmarks/torch_ppo_atari_envpool/prepare.py b/benchmarks/torch_ppo_atari_envpool/prepare.py deleted file mode 100755 index 32bd5901d..000000000 --- a/benchmarks/torch_ppo_atari_envpool/prepare.py +++ /dev/null @@ -1,16 +0,0 @@ -#!/usr/bin/env python - -import os - -if __name__ == "__main__": - # If you need the whole configuration: - # config = json.loads(os.environ["MILABENCH_CONFIG"]) - - data_directory = os.environ["MILABENCH_DIR_DATA"] - - # Download (or generate) the needed dataset(s). You are responsible - # to check if it has already been properly downloaded or not, and to - # do nothing if it has been. - print("Hello I am doing some data stuff!") - - # If there is nothing to download or generate, just delete this file. diff --git a/benchmarks/torch_ppo_atari_envpool/Makefile b/benchmarks/torchatari/Makefile similarity index 94% rename from benchmarks/torch_ppo_atari_envpool/Makefile rename to benchmarks/torchatari/Makefile index 81443ce2b..9eb0a30c5 100644 --- a/benchmarks/torch_ppo_atari_envpool/Makefile +++ b/benchmarks/torchatari/Makefile @@ -5,7 +5,7 @@ endif export MILABENCH_BASE -BENCH_NAME=torch_ppo_atari_envpool +BENCH_NAME=torchatari MILABENCH_CONFIG=dev.yaml MILABENCH_ARGS=--config $(MILABENCH_CONFIG) --base $(MILABENCH_BASE) diff --git a/benchmarks/torch_ppo_atari_envpool/README.md b/benchmarks/torchatari/README.md similarity index 100% rename from benchmarks/torch_ppo_atari_envpool/README.md rename to benchmarks/torchatari/README.md diff --git a/benchmarks/rlhf/benchfile.py b/benchmarks/torchatari/benchfile.py similarity index 94% rename from benchmarks/rlhf/benchfile.py rename to benchmarks/torchatari/benchfile.py index a568f6690..1bf4ee785 100644 --- a/benchmarks/rlhf/benchfile.py +++ b/benchmarks/torchatari/benchfile.py @@ -1,7 +1,7 @@ from milabench.pack import Package -class Rlhf(Package): +class Torchatari(Package): # Requirements file installed by install(). It can be empty or absent. base_requirements = "requirements.in" @@ -28,4 +28,4 @@ async def prepare(self): -__pack__ = Rlhf +__pack__ = Torchatari diff --git a/benchmarks/torch_ppo_atari_envpool/dev.yaml b/benchmarks/torchatari/dev.yaml similarity index 92% rename from benchmarks/torch_ppo_atari_envpool/dev.yaml rename to benchmarks/torchatari/dev.yaml index 338bed075..d0df0df1a 100644 --- a/benchmarks/torch_ppo_atari_envpool/dev.yaml +++ b/benchmarks/torchatari/dev.yaml @@ -1,5 +1,5 @@ -torch_ppo_atari_envpool: +torchatari: max_duration: 600 inherits: _defaults definition: . diff --git a/benchmarks/torch_ppo_atari_envpool/main.py b/benchmarks/torchatari/main.py similarity index 100% rename from benchmarks/torch_ppo_atari_envpool/main.py rename to benchmarks/torchatari/main.py diff --git a/benchmarks/rlhf/prepare.py b/benchmarks/torchatari/prepare.py similarity index 100% rename from benchmarks/rlhf/prepare.py rename to benchmarks/torchatari/prepare.py diff --git a/benchmarks/torch_ppo_atari_envpool/requirements.in b/benchmarks/torchatari/requirements.in similarity index 100% rename from benchmarks/torch_ppo_atari_envpool/requirements.in rename to benchmarks/torchatari/requirements.in diff --git a/benchmarks/torch_ppo_atari_envpool/voirfile.py b/benchmarks/torchatari/voirfile.py similarity index 100% rename from benchmarks/torch_ppo_atari_envpool/voirfile.py rename to benchmarks/torchatari/voirfile.py diff --git a/config/base.yaml b/config/base.yaml index f5d5920d0..c095e7cb7 100644 --- a/config/base.yaml +++ b/config/base.yaml @@ -648,3 +648,52 @@ llm-full-mp-nodes: - "len(nodes) >= ${num_machines}" +_geo_gnn: + inherits: _defaults + definition: . + # FIXME: torch cluster is laging behind pytorch + # we are forced to use torch==2.3 instead of torch==2.4 + install_group: gnn + definition: ../benchmarks/geo_gnn + plan: + method: per_gpu + +dimenet: + inherits: inherits + argv: + --model: 'DimeNet' + --num-samples: 10000 + --use3d: True + + +recursiongfn: + inherits: _defaults + definition: ../benchmarks/recursiongfn + install_group: torch + plan: + method: per_gpu + + argv: + --batch_size: 128 + --num_workers: 8 + --num_steps: 100 + --layer_width: 128 + --num_layers: 4 + + +torchatari: + max_duration: 600 + inherits: _defaults + definition: . + install-variant: unpinned + install_group: torch + plan: + method: per_gpu + + argv: + --num-minibatches: 16 + --update-epochs: 4 + --num-steps: 128 + --num-envs: auto({cpu_per_gpu}, 128) + --total-timesteps: 1000000 + --env-id: Breakout-v5 \ No newline at end of file diff --git a/constraints/extra/gnn.cuda.txt b/constraints/extra/gnn.cuda.txt new file mode 100644 index 000000000..96a943ad3 --- /dev/null +++ b/constraints/extra/gnn.cuda.txt @@ -0,0 +1,6 @@ + +# FIXME: this is cuda specific +--find-links https://data.pyg.org/whl/torch-2.3.0+cu121.html + + +torch>2.2,<2.4 diff --git a/constraints/extra/gnn.hpu.txt b/constraints/extra/gnn.hpu.txt new file mode 100644 index 000000000..e69de29bb diff --git a/constraints/extra/gnn.rocm.txt b/constraints/extra/gnn.rocm.txt new file mode 100644 index 000000000..e69de29bb diff --git a/constraints/extra/gnn.xpu.txt b/constraints/extra/gnn.xpu.txt new file mode 100644 index 000000000..e69de29bb diff --git a/constraints/extra/torch.rocm.txt b/constraints/extra/torch.rocm.txt index 493c77672..870d923a2 100644 --- a/constraints/extra/torch.rocm.txt +++ b/constraints/extra/torch.rocm.txt @@ -1 +1 @@ -# No jax \ No newline at end of file +# No jax only a container for it