Skip to content

Commit

Permalink
new RLHF benchmark (#273)
Browse files Browse the repository at this point in the history
* new RLHF benchmark

* Add RLHF config to standard

---------

Co-authored-by: pierre.delaunay <[email protected]>
  • Loading branch information
Delaunay and pierre.delaunay authored Sep 10, 2024
1 parent e327768 commit 5672f16
Show file tree
Hide file tree
Showing 13 changed files with 735 additions and 6 deletions.
31 changes: 31 additions & 0 deletions benchmarks/rlhf/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Use global base if possible
ifndef MILABENCH_BASE
MILABENCH_BASE="base"
endif

export MILABENCH_BASE

BENCH_NAME=rlhf
MILABENCH_CONFIG=dev.yaml
MILABENCH_ARGS=--config $(MILABENCH_CONFIG) --base $(MILABENCH_BASE)

all:
install prepare single gpus nodes

install:
milabench install $(MILABENCH_ARGS) --force

prepare:
milabench prepare $(MILABENCH_ARGS)

tests: install prepare
milabench run $(MILABENCH_ARGS)

single:
milabench run $(MILABENCH_ARGS) --select $(BENCH_NAME)-single

gpus:
milabench run $(MILABENCH_ARGS) --select $(BENCH_NAME)-gpus

nodes:
milabench run $(MILABENCH_ARGS) --select $(BENCH_NAME)-nodes
4 changes: 4 additions & 0 deletions benchmarks/rlhf/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@

# Rlhf

Rewrite this README to explain what the benchmark is!
41 changes: 41 additions & 0 deletions benchmarks/rlhf/benchfile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from milabench.pack import Package


class Rlhf(Package):
# Requirements file installed by install(). It can be empty or absent.
base_requirements = "requirements.in"

# The preparation script called by prepare(). It must be executable,
# but it can be any type of script. It can be empty or absent.
prepare_script = "prepare.py"

# The main script called by run(). It must be a Python file. It has to
# be present.
main_script = "main.py"

# You can remove the functions below if you don't need to modify them.

def make_env(self):
# Return a dict of environment variables for prepare_script and
# main_script.
return super().make_env()

async def install(self):
await super().install() # super() call installs the requirements

async def prepare(self):
await super().prepare() # super() call executes prepare_script

def build_run_plan(self):
from milabench.commands import PackCommand, AccelerateAllNodes

main = self.dirs.code / self.main_script
plan = PackCommand(self, *self.argv, lazy=True)

if False:
plan = VoirCommand(plan, cwd=main.parent)

return AccelerateAllNodes(plan).use_stdout()


__pack__ = Rlhf
29 changes: 29 additions & 0 deletions benchmarks/rlhf/dev.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@

rlhf_:
inherits: _defaults
definition: .
install-variant: unpinned
install_group: torch
plan:
method: per_gpu

argv:
--output_dir: "{milabench_extra}/output"
--model_name_or_path: EleutherAI/pythia-1b-deduped
--per_device_train_batch_size: 64
--logging_strategy: "no"
--log_level: "critical"
--bf16: true


rlhf-single:
inherits: rlhf_
plan:
method: per_gpu


rlhf-gpus:
inherits: rlhf_
plan:
method: njobs
n: 1
136 changes: 136 additions & 0 deletions benchmarks/rlhf/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
#!/usr/bin/env python

import shutil

from accelerate import PartialState
from datasets import load_dataset
from transformers import (
AutoModelForCausalLM,
AutoModelForSequenceClassification,
AutoTokenizer,
HfArgumentParser,
)

from trl import ModelConfig
from trl.trainer.ppov2_trainer import PPOv2Config, PPOv2Trainer
from trl.trainer.utils import SIMPLE_QUERY_CHAT_TEMPLATE


class PPOv2TrainerIntrumented(PPOv2Trainer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def batch_size_fn(batch):
x, y = batch['input_ids'].shape
return x * y

from benchmate.observer import BenchObserver
observer = BenchObserver(
batch_size_fn=batch_size_fn,
earlystop=70,
raise_stop_program=True,
stdout=True,
)

self.dataloader = observer.iterate(self.dataloader)

def generate_completions(self, sampling: bool = False):
pass

def _save_checkpoint(self, *args, **kwargs):
pass

def save_model(self, *args, **kwargs):
pass


def main():
parser = HfArgumentParser((PPOv2Config, ModelConfig))
config, model_config = parser.parse_args_into_dataclasses()
# remove output_dir if exists
shutil.rmtree(config.output_dir, ignore_errors=True)

################
# Model & Tokenizer
################
tokenizer = AutoTokenizer.from_pretrained(
model_config.model_name_or_path,
padding_side="left",
trust_remote_code=model_config.trust_remote_code,
)
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
if tokenizer.chat_template is None:
tokenizer.chat_template = SIMPLE_QUERY_CHAT_TEMPLATE
value_model = AutoModelForSequenceClassification.from_pretrained(
config.reward_model_path, trust_remote_code=model_config.trust_remote_code, num_labels=1
)
reward_model = AutoModelForSequenceClassification.from_pretrained(
config.reward_model_path, trust_remote_code=model_config.trust_remote_code, num_labels=1
)
ref_policy = AutoModelForCausalLM.from_pretrained(
config.sft_model_path, trust_remote_code=model_config.trust_remote_code
)
policy = AutoModelForCausalLM.from_pretrained(
config.sft_model_path, trust_remote_code=model_config.trust_remote_code
)
################
# Dataset
################
raw_datasets = load_dataset("trl-internal-testing/descriptiveness-sentiment-trl-style", split="descriptiveness")
eval_samples = 20
train_dataset = raw_datasets.select(range(len(raw_datasets) - eval_samples))
eval_dataset = raw_datasets.select(range(len(raw_datasets) - eval_samples, len(raw_datasets)))
dataset_text_field = "prompt"

def prepare_dataset(dataset, tokenizer):
"""pre-tokenize the dataset before training; only collate during training"""

def tokenize(element):
outputs = tokenizer(
element[dataset_text_field],
padding=False,
)
return {"input_ids": outputs["input_ids"]}

return dataset.map(
tokenize,
batched=True,
remove_columns=dataset.column_names,
num_proc=config.dataset_num_proc,
)

# Compute that only on the main process for faster data processing.
# see: https://github.com/huggingface/trl/pull/1255
with PartialState().local_main_process_first():
train_dataset = prepare_dataset(train_dataset, tokenizer)
eval_dataset = prepare_dataset(eval_dataset, tokenizer)

################
# Training
################
trainer = PPOv2TrainerIntrumented(
config=config,
tokenizer=tokenizer,
policy=policy,
ref_policy=ref_policy,
reward_model=reward_model,
value_model=value_model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
trainer.train()
trainer.save_model(config.output_dir)
if config.push_to_hub:
trainer.push_to_hub()
trainer.generate_completions()


if __name__ == "__main__":
from voir.phase import StopProgram
from benchmate.monitor import bench_monitor

try:
with bench_monitor():
main()
except StopProgram:
pass
54 changes: 54 additions & 0 deletions benchmarks/rlhf/prepare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#!/usr/bin/env python

import shutil

from transformers import (
AutoModelForCausalLM,
AutoModelForSequenceClassification,
AutoTokenizer,
HfArgumentParser,
)
from datasets import load_dataset
from trl import ModelConfig
from trl.trainer.ppov2_trainer import PPOv2Config
from trl.trainer.utils import SIMPLE_QUERY_CHAT_TEMPLATE


if __name__ == "__main__":
parser = HfArgumentParser((PPOv2Config, ModelConfig))
config, model_config = parser.parse_args_into_dataclasses()

# remove output_dir if exists
shutil.rmtree(config.output_dir, ignore_errors=True)

tokenizer = AutoTokenizer.from_pretrained(
model_config.model_name_or_path,
padding_side="left",
trust_remote_code=model_config.trust_remote_code,
)

tokenizer.add_special_tokens({"pad_token": "[PAD]"})

if tokenizer.chat_template is None:
tokenizer.chat_template = SIMPLE_QUERY_CHAT_TEMPLATE

value_model = AutoModelForSequenceClassification.from_pretrained(
config.reward_model_path,
trust_remote_code=model_config.trust_remote_code,
num_labels=1
)
reward_model = AutoModelForSequenceClassification.from_pretrained(
config.reward_model_path,
trust_remote_code=model_config.trust_remote_code,
num_labels=1
)
ref_policy = AutoModelForCausalLM.from_pretrained(
config.sft_model_path,
trust_remote_code=model_config.trust_remote_code
)
policy = AutoModelForCausalLM.from_pretrained(
config.sft_model_path,
trust_remote_code=model_config.trust_remote_code
)

raw_datasets = load_dataset("trl-internal-testing/descriptiveness-sentiment-trl-style", split="descriptiveness")
Loading

0 comments on commit 5672f16

Please sign in to comment.