-
Notifications
You must be signed in to change notification settings - Fork 59
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* chore: scaffold for dpo * chore: scaffolding for data * chore: more scaffolding for data * chore: fix scaffolding for data * chore: done dpo collator * chore: done make data module * chore: fix trainer * fix: specify label names * fix: eval_size * chore: add trainer * chore: doc efficient impl * chore: doc efficient impl clarify * fix: ref model sharing * fix: simplify * fix: add shebang * fix: add dpo launcher * fix: tweak hparams * fix: bug in eval
- Loading branch information
Showing
10 changed files
with
340 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
import os | ||
import pathlib | ||
from dataclasses import dataclass, field | ||
from typing import List, Literal | ||
|
||
import pandas as pd | ||
import transformers | ||
|
||
from alpaca_farm import common, constants, data_utils, logging, utils | ||
from alpaca_farm.rl.dpo_trainer import Trainer | ||
|
||
pd.set_option("display.max_rows", None) | ||
pd.set_option("display.max_columns", None) | ||
pd.set_option("display.width", None) | ||
pd.set_option("display.max_colwidth", None) | ||
|
||
logger = logging.get_logger(__name__) | ||
|
||
|
||
@dataclass | ||
class ModelArguments: | ||
model_name_or_path: str = field( | ||
default=None, metadata={"help": "Name to a huggingface native pretrained model or path to a model on disk."} | ||
) | ||
|
||
|
||
@dataclass | ||
class DataArguments: | ||
dataset_path: str = field(default="tatsu-lab/alpaca_farm") | ||
dataset_name: Literal["alpaca_human_preference", "alpaca_gpt4_preference", "alpaca_noisy_multi_preference"] = field( | ||
default="alpaca_noisy_multi_preference", | ||
metadata={"help": "Name of the dataset. Fetches the human or GPT-4 preference data."}, | ||
) | ||
eval_size: int = field( | ||
default=500, | ||
metadata={"help": "Number of examples to split out from training to use for evaluation."}, | ||
) | ||
prompt_dict_path: str = field( | ||
default=pathlib.Path(__file__).parent / "prompts" / "v0_inputs_noinputs.json", | ||
metadata={"help": "Path to the dictionary for the prompt to format examples."}, | ||
) | ||
|
||
|
||
@dataclass | ||
class TrainingArguments(transformers.TrainingArguments): | ||
pad_token: str = field(default=constants.DEFAULT_PAD_TOKEN) | ||
cache_dir: str = field(default=constants.DEFAULT_CACHE_DIR) | ||
wandb_project: str = field(default=constants.WANDB_PROJECT) | ||
flash_attn: bool = field(default=False) | ||
optim: str = field(default="adamw_torch") | ||
model_max_length: int = field( | ||
default=512, | ||
metadata={ | ||
"help": "Maximum sequence length. Sequences will be right padded to this length (and possibly truncated)." | ||
"Enforcing a consistent max length ensures memory usage is constant and predictable." | ||
}, | ||
) | ||
padding: Literal["max_length", "longest"] = field( | ||
default="longest", | ||
metadata={ | ||
"help": "Padding strategy. If 'max_length', pads to `model_max_length` always; this might lead to some " | ||
"redundant compute. If 'longest', pads to the longest sequence in the batch, capped by `model_max_length`." | ||
}, | ||
) | ||
resume_from_checkpoint: bool = field(default=False, metadata={"help": "If True, loads from last check point."}) | ||
use_fast_tokenizer: bool = field( | ||
default=False, | ||
metadata={ | ||
"help": "Use fast tokenizer if True. " | ||
"Fast LLaMA tokenizer forces protobuf downgrade to 3.20.3. " | ||
"Use fast tokenizer only if you can live with that." | ||
}, | ||
) | ||
beta: float = field(default=1e-1, metadata={"help": "Beta for the KL divergence."}) | ||
|
||
|
||
def main(): | ||
parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) | ||
model_args, data_args, training_args = parser.parse_args_into_dataclasses() | ||
os.environ["WANDB_PROJECT"] = training_args.wandb_project | ||
|
||
# Load model on CPU to prevent upfront OOM. | ||
model: transformers.PreTrainedModel = common.make_generative_lm( | ||
model_name_or_path=model_args.model_name_or_path, | ||
flash_attn=training_args.flash_attn, | ||
fp16=training_args.fp16, | ||
bf16=training_args.bf16, | ||
config=transformers.AutoConfig.from_pretrained(model_args.model_name_or_path), | ||
cache_dir=training_args.cache_dir, | ||
low_cpu_mem_usage=True, | ||
device_map=None, | ||
) | ||
common.let_model_save_mem_when_zero_grad(model) | ||
|
||
tokenizer = transformers.AutoTokenizer.from_pretrained( | ||
model_args.model_name_or_path, | ||
cache_dir=training_args.cache_dir, | ||
model_max_length=training_args.model_max_length, | ||
padding_side="right", # Ensures properly masking out the source tokens. | ||
use_fast=training_args.use_fast_tokenizer, | ||
) | ||
tokenizer.padding = training_args.padding | ||
|
||
# Collect special tokens. Only add if non-existent. | ||
special_tokens_dict = dict(additional_special_tokens=[]) | ||
if tokenizer.pad_token is None: | ||
special_tokens_dict["pad_token"] = training_args.pad_token | ||
if tokenizer.eos_token is None: | ||
special_tokens_dict["eos_token"] = constants.DEFAULT_EOS_TOKEN | ||
if tokenizer.bos_token is None: | ||
special_tokens_dict["bos_token"] = constants.DEFAULT_BOS_TOKEN | ||
if tokenizer.unk_token is None: | ||
special_tokens_dict["unk_token"] = constants.DEFAULT_UNK_TOKEN | ||
utils.stable_resize_token_embeddings_and_tokenizer(model, tokenizer, special_tokens_dict) | ||
|
||
data_module: dict = data_utils.make_dpo_data_module( | ||
tokenizer=tokenizer, | ||
data_args=data_args, | ||
training_args=training_args, | ||
) | ||
|
||
trainer = Trainer( | ||
model=model, | ||
tokenizer=tokenizer, | ||
args=training_args, | ||
**data_module, | ||
) | ||
|
||
trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) | ||
logger.warning("hooray! training finished successfully! now on to model saving.", main_process_only=True) | ||
|
||
trainer.save_state() | ||
common.safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir) | ||
logger.warning("hooray again! model saving worked.", main_process_only=True) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
#!/bin/bash | ||
|
||
output_dir=$1 | ||
run_name=$2 | ||
model_name_or_path=$3 | ||
|
||
torchrun --nproc_per_node=8 --master_port=1234 examples/dpo.py \ | ||
--model_name_or_path "${model_name_or_path}" \ | ||
--fp16 False \ | ||
--bf16 True \ | ||
--seed 42 \ | ||
--output_dir "${output_dir}" \ | ||
--num_train_epochs 2 \ | ||
--per_device_train_batch_size 1 \ | ||
--per_device_eval_batch_size 4 \ | ||
--gradient_accumulation_steps 16 \ | ||
--eval_steps 100 \ | ||
--save_strategy "steps" \ | ||
--save_steps 1000000000 \ | ||
--save_total_limit 1 \ | ||
--learning_rate 1e-6 \ | ||
--weight_decay 0.0 \ | ||
--warmup_ratio 0.03 \ | ||
--lr_scheduler_type "cosine" \ | ||
--evaluation_strategy "steps" \ | ||
--logging_steps 10 \ | ||
--wandb_project "alpaca_farm" \ | ||
--run_name "${run_name}" \ | ||
--tf32 True \ | ||
--flash_attn True \ | ||
--model_max_length 512 \ | ||
--ddp_timeout 1800 \ | ||
--fsdp "full_shard auto_wrap" \ | ||
--fsdp_transformer_layer_cls_to_wrap "LlamaDecoderLayer" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,5 @@ | ||
#!/bin/bash | ||
|
||
output_dir=$1 | ||
run_name=$2 | ||
model_name_or_path=$3 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,5 @@ | ||
#!/bin/bash | ||
|
||
output_dir=$1 | ||
run_name=$2 | ||
model_name_or_path=$3 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,5 @@ | ||
#!/bin/bash | ||
|
||
output_dir=$1 | ||
run_name=$2 | ||
reward_model_name_or_path=$3 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,5 @@ | ||
#!/bin/bash | ||
|
||
output_dir=$1 | ||
run_name=$2 | ||
reward_model_name_or_path=$3 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,5 @@ | ||
#!/bin/bash | ||
|
||
output_dir=$1 | ||
run_name=$2 | ||
model_name_or_path=$3 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.