Skip to content

Commit

Permalink
DPO (#81)
Browse files Browse the repository at this point in the history
* chore: scaffold for dpo

* chore: scaffolding for data

* chore: more scaffolding for data

* chore: fix scaffolding for data

* chore: done dpo collator

* chore: done make data module

* chore: fix trainer

* fix: specify label names

* fix: eval_size

* chore: add trainer

* chore: doc efficient impl

* chore: doc efficient impl clarify

* fix: ref model sharing

* fix: simplify

* fix: add shebang

* fix: add dpo launcher

* fix: tweak hparams

* fix: bug in eval
  • Loading branch information
lxuechen authored Nov 24, 2023
1 parent 9226444 commit 43333c7
Show file tree
Hide file tree
Showing 10 changed files with 340 additions and 4 deletions.
138 changes: 138 additions & 0 deletions examples/dpo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import os
import pathlib
from dataclasses import dataclass, field
from typing import List, Literal

import pandas as pd
import transformers

from alpaca_farm import common, constants, data_utils, logging, utils
from alpaca_farm.rl.dpo_trainer import Trainer

pd.set_option("display.max_rows", None)
pd.set_option("display.max_columns", None)
pd.set_option("display.width", None)
pd.set_option("display.max_colwidth", None)

logger = logging.get_logger(__name__)


@dataclass
class ModelArguments:
model_name_or_path: str = field(
default=None, metadata={"help": "Name to a huggingface native pretrained model or path to a model on disk."}
)


@dataclass
class DataArguments:
dataset_path: str = field(default="tatsu-lab/alpaca_farm")
dataset_name: Literal["alpaca_human_preference", "alpaca_gpt4_preference", "alpaca_noisy_multi_preference"] = field(
default="alpaca_noisy_multi_preference",
metadata={"help": "Name of the dataset. Fetches the human or GPT-4 preference data."},
)
eval_size: int = field(
default=500,
metadata={"help": "Number of examples to split out from training to use for evaluation."},
)
prompt_dict_path: str = field(
default=pathlib.Path(__file__).parent / "prompts" / "v0_inputs_noinputs.json",
metadata={"help": "Path to the dictionary for the prompt to format examples."},
)


@dataclass
class TrainingArguments(transformers.TrainingArguments):
pad_token: str = field(default=constants.DEFAULT_PAD_TOKEN)
cache_dir: str = field(default=constants.DEFAULT_CACHE_DIR)
wandb_project: str = field(default=constants.WANDB_PROJECT)
flash_attn: bool = field(default=False)
optim: str = field(default="adamw_torch")
model_max_length: int = field(
default=512,
metadata={
"help": "Maximum sequence length. Sequences will be right padded to this length (and possibly truncated)."
"Enforcing a consistent max length ensures memory usage is constant and predictable."
},
)
padding: Literal["max_length", "longest"] = field(
default="longest",
metadata={
"help": "Padding strategy. If 'max_length', pads to `model_max_length` always; this might lead to some "
"redundant compute. If 'longest', pads to the longest sequence in the batch, capped by `model_max_length`."
},
)
resume_from_checkpoint: bool = field(default=False, metadata={"help": "If True, loads from last check point."})
use_fast_tokenizer: bool = field(
default=False,
metadata={
"help": "Use fast tokenizer if True. "
"Fast LLaMA tokenizer forces protobuf downgrade to 3.20.3. "
"Use fast tokenizer only if you can live with that."
},
)
beta: float = field(default=1e-1, metadata={"help": "Beta for the KL divergence."})


def main():
parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
os.environ["WANDB_PROJECT"] = training_args.wandb_project

# Load model on CPU to prevent upfront OOM.
model: transformers.PreTrainedModel = common.make_generative_lm(
model_name_or_path=model_args.model_name_or_path,
flash_attn=training_args.flash_attn,
fp16=training_args.fp16,
bf16=training_args.bf16,
config=transformers.AutoConfig.from_pretrained(model_args.model_name_or_path),
cache_dir=training_args.cache_dir,
low_cpu_mem_usage=True,
device_map=None,
)
common.let_model_save_mem_when_zero_grad(model)

tokenizer = transformers.AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
model_max_length=training_args.model_max_length,
padding_side="right", # Ensures properly masking out the source tokens.
use_fast=training_args.use_fast_tokenizer,
)
tokenizer.padding = training_args.padding

# Collect special tokens. Only add if non-existent.
special_tokens_dict = dict(additional_special_tokens=[])
if tokenizer.pad_token is None:
special_tokens_dict["pad_token"] = training_args.pad_token
if tokenizer.eos_token is None:
special_tokens_dict["eos_token"] = constants.DEFAULT_EOS_TOKEN
if tokenizer.bos_token is None:
special_tokens_dict["bos_token"] = constants.DEFAULT_BOS_TOKEN
if tokenizer.unk_token is None:
special_tokens_dict["unk_token"] = constants.DEFAULT_UNK_TOKEN
utils.stable_resize_token_embeddings_and_tokenizer(model, tokenizer, special_tokens_dict)

data_module: dict = data_utils.make_dpo_data_module(
tokenizer=tokenizer,
data_args=data_args,
training_args=training_args,
)

trainer = Trainer(
model=model,
tokenizer=tokenizer,
args=training_args,
**data_module,
)

trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
logger.warning("hooray! training finished successfully! now on to model saving.", main_process_only=True)

trainer.save_state()
common.safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir)
logger.warning("hooray again! model saving worked.", main_process_only=True)


if __name__ == "__main__":
main()
34 changes: 34 additions & 0 deletions examples/scripts/dpo.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#!/bin/bash

output_dir=$1
run_name=$2
model_name_or_path=$3

torchrun --nproc_per_node=8 --master_port=1234 examples/dpo.py \
--model_name_or_path "${model_name_or_path}" \
--fp16 False \
--bf16 True \
--seed 42 \
--output_dir "${output_dir}" \
--num_train_epochs 2 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps 16 \
--eval_steps 100 \
--save_strategy "steps" \
--save_steps 1000000000 \
--save_total_limit 1 \
--learning_rate 1e-6 \
--weight_decay 0.0 \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--evaluation_strategy "steps" \
--logging_steps 10 \
--wandb_project "alpaca_farm" \
--run_name "${run_name}" \
--tf32 True \
--flash_attn True \
--model_max_length 512 \
--ddp_timeout 1800 \
--fsdp "full_shard auto_wrap" \
--fsdp_transformer_layer_cls_to_wrap "LlamaDecoderLayer"
2 changes: 2 additions & 0 deletions examples/scripts/expiter.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#!/bin/bash

output_dir=$1
run_name=$2
model_name_or_path=$3
Expand Down
2 changes: 2 additions & 0 deletions examples/scripts/reward_modeling.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#!/bin/bash

output_dir=$1
run_name=$2
model_name_or_path=$3
Expand Down
2 changes: 2 additions & 0 deletions examples/scripts/rlhf_ppo.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#!/bin/bash

output_dir=$1
run_name=$2
reward_model_name_or_path=$3
Expand Down
2 changes: 2 additions & 0 deletions examples/scripts/rlhf_quark.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#!/bin/bash

output_dir=$1
run_name=$2
reward_model_name_or_path=$3
Expand Down
2 changes: 2 additions & 0 deletions examples/scripts/sft.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#!/bin/bash

output_dir=$1
run_name=$2
model_name_or_path=$3
Expand Down
79 changes: 78 additions & 1 deletion src/alpaca_farm/data_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@

import copy
import dataclasses
from typing import Callable, Dict, Optional, Sequence, Union
from typing import Callable, Dict, List, Optional, Sequence, Union

import einops
import numpy as np
import pandas as pd
import torch
import transformers
Expand Down Expand Up @@ -270,6 +271,32 @@ def _merge_tokenization_metadata(metadata_list: Sequence[dict]) -> dict:
return packaged_data


def preprocess_for_dpo(
df: pd.DataFrame,
prompt_dict: dict,
tokenizer: transformers.PreTrainedTokenizer,
df_postprocessor=None,
verbose=True,
) -> dict[str, Union[torch.Tensor, Sequence[torch.Tensor]]]:
output_1, output_2, preference = df["output_1"], df["output_2"], df["preference"]

df_w = df.assign(output=np.where(preference == 1, output_1, output_2))[["instruction", "input", "output"]]
df_l = df.assign(output=np.where(preference == 2, output_1, output_2))[["instruction", "input", "output"]]

tensors_w = preprocess_for_sft(
df=df_w, prompt_dict=prompt_dict, tokenizer=tokenizer, df_postprocessor=df_postprocessor, verbose=verbose
)
tensors_l = preprocess_for_sft(
df=df_l, prompt_dict=prompt_dict, tokenizer=tokenizer, df_postprocessor=df_postprocessor, verbose=verbose
)
return dict(
input_ids_w=tensors_w["input_ids"],
labels_w=tensors_w["labels"],
input_ids_l=tensors_l["input_ids"],
labels_l=tensors_l["labels"],
)


def _get_generator(seed: int) -> torch.Generator:
rng = torch.Generator()
rng.manual_seed(seed)
Expand Down Expand Up @@ -526,3 +553,53 @@ def __len__(self):
class DataCollatorForStackableDataset(object):
def __call__(self, instances: Sequence[Dict]) -> Dict[str, Tensor]:
return {key: torch.stack([instance[key] for instance in instances]) for key in instances[0].keys()}


class DPODataset(Dataset):
def __init__(
self,
df: pd.DataFrame,
prompt_dict: dict,
tokenizer: transformers.PreTrainedTokenizer,
df_postprocessor: Optional[Callable] = None,
):
super(DPODataset, self).__init__()
self.tensors = preprocess_for_dpo(
df=df, prompt_dict=prompt_dict, tokenizer=tokenizer, df_postprocessor=df_postprocessor
)

def __len__(self):
return len(next(iter(self.tensors.values())))

def __getitem__(self, i) -> Dict[str, Tensor]:
return {key: value[i] for key, value in self.tensors.items()}


@dataclasses.dataclass
class DataCollatorForDPODataset(object):
tokenizer: transformers.PreTrainedTokenizer

def _pad_input_ids_and_labels(self, input_ids: List[Tensor], labels: List[Tensor]) -> tuple[Tensor, Tensor, Tensor]:
# This is the same things as done in `DataCollatorForSFTDataset`; repeat for better readability.
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
)
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=constants.IGNORE_INDEX)
# When sequences are right padded, `attention_mask` is only useful for T5 training.
attention_mask = input_ids.ne(self.tokenizer.pad_token_id).long()
return input_ids, labels, attention_mask

def __call__(self, instances: Sequence[Dict]) -> Dict[str, Tensor]:
input_ids_w, labels_w, input_ids_l, labels_l = tuple(
[instance[key] for instance in instances] for key in ("input_ids_w", "labels_w", "input_ids_l", "labels_l")
)
input_ids_w, labels_w, attention_mask_w = self._pad_input_ids_and_labels(input_ids_w, labels_w)
input_ids_l, labels_l, attention_mask_l = self._pad_input_ids_and_labels(input_ids_l, labels_l)
return dict(
input_ids_w=input_ids_w,
labels_w=labels_w,
attention_mask_w=attention_mask_w,
input_ids_l=input_ids_l,
labels_l=labels_l,
attention_mask_l=attention_mask_l,
)
32 changes: 29 additions & 3 deletions src/alpaca_farm/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@
from .data_preprocessor import (
BinaryRewardModelingDataset,
DataCollatorForBinaryRewardModelingDataset,
DataCollatorForDPODataset,
DataCollatorForSFTDataset,
DataCollatorForStackableDataset,
DPODataset,
QueryDataset,
SFTDataset,
split_train_into_train_and_eval,
Expand All @@ -35,7 +37,7 @@ def make_supervised_data_module(
tokenizer: transformers.PreTrainedTokenizer,
data_args,
training_args,
):
) -> dict:
prompt_dict = utils.jload(data_args.prompt_dict_path)

alpaca_instructions = datasets.load_dataset(data_args.dataset_path, data_args.dataset_name)
Expand Down Expand Up @@ -72,7 +74,7 @@ def make_binary_reward_modeling_data_module(
tokenizer: transformers.PreTrainedTokenizer,
data_args,
training_args,
):
) -> dict:
prompt_dict = utils.jload(data_args.prompt_dict_path)

alpaca_human_preference = datasets.load_dataset(data_args.dataset_path, data_args.dataset_name)
Expand All @@ -97,7 +99,7 @@ def make_rl_data_module(
tokenizer: transformers.PreTrainedTokenizer,
data_args,
training_args,
):
) -> dict:
prompt_dict = utils.jload(data_args.prompt_dict_path)

alpaca_instructions = datasets.load_dataset(data_args.dataset_path, data_args.dataset_name)
Expand Down Expand Up @@ -126,3 +128,27 @@ def make_rl_data_module(
prompt_postprocessor=prompt_postprocessor,
)
return dict(train_dataset=train_dataset, eval_dataset=eval_dataset, data_collator=DataCollatorForStackableDataset())


def make_dpo_data_module(
tokenizer: transformers.PreTrainedTokenizer,
data_args,
training_args,
) -> dict:
prompt_dict = utils.jload(data_args.prompt_dict_path)

alpaca_human_preference = datasets.load_dataset(data_args.dataset_path, data_args.dataset_name)
train_df = pd.DataFrame(alpaca_human_preference["preference"])

train_dataset = DPODataset(
df=train_df,
prompt_dict=prompt_dict,
tokenizer=tokenizer,
)
train_dataset, eval_dataset = split_train_into_train_and_eval(
train_dataset=train_dataset,
eval_size=data_args.eval_size,
seed=training_args.seed,
)
data_collator = DataCollatorForDPODataset(tokenizer=tokenizer)
return dict(train_dataset=train_dataset, eval_dataset=eval_dataset, data_collator=data_collator)
Loading

0 comments on commit 43333c7

Please sign in to comment.