From 43333c7e0b3c67d45a8c5e92788cab713fc750c8 Mon Sep 17 00:00:00 2001 From: Xuechen Li <12689993+lxuechen@users.noreply.github.com> Date: Thu, 23 Nov 2023 16:42:29 -0800 Subject: [PATCH] DPO (#81) * chore: scaffold for dpo * chore: scaffolding for data * chore: more scaffolding for data * chore: fix scaffolding for data * chore: done dpo collator * chore: done make data module * chore: fix trainer * fix: specify label names * fix: eval_size * chore: add trainer * chore: doc efficient impl * chore: doc efficient impl clarify * fix: ref model sharing * fix: simplify * fix: add shebang * fix: add dpo launcher * fix: tweak hparams * fix: bug in eval --- examples/dpo.py | 138 +++++++++++++++++++++++++++ examples/scripts/dpo.sh | 34 +++++++ examples/scripts/expiter.sh | 2 + examples/scripts/reward_modeling.sh | 2 + examples/scripts/rlhf_ppo.sh | 2 + examples/scripts/rlhf_quark.sh | 2 + examples/scripts/sft.sh | 2 + src/alpaca_farm/data_preprocessor.py | 79 ++++++++++++++- src/alpaca_farm/data_utils.py | 32 ++++++- src/alpaca_farm/rl/dpo_trainer.py | 51 ++++++++++ 10 files changed, 340 insertions(+), 4 deletions(-) create mode 100644 examples/dpo.py create mode 100644 examples/scripts/dpo.sh create mode 100644 src/alpaca_farm/rl/dpo_trainer.py diff --git a/examples/dpo.py b/examples/dpo.py new file mode 100644 index 00000000..c3d19444 --- /dev/null +++ b/examples/dpo.py @@ -0,0 +1,138 @@ +import os +import pathlib +from dataclasses import dataclass, field +from typing import List, Literal + +import pandas as pd +import transformers + +from alpaca_farm import common, constants, data_utils, logging, utils +from alpaca_farm.rl.dpo_trainer import Trainer + +pd.set_option("display.max_rows", None) +pd.set_option("display.max_columns", None) +pd.set_option("display.width", None) +pd.set_option("display.max_colwidth", None) + +logger = logging.get_logger(__name__) + + +@dataclass +class ModelArguments: + model_name_or_path: str = field( + default=None, metadata={"help": "Name to a huggingface native pretrained model or path to a model on disk."} + ) + + +@dataclass +class DataArguments: + dataset_path: str = field(default="tatsu-lab/alpaca_farm") + dataset_name: Literal["alpaca_human_preference", "alpaca_gpt4_preference", "alpaca_noisy_multi_preference"] = field( + default="alpaca_noisy_multi_preference", + metadata={"help": "Name of the dataset. Fetches the human or GPT-4 preference data."}, + ) + eval_size: int = field( + default=500, + metadata={"help": "Number of examples to split out from training to use for evaluation."}, + ) + prompt_dict_path: str = field( + default=pathlib.Path(__file__).parent / "prompts" / "v0_inputs_noinputs.json", + metadata={"help": "Path to the dictionary for the prompt to format examples."}, + ) + + +@dataclass +class TrainingArguments(transformers.TrainingArguments): + pad_token: str = field(default=constants.DEFAULT_PAD_TOKEN) + cache_dir: str = field(default=constants.DEFAULT_CACHE_DIR) + wandb_project: str = field(default=constants.WANDB_PROJECT) + flash_attn: bool = field(default=False) + optim: str = field(default="adamw_torch") + model_max_length: int = field( + default=512, + metadata={ + "help": "Maximum sequence length. Sequences will be right padded to this length (and possibly truncated)." + "Enforcing a consistent max length ensures memory usage is constant and predictable." + }, + ) + padding: Literal["max_length", "longest"] = field( + default="longest", + metadata={ + "help": "Padding strategy. If 'max_length', pads to `model_max_length` always; this might lead to some " + "redundant compute. If 'longest', pads to the longest sequence in the batch, capped by `model_max_length`." + }, + ) + resume_from_checkpoint: bool = field(default=False, metadata={"help": "If True, loads from last check point."}) + use_fast_tokenizer: bool = field( + default=False, + metadata={ + "help": "Use fast tokenizer if True. " + "Fast LLaMA tokenizer forces protobuf downgrade to 3.20.3. " + "Use fast tokenizer only if you can live with that." + }, + ) + beta: float = field(default=1e-1, metadata={"help": "Beta for the KL divergence."}) + + +def main(): + parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + os.environ["WANDB_PROJECT"] = training_args.wandb_project + + # Load model on CPU to prevent upfront OOM. + model: transformers.PreTrainedModel = common.make_generative_lm( + model_name_or_path=model_args.model_name_or_path, + flash_attn=training_args.flash_attn, + fp16=training_args.fp16, + bf16=training_args.bf16, + config=transformers.AutoConfig.from_pretrained(model_args.model_name_or_path), + cache_dir=training_args.cache_dir, + low_cpu_mem_usage=True, + device_map=None, + ) + common.let_model_save_mem_when_zero_grad(model) + + tokenizer = transformers.AutoTokenizer.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + model_max_length=training_args.model_max_length, + padding_side="right", # Ensures properly masking out the source tokens. + use_fast=training_args.use_fast_tokenizer, + ) + tokenizer.padding = training_args.padding + + # Collect special tokens. Only add if non-existent. + special_tokens_dict = dict(additional_special_tokens=[]) + if tokenizer.pad_token is None: + special_tokens_dict["pad_token"] = training_args.pad_token + if tokenizer.eos_token is None: + special_tokens_dict["eos_token"] = constants.DEFAULT_EOS_TOKEN + if tokenizer.bos_token is None: + special_tokens_dict["bos_token"] = constants.DEFAULT_BOS_TOKEN + if tokenizer.unk_token is None: + special_tokens_dict["unk_token"] = constants.DEFAULT_UNK_TOKEN + utils.stable_resize_token_embeddings_and_tokenizer(model, tokenizer, special_tokens_dict) + + data_module: dict = data_utils.make_dpo_data_module( + tokenizer=tokenizer, + data_args=data_args, + training_args=training_args, + ) + + trainer = Trainer( + model=model, + tokenizer=tokenizer, + args=training_args, + **data_module, + ) + + trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) + logger.warning("hooray! training finished successfully! now on to model saving.", main_process_only=True) + + trainer.save_state() + common.safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir) + logger.warning("hooray again! model saving worked.", main_process_only=True) + + +if __name__ == "__main__": + main() diff --git a/examples/scripts/dpo.sh b/examples/scripts/dpo.sh new file mode 100644 index 00000000..a207f2c9 --- /dev/null +++ b/examples/scripts/dpo.sh @@ -0,0 +1,34 @@ +#!/bin/bash + +output_dir=$1 +run_name=$2 +model_name_or_path=$3 + +torchrun --nproc_per_node=8 --master_port=1234 examples/dpo.py \ + --model_name_or_path "${model_name_or_path}" \ + --fp16 False \ + --bf16 True \ + --seed 42 \ + --output_dir "${output_dir}" \ + --num_train_epochs 2 \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 4 \ + --gradient_accumulation_steps 16 \ + --eval_steps 100 \ + --save_strategy "steps" \ + --save_steps 1000000000 \ + --save_total_limit 1 \ + --learning_rate 1e-6 \ + --weight_decay 0.0 \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --evaluation_strategy "steps" \ + --logging_steps 10 \ + --wandb_project "alpaca_farm" \ + --run_name "${run_name}" \ + --tf32 True \ + --flash_attn True \ + --model_max_length 512 \ + --ddp_timeout 1800 \ + --fsdp "full_shard auto_wrap" \ + --fsdp_transformer_layer_cls_to_wrap "LlamaDecoderLayer" diff --git a/examples/scripts/expiter.sh b/examples/scripts/expiter.sh index 6862bd9f..e8c68884 100644 --- a/examples/scripts/expiter.sh +++ b/examples/scripts/expiter.sh @@ -1,3 +1,5 @@ +#!/bin/bash + output_dir=$1 run_name=$2 model_name_or_path=$3 diff --git a/examples/scripts/reward_modeling.sh b/examples/scripts/reward_modeling.sh index 383d0134..f05af545 100644 --- a/examples/scripts/reward_modeling.sh +++ b/examples/scripts/reward_modeling.sh @@ -1,3 +1,5 @@ +#!/bin/bash + output_dir=$1 run_name=$2 model_name_or_path=$3 diff --git a/examples/scripts/rlhf_ppo.sh b/examples/scripts/rlhf_ppo.sh index 05c0f2fc..1fcd27bb 100644 --- a/examples/scripts/rlhf_ppo.sh +++ b/examples/scripts/rlhf_ppo.sh @@ -1,3 +1,5 @@ +#!/bin/bash + output_dir=$1 run_name=$2 reward_model_name_or_path=$3 diff --git a/examples/scripts/rlhf_quark.sh b/examples/scripts/rlhf_quark.sh index 830175d3..7c0f1423 100644 --- a/examples/scripts/rlhf_quark.sh +++ b/examples/scripts/rlhf_quark.sh @@ -1,3 +1,5 @@ +#!/bin/bash + output_dir=$1 run_name=$2 reward_model_name_or_path=$3 diff --git a/examples/scripts/sft.sh b/examples/scripts/sft.sh index 00c0d7e8..06cf8e48 100644 --- a/examples/scripts/sft.sh +++ b/examples/scripts/sft.sh @@ -1,3 +1,5 @@ +#!/bin/bash + output_dir=$1 run_name=$2 model_name_or_path=$3 diff --git a/src/alpaca_farm/data_preprocessor.py b/src/alpaca_farm/data_preprocessor.py index f3405c7d..7d3bcbb9 100644 --- a/src/alpaca_farm/data_preprocessor.py +++ b/src/alpaca_farm/data_preprocessor.py @@ -14,9 +14,10 @@ import copy import dataclasses -from typing import Callable, Dict, Optional, Sequence, Union +from typing import Callable, Dict, List, Optional, Sequence, Union import einops +import numpy as np import pandas as pd import torch import transformers @@ -270,6 +271,32 @@ def _merge_tokenization_metadata(metadata_list: Sequence[dict]) -> dict: return packaged_data +def preprocess_for_dpo( + df: pd.DataFrame, + prompt_dict: dict, + tokenizer: transformers.PreTrainedTokenizer, + df_postprocessor=None, + verbose=True, +) -> dict[str, Union[torch.Tensor, Sequence[torch.Tensor]]]: + output_1, output_2, preference = df["output_1"], df["output_2"], df["preference"] + + df_w = df.assign(output=np.where(preference == 1, output_1, output_2))[["instruction", "input", "output"]] + df_l = df.assign(output=np.where(preference == 2, output_1, output_2))[["instruction", "input", "output"]] + + tensors_w = preprocess_for_sft( + df=df_w, prompt_dict=prompt_dict, tokenizer=tokenizer, df_postprocessor=df_postprocessor, verbose=verbose + ) + tensors_l = preprocess_for_sft( + df=df_l, prompt_dict=prompt_dict, tokenizer=tokenizer, df_postprocessor=df_postprocessor, verbose=verbose + ) + return dict( + input_ids_w=tensors_w["input_ids"], + labels_w=tensors_w["labels"], + input_ids_l=tensors_l["input_ids"], + labels_l=tensors_l["labels"], + ) + + def _get_generator(seed: int) -> torch.Generator: rng = torch.Generator() rng.manual_seed(seed) @@ -526,3 +553,53 @@ def __len__(self): class DataCollatorForStackableDataset(object): def __call__(self, instances: Sequence[Dict]) -> Dict[str, Tensor]: return {key: torch.stack([instance[key] for instance in instances]) for key in instances[0].keys()} + + +class DPODataset(Dataset): + def __init__( + self, + df: pd.DataFrame, + prompt_dict: dict, + tokenizer: transformers.PreTrainedTokenizer, + df_postprocessor: Optional[Callable] = None, + ): + super(DPODataset, self).__init__() + self.tensors = preprocess_for_dpo( + df=df, prompt_dict=prompt_dict, tokenizer=tokenizer, df_postprocessor=df_postprocessor + ) + + def __len__(self): + return len(next(iter(self.tensors.values()))) + + def __getitem__(self, i) -> Dict[str, Tensor]: + return {key: value[i] for key, value in self.tensors.items()} + + +@dataclasses.dataclass +class DataCollatorForDPODataset(object): + tokenizer: transformers.PreTrainedTokenizer + + def _pad_input_ids_and_labels(self, input_ids: List[Tensor], labels: List[Tensor]) -> tuple[Tensor, Tensor, Tensor]: + # This is the same things as done in `DataCollatorForSFTDataset`; repeat for better readability. + input_ids = torch.nn.utils.rnn.pad_sequence( + input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id + ) + labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=constants.IGNORE_INDEX) + # When sequences are right padded, `attention_mask` is only useful for T5 training. + attention_mask = input_ids.ne(self.tokenizer.pad_token_id).long() + return input_ids, labels, attention_mask + + def __call__(self, instances: Sequence[Dict]) -> Dict[str, Tensor]: + input_ids_w, labels_w, input_ids_l, labels_l = tuple( + [instance[key] for instance in instances] for key in ("input_ids_w", "labels_w", "input_ids_l", "labels_l") + ) + input_ids_w, labels_w, attention_mask_w = self._pad_input_ids_and_labels(input_ids_w, labels_w) + input_ids_l, labels_l, attention_mask_l = self._pad_input_ids_and_labels(input_ids_l, labels_l) + return dict( + input_ids_w=input_ids_w, + labels_w=labels_w, + attention_mask_w=attention_mask_w, + input_ids_l=input_ids_l, + labels_l=labels_l, + attention_mask_l=attention_mask_l, + ) diff --git a/src/alpaca_farm/data_utils.py b/src/alpaca_farm/data_utils.py index 6b85750a..e511cbbb 100644 --- a/src/alpaca_farm/data_utils.py +++ b/src/alpaca_farm/data_utils.py @@ -21,8 +21,10 @@ from .data_preprocessor import ( BinaryRewardModelingDataset, DataCollatorForBinaryRewardModelingDataset, + DataCollatorForDPODataset, DataCollatorForSFTDataset, DataCollatorForStackableDataset, + DPODataset, QueryDataset, SFTDataset, split_train_into_train_and_eval, @@ -35,7 +37,7 @@ def make_supervised_data_module( tokenizer: transformers.PreTrainedTokenizer, data_args, training_args, -): +) -> dict: prompt_dict = utils.jload(data_args.prompt_dict_path) alpaca_instructions = datasets.load_dataset(data_args.dataset_path, data_args.dataset_name) @@ -72,7 +74,7 @@ def make_binary_reward_modeling_data_module( tokenizer: transformers.PreTrainedTokenizer, data_args, training_args, -): +) -> dict: prompt_dict = utils.jload(data_args.prompt_dict_path) alpaca_human_preference = datasets.load_dataset(data_args.dataset_path, data_args.dataset_name) @@ -97,7 +99,7 @@ def make_rl_data_module( tokenizer: transformers.PreTrainedTokenizer, data_args, training_args, -): +) -> dict: prompt_dict = utils.jload(data_args.prompt_dict_path) alpaca_instructions = datasets.load_dataset(data_args.dataset_path, data_args.dataset_name) @@ -126,3 +128,27 @@ def make_rl_data_module( prompt_postprocessor=prompt_postprocessor, ) return dict(train_dataset=train_dataset, eval_dataset=eval_dataset, data_collator=DataCollatorForStackableDataset()) + + +def make_dpo_data_module( + tokenizer: transformers.PreTrainedTokenizer, + data_args, + training_args, +) -> dict: + prompt_dict = utils.jload(data_args.prompt_dict_path) + + alpaca_human_preference = datasets.load_dataset(data_args.dataset_path, data_args.dataset_name) + train_df = pd.DataFrame(alpaca_human_preference["preference"]) + + train_dataset = DPODataset( + df=train_df, + prompt_dict=prompt_dict, + tokenizer=tokenizer, + ) + train_dataset, eval_dataset = split_train_into_train_and_eval( + train_dataset=train_dataset, + eval_size=data_args.eval_size, + seed=training_args.seed, + ) + data_collator = DataCollatorForDPODataset(tokenizer=tokenizer) + return dict(train_dataset=train_dataset, eval_dataset=eval_dataset, data_collator=data_collator) diff --git a/src/alpaca_farm/rl/dpo_trainer.py b/src/alpaca_farm/rl/dpo_trainer.py new file mode 100644 index 00000000..eb18d85a --- /dev/null +++ b/src/alpaca_farm/rl/dpo_trainer.py @@ -0,0 +1,51 @@ +import copy + +import torch +import torch.nn.functional as F +import transformers + +from .. import common + +LABEL_NAMES = ["input_ids_w", "labels_w", "attention_mask_w", "input_ids_l", "labels_l", "attention_mask_l"] + + +class Trainer(transformers.Trainer): + def __init__(self, model, args, *argv, **kwargs): + args.label_names = LABEL_NAMES + super().__init__(model, args, *argv, **kwargs) + self.ref_model = self._wrap_model(copy.deepcopy(model)).eval() + + def compute_loss(self, model, inputs, return_outputs=False): + # This implementation is simple and readable, but it's not efficient. + # Since the instruction+input is shared for the winning and losing sequences, one can in principle + # just do a single forward pass on this part for model and ref_model, instead of doing the full forward + # twice (one for winning and one for losing sequence) for model and ref_model. + # So here's the efficient implementation: + # 1. Do a single forward pass on the instruction+input for model. Retain the kv cache. + # 2. Do a forward pass on the winning response for model, using the kv cache. + # 3. Do a forward pass on the losing response for model, using the kv cache. + # 4. Follow a similar procedure for ref_model, except don't retain activations for backprop + # (but do temporarily retain the kv cache). + # There's an explicit speed/memory tradeoff here -- retaining kv cache saves FLOPs but uses more memory. + # 5. Compute the loss. + # If memory is not a concern, then the winning and losing sequences should be batched together so the logits + # can be computed in a single forward call. + input_ids_w, labels_w, attention_mask_w, input_ids_l, labels_l, attention_mask_l = common.unpack_dict( + inputs, LABEL_NAMES + ) + labels_w, labels_l = labels_w[..., 1:], labels_l[..., 1:] + + with torch.no_grad(): + ref_logits_w = self.ref_model(input_ids=input_ids_w, attention_mask=attention_mask_w).logits[..., :-1, :] + ref_logits_l = self.ref_model(input_ids=input_ids_l, attention_mask=attention_mask_l).logits[..., :-1, :] + ref_logprobs_w = F.cross_entropy(ref_logits_w.transpose(-1, -2), labels_w, reduction="none").sum(-1) + ref_logprobs_l = F.cross_entropy(ref_logits_l.transpose(-1, -2), labels_l, reduction="none").sum(-1) + + logits_w = model(input_ids=input_ids_w, attention_mask=attention_mask_w).logits[..., :-1, :] + logits_l = model(input_ids=input_ids_l, attention_mask=attention_mask_l).logits[..., :-1, :] + logprobs_w = F.cross_entropy(logits_w.transpose(-1, -2), labels_w, reduction="none").sum(-1) + logprobs_l = F.cross_entropy(logits_l.transpose(-1, -2), labels_l, reduction="none").sum(-1) + + logits = self.args.beta * ((logprobs_w - ref_logprobs_w) - (logprobs_l - ref_logprobs_l)) + loss = -F.logsigmoid(logits).mean(0) + return (loss, dict(logits=logits)) if return_outputs else loss