From 5672f16f8d6cf6d6f48df6a21661f3e05aaa17be Mon Sep 17 00:00:00 2001
From: Setepenre <pierre.delaunay.tr@gmail.com>
Date: Tue, 10 Sep 2024 10:50:26 -0400
Subject: [PATCH] new RLHF benchmark (#273)

* new RLHF benchmark

* Add RLHF config to standard

---------

Co-authored-by: pierre.delaunay <delaunap@rtx5.server.mila.quebec>
---
 benchmarks/rlhf/Makefile                      |  31 ++
 benchmarks/rlhf/README.md                     |   4 +
 benchmarks/rlhf/benchfile.py                  |  41 +++
 benchmarks/rlhf/dev.yaml                      |  29 ++
 benchmarks/rlhf/main.py                       | 136 +++++++++
 benchmarks/rlhf/prepare.py                    |  54 ++++
 benchmarks/rlhf/requirements.cuda.txt         | 283 ++++++++++++++++++
 benchmarks/rlhf/requirements.in               |   6 +
 benchmarks/rlhf/voirfile.py                   |  38 +++
 config/base.yaml                              |  32 +-
 scripts/article/run_cuda.sh                   |  11 +-
 .../test_command_reg_one_node.txt             |  38 +++
 .../test_command_reg_two_nodes.txt            |  38 +++
 13 files changed, 735 insertions(+), 6 deletions(-)
 create mode 100644 benchmarks/rlhf/Makefile
 create mode 100644 benchmarks/rlhf/README.md
 create mode 100644 benchmarks/rlhf/benchfile.py
 create mode 100644 benchmarks/rlhf/dev.yaml
 create mode 100755 benchmarks/rlhf/main.py
 create mode 100755 benchmarks/rlhf/prepare.py
 create mode 100644 benchmarks/rlhf/requirements.cuda.txt
 create mode 100644 benchmarks/rlhf/requirements.in
 create mode 100644 benchmarks/rlhf/voirfile.py

diff --git a/benchmarks/rlhf/Makefile b/benchmarks/rlhf/Makefile
new file mode 100644
index 000000000..48f4da4a1
--- /dev/null
+++ b/benchmarks/rlhf/Makefile
@@ -0,0 +1,31 @@
+# Use global base if possible
+ifndef MILABENCH_BASE
+	MILABENCH_BASE="base"
+endif
+
+export MILABENCH_BASE
+
+BENCH_NAME=rlhf
+MILABENCH_CONFIG=dev.yaml
+MILABENCH_ARGS=--config $(MILABENCH_CONFIG) --base $(MILABENCH_BASE)
+
+all:
+	install prepare single gpus nodes
+
+install:
+	milabench install $(MILABENCH_ARGS) --force
+
+prepare:
+	milabench prepare $(MILABENCH_ARGS)
+
+tests: install prepare
+	milabench run $(MILABENCH_ARGS)
+
+single:
+	milabench run $(MILABENCH_ARGS) --select $(BENCH_NAME)-single
+
+gpus:
+	milabench run $(MILABENCH_ARGS) --select $(BENCH_NAME)-gpus
+
+nodes:
+	milabench run $(MILABENCH_ARGS) --select $(BENCH_NAME)-nodes
diff --git a/benchmarks/rlhf/README.md b/benchmarks/rlhf/README.md
new file mode 100644
index 000000000..9c22d45ca
--- /dev/null
+++ b/benchmarks/rlhf/README.md
@@ -0,0 +1,4 @@
+
+# Rlhf
+
+Rewrite this README to explain what the benchmark is!
diff --git a/benchmarks/rlhf/benchfile.py b/benchmarks/rlhf/benchfile.py
new file mode 100644
index 000000000..0cc83fbad
--- /dev/null
+++ b/benchmarks/rlhf/benchfile.py
@@ -0,0 +1,41 @@
+from milabench.pack import Package
+
+
+class Rlhf(Package):
+    # Requirements file installed by install(). It can be empty or absent.
+    base_requirements = "requirements.in"
+
+    # The preparation script called by prepare(). It must be executable,
+    # but it can be any type of script. It can be empty or absent.
+    prepare_script = "prepare.py"
+
+    # The main script called by run(). It must be a Python file. It has to
+    # be present.
+    main_script = "main.py"
+
+    # You can remove the functions below if you don't need to modify them.
+
+    def make_env(self):
+        # Return a dict of environment variables for prepare_script and
+        # main_script.
+        return super().make_env()
+
+    async def install(self):
+        await super().install()  # super() call installs the requirements
+
+    async def prepare(self):
+        await super().prepare()  # super() call executes prepare_script
+
+    def build_run_plan(self):
+        from milabench.commands import PackCommand, AccelerateAllNodes
+
+        main = self.dirs.code / self.main_script
+        plan = PackCommand(self, *self.argv, lazy=True)
+        
+        if False:
+            plan = VoirCommand(plan, cwd=main.parent)
+
+        return AccelerateAllNodes(plan).use_stdout()
+
+
+__pack__ = Rlhf
diff --git a/benchmarks/rlhf/dev.yaml b/benchmarks/rlhf/dev.yaml
new file mode 100644
index 000000000..32893233a
--- /dev/null
+++ b/benchmarks/rlhf/dev.yaml
@@ -0,0 +1,29 @@
+
+rlhf_:
+  inherits: _defaults
+  definition: .
+  install-variant: unpinned
+  install_group: torch
+  plan:
+    method: per_gpu
+
+  argv:
+    --output_dir: "{milabench_extra}/output"
+    --model_name_or_path: EleutherAI/pythia-1b-deduped
+    --per_device_train_batch_size: 64
+    --logging_strategy: "no"
+    --log_level: "critical"
+    --bf16: true
+
+
+rlhf-single:
+  inherits: rlhf_
+  plan:
+    method: per_gpu
+
+
+rlhf-gpus:
+  inherits: rlhf_
+  plan:
+    method: njobs
+    n: 1
diff --git a/benchmarks/rlhf/main.py b/benchmarks/rlhf/main.py
new file mode 100755
index 000000000..1218674d8
--- /dev/null
+++ b/benchmarks/rlhf/main.py
@@ -0,0 +1,136 @@
+#!/usr/bin/env python
+
+import shutil
+
+from accelerate import PartialState
+from datasets import load_dataset
+from transformers import (
+    AutoModelForCausalLM,
+    AutoModelForSequenceClassification,
+    AutoTokenizer,
+    HfArgumentParser,
+)
+
+from trl import ModelConfig
+from trl.trainer.ppov2_trainer import PPOv2Config, PPOv2Trainer
+from trl.trainer.utils import SIMPLE_QUERY_CHAT_TEMPLATE
+
+
+class PPOv2TrainerIntrumented(PPOv2Trainer):
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+        def batch_size_fn(batch):
+            x, y = batch['input_ids'].shape
+            return x * y
+    
+        from benchmate.observer import BenchObserver
+        observer = BenchObserver(
+            batch_size_fn=batch_size_fn, 
+            earlystop=70, 
+            raise_stop_program=True,
+            stdout=True,
+        )
+        
+        self.dataloader = observer.iterate(self.dataloader)
+
+    def generate_completions(self, sampling: bool = False):
+        pass
+
+    def _save_checkpoint(self, *args, **kwargs):
+        pass
+
+    def save_model(self, *args, **kwargs):
+        pass
+
+
+def main():
+    parser = HfArgumentParser((PPOv2Config, ModelConfig))
+    config, model_config = parser.parse_args_into_dataclasses()
+    # remove output_dir if exists
+    shutil.rmtree(config.output_dir, ignore_errors=True)
+
+    ################
+    # Model & Tokenizer
+    ################
+    tokenizer = AutoTokenizer.from_pretrained(
+        model_config.model_name_or_path,
+        padding_side="left",
+        trust_remote_code=model_config.trust_remote_code,
+    )
+    tokenizer.add_special_tokens({"pad_token": "[PAD]"})
+    if tokenizer.chat_template is None:
+        tokenizer.chat_template = SIMPLE_QUERY_CHAT_TEMPLATE
+    value_model = AutoModelForSequenceClassification.from_pretrained(
+        config.reward_model_path, trust_remote_code=model_config.trust_remote_code, num_labels=1
+    )
+    reward_model = AutoModelForSequenceClassification.from_pretrained(
+        config.reward_model_path, trust_remote_code=model_config.trust_remote_code, num_labels=1
+    )
+    ref_policy = AutoModelForCausalLM.from_pretrained(
+        config.sft_model_path, trust_remote_code=model_config.trust_remote_code
+    )
+    policy = AutoModelForCausalLM.from_pretrained(
+        config.sft_model_path, trust_remote_code=model_config.trust_remote_code
+    )
+    ################
+    # Dataset
+    ################
+    raw_datasets = load_dataset("trl-internal-testing/descriptiveness-sentiment-trl-style", split="descriptiveness")
+    eval_samples = 20
+    train_dataset = raw_datasets.select(range(len(raw_datasets) - eval_samples))
+    eval_dataset = raw_datasets.select(range(len(raw_datasets) - eval_samples, len(raw_datasets)))
+    dataset_text_field = "prompt"
+
+    def prepare_dataset(dataset, tokenizer):
+        """pre-tokenize the dataset before training; only collate during training"""
+
+        def tokenize(element):
+            outputs = tokenizer(
+                element[dataset_text_field],
+                padding=False,
+            )
+            return {"input_ids": outputs["input_ids"]}
+
+        return dataset.map(
+            tokenize,
+            batched=True,
+            remove_columns=dataset.column_names,
+            num_proc=config.dataset_num_proc,
+        )
+
+    # Compute that only on the main process for faster data processing.
+    # see: https://github.com/huggingface/trl/pull/1255
+    with PartialState().local_main_process_first():
+        train_dataset = prepare_dataset(train_dataset, tokenizer)
+        eval_dataset = prepare_dataset(eval_dataset, tokenizer)
+
+    ################
+    # Training
+    ################
+    trainer = PPOv2TrainerIntrumented(
+        config=config,
+        tokenizer=tokenizer,
+        policy=policy,
+        ref_policy=ref_policy,
+        reward_model=reward_model,
+        value_model=value_model,
+        train_dataset=train_dataset,
+        eval_dataset=eval_dataset,
+    )
+    trainer.train()
+    trainer.save_model(config.output_dir)
+    if config.push_to_hub:
+        trainer.push_to_hub()
+    trainer.generate_completions()
+
+
+if __name__ == "__main__":
+    from voir.phase import StopProgram
+    from benchmate.monitor import bench_monitor
+
+    try:
+        with bench_monitor():
+            main()
+    except StopProgram:
+        pass
diff --git a/benchmarks/rlhf/prepare.py b/benchmarks/rlhf/prepare.py
new file mode 100755
index 000000000..4c9aa631f
--- /dev/null
+++ b/benchmarks/rlhf/prepare.py
@@ -0,0 +1,54 @@
+#!/usr/bin/env python
+
+import shutil
+
+from transformers import (
+    AutoModelForCausalLM,
+    AutoModelForSequenceClassification,
+    AutoTokenizer,
+    HfArgumentParser,
+)
+from datasets import load_dataset
+from trl import ModelConfig
+from trl.trainer.ppov2_trainer import PPOv2Config
+from trl.trainer.utils import SIMPLE_QUERY_CHAT_TEMPLATE
+
+
+if __name__ == "__main__":
+    parser = HfArgumentParser((PPOv2Config, ModelConfig))
+    config, model_config = parser.parse_args_into_dataclasses()
+    
+    # remove output_dir if exists
+    shutil.rmtree(config.output_dir, ignore_errors=True)
+
+    tokenizer = AutoTokenizer.from_pretrained(
+        model_config.model_name_or_path,
+        padding_side="left",
+        trust_remote_code=model_config.trust_remote_code,
+    )
+
+    tokenizer.add_special_tokens({"pad_token": "[PAD]"})
+
+    if tokenizer.chat_template is None:
+        tokenizer.chat_template = SIMPLE_QUERY_CHAT_TEMPLATE
+    
+    value_model = AutoModelForSequenceClassification.from_pretrained(
+        config.reward_model_path, 
+        trust_remote_code=model_config.trust_remote_code, 
+        num_labels=1
+    )
+    reward_model = AutoModelForSequenceClassification.from_pretrained(
+        config.reward_model_path, 
+        trust_remote_code=model_config.trust_remote_code, 
+        num_labels=1
+    )
+    ref_policy = AutoModelForCausalLM.from_pretrained(
+        config.sft_model_path,
+        trust_remote_code=model_config.trust_remote_code
+    )
+    policy = AutoModelForCausalLM.from_pretrained(
+        config.sft_model_path, 
+        trust_remote_code=model_config.trust_remote_code
+    )
+
+    raw_datasets = load_dataset("trl-internal-testing/descriptiveness-sentiment-trl-style", split="descriptiveness")
diff --git a/benchmarks/rlhf/requirements.cuda.txt b/benchmarks/rlhf/requirements.cuda.txt
new file mode 100644
index 000000000..764afb978
--- /dev/null
+++ b/benchmarks/rlhf/requirements.cuda.txt
@@ -0,0 +1,283 @@
+#
+# This file is autogenerated by pip-compile with Python 3.10
+# by the following command:
+#
+#    pip-compile --output-file=benchmarks/rlhf/requirements.cuda.txt .pin/tmp-constraints-cuda-rlhf.txt benchmarks/rlhf/requirements.in
+#
+--extra-index-url https://pypi.ngc.nvidia.com
+--extra-index-url https://download.pytorch.org/whl/cu121
+--find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
+--trusted-host pypi.ngc.nvidia.com
+
+accelerate==0.34.2
+    # via
+    #   -r benchmarks/rlhf/requirements.in
+    #   trl
+aiohappyeyeballs==2.4.0
+    # via aiohttp
+aiohttp==3.10.5
+    # via
+    #   datasets
+    #   fsspec
+aiosignal==1.3.1
+    # via aiohttp
+antlr4-python3-runtime==4.9.3
+    # via omegaconf
+asttokens==2.4.1
+    # via giving
+async-timeout==4.0.3
+    # via aiohttp
+attrs==24.2.0
+    # via aiohttp
+certifi==2024.8.30
+    # via requests
+charset-normalizer==3.3.2
+    # via requests
+codefind==0.1.7
+    # via ptera
+datasets==2.21.0
+    # via
+    #   -r benchmarks/rlhf/requirements.in
+    #   trl
+dill==0.3.8
+    # via
+    #   datasets
+    #   multiprocess
+docstring-parser==0.16
+    # via tyro
+executing==1.2.0
+    # via varname
+filelock==3.16.0
+    # via
+    #   datasets
+    #   huggingface-hub
+    #   torch
+    #   transformers
+    #   triton
+frozenlist==1.4.1
+    # via
+    #   aiohttp
+    #   aiosignal
+fsspec[http]==2024.6.1
+    # via
+    #   datasets
+    #   huggingface-hub
+    #   torch
+giving==0.4.2
+    # via
+    #   ptera
+    #   voir
+huggingface-hub==0.24.6
+    # via
+    #   accelerate
+    #   datasets
+    #   tokenizers
+    #   transformers
+idna==3.8
+    # via
+    #   requests
+    #   yarl
+jax[cuda12]==0.4.31
+    # via -r .pin/../constraints/extra/torch.cuda.txt
+jax-cuda12-pjrt==0.4.31
+    # via jax-cuda12-plugin
+jax-cuda12-plugin[with-cuda]==0.4.31
+    # via jax
+jaxlib==0.4.31
+    # via jax
+jinja2==3.1.4
+    # via torch
+markdown-it-py==3.0.0
+    # via rich
+markupsafe==2.1.5
+    # via jinja2
+mdurl==0.1.2
+    # via markdown-it-py
+ml-dtypes==0.4.0
+    # via
+    #   jax
+    #   jaxlib
+mpmath==1.3.0
+    # via sympy
+multidict==6.1.0
+    # via
+    #   aiohttp
+    #   yarl
+multiprocess==0.70.16
+    # via datasets
+networkx==3.3
+    # via torch
+numpy==2.1.1
+    # via
+    #   accelerate
+    #   datasets
+    #   jax
+    #   jaxlib
+    #   ml-dtypes
+    #   opt-einsum
+    #   pandas
+    #   pyarrow
+    #   scipy
+    #   transformers
+    #   trl
+    #   xformers
+nvidia-cublas-cu12==12.1.3.1
+    # via
+    #   jax-cuda12-plugin
+    #   nvidia-cudnn-cu12
+    #   nvidia-cusolver-cu12
+    #   torch
+nvidia-cuda-cupti-cu12==12.1.105
+    # via
+    #   jax-cuda12-plugin
+    #   torch
+nvidia-cuda-nvcc-cu12==12.6.68
+    # via jax-cuda12-plugin
+nvidia-cuda-nvrtc-cu12==12.1.105
+    # via torch
+nvidia-cuda-runtime-cu12==12.1.105
+    # via
+    #   jax-cuda12-plugin
+    #   torch
+nvidia-cudnn-cu12==9.1.0.70
+    # via
+    #   jax-cuda12-plugin
+    #   torch
+nvidia-cufft-cu12==11.0.2.54
+    # via
+    #   jax-cuda12-plugin
+    #   torch
+nvidia-curand-cu12==10.3.2.106
+    # via torch
+nvidia-cusolver-cu12==11.4.5.107
+    # via
+    #   jax-cuda12-plugin
+    #   torch
+nvidia-cusparse-cu12==12.1.0.106
+    # via
+    #   jax-cuda12-plugin
+    #   nvidia-cusolver-cu12
+    #   torch
+nvidia-ml-py==12.560.30
+    # via voir
+nvidia-nccl-cu12==2.20.5
+    # via
+    #   jax-cuda12-plugin
+    #   torch
+nvidia-nvjitlink-cu12==12.6.68
+    # via
+    #   jax-cuda12-plugin
+    #   nvidia-cusolver-cu12
+    #   nvidia-cusparse-cu12
+nvidia-nvtx-cu12==12.1.105
+    # via torch
+omegaconf==2.3.0
+    # via voir
+opt-einsum==3.3.0
+    # via jax
+ovld==0.3.9
+    # via voir
+packaging==24.1
+    # via
+    #   accelerate
+    #   datasets
+    #   huggingface-hub
+    #   transformers
+pandas==2.2.2
+    # via datasets
+psutil==5.9.8
+    # via
+    #   accelerate
+    #   voir
+ptera==1.4.1
+    # via voir
+pyarrow==17.0.0
+    # via datasets
+pygments==2.18.0
+    # via rich
+python-dateutil==2.9.0.post0
+    # via pandas
+pytz==2024.1
+    # via pandas
+pyyaml==6.0.2
+    # via
+    #   accelerate
+    #   datasets
+    #   huggingface-hub
+    #   omegaconf
+    #   transformers
+reactivex==4.0.4
+    # via giving
+regex==2024.7.24
+    # via transformers
+requests==2.32.3
+    # via
+    #   datasets
+    #   huggingface-hub
+    #   transformers
+rich==13.8.1
+    # via
+    #   tyro
+    #   voir
+safetensors==0.4.5
+    # via
+    #   accelerate
+    #   transformers
+scipy==1.14.1
+    # via
+    #   jax
+    #   jaxlib
+shtab==1.7.1
+    # via tyro
+six==1.16.0
+    # via
+    #   asttokens
+    #   python-dateutil
+sympy==1.13.2
+    # via torch
+tokenizers==0.19.1
+    # via transformers
+torch==2.4.0+cu121
+    # via
+    #   -r benchmarks/rlhf/requirements.in
+    #   accelerate
+    #   trl
+    #   xformers
+tqdm==4.66.5
+    # via
+    #   datasets
+    #   huggingface-hub
+    #   transformers
+transformers==4.44.2
+    # via
+    #   -r benchmarks/rlhf/requirements.in
+    #   trl
+triton==3.0.0
+    # via torch
+trl==0.10.1
+    # via -r benchmarks/rlhf/requirements.in
+typing-extensions==4.12.2
+    # via
+    #   huggingface-hub
+    #   multidict
+    #   reactivex
+    #   torch
+    #   tyro
+tyro==0.8.10
+    # via trl
+tzdata==2024.1
+    # via pandas
+urllib3==2.2.2
+    # via requests
+varname==0.10.0
+    # via giving
+voir==0.2.19
+    # via
+    #   -c .pin/../constraints/cuda.txt
+    #   -r benchmarks/rlhf/requirements.in
+xformers==0.0.27.post2
+    # via -r .pin/../constraints/extra/torch.cuda.txt
+xxhash==3.5.0
+    # via datasets
+yarl==1.11.1
+    # via aiohttp
diff --git a/benchmarks/rlhf/requirements.in b/benchmarks/rlhf/requirements.in
new file mode 100644
index 000000000..045bca09c
--- /dev/null
+++ b/benchmarks/rlhf/requirements.in
@@ -0,0 +1,6 @@
+voir>=0.2.19,<0.3
+torch
+trl
+accelerate
+transformers
+datasets
diff --git a/benchmarks/rlhf/voirfile.py b/benchmarks/rlhf/voirfile.py
new file mode 100644
index 000000000..d93f886cd
--- /dev/null
+++ b/benchmarks/rlhf/voirfile.py
@@ -0,0 +1,38 @@
+from dataclasses import dataclass
+
+from voir import configurable
+from voir.instruments import dash, early_stop, log, rate
+from benchmate.monitor import monitor_monogpu
+
+@dataclass
+class Config:
+    """voir configuration"""
+
+    # Whether to display the dash or not
+    dash: bool = False
+
+    # How often to log the rates
+    interval: str = "1s"
+
+    # Number of rates to skip before logging
+    skip: int = 5
+
+    # Number of rates to log before stopping
+    stop: int = 20
+
+    # Number of seconds between each gpu poll
+    gpu_poll: int = 3
+
+
+@configurable
+def instrument_main(ov, options: Config):
+    yield ov.phases.init
+
+    if options.dash:
+        ov.require(dash)
+
+    ov.require(
+        log("value", "progress", "rate", "units", "loss", "gpudata", context="task"),
+        early_stop(n=options.stop, key="rate", task="train"),
+        monitor_monogpu(poll_interval=options.gpu_poll),
+    )
diff --git a/config/base.yaml b/config/base.yaml
index 7bd21fc30..7bdaad591 100644
--- a/config/base.yaml
+++ b/config/base.yaml
@@ -774,4 +774,34 @@ llava-gpus:
   argv:
     --batch_size: 1
     --num_workers: 4
-    --gradient_accumulation_steps: 1
\ No newline at end of file
+    --gradient_accumulation_steps: 1
+
+
+rlhf_:
+  inherits: _defaults
+  definition: ../benchmarks/rlhf
+  install-variant: unpinned
+  install_group: torch
+  plan:
+    method: per_gpu
+  tags:
+    - rlhf
+    - llm
+  argv:
+    --output_dir: "{milabench_extra}/output"
+    --model_name_or_path: EleutherAI/pythia-1b-deduped
+    --per_device_train_batch_size: 64
+    --logging_strategy: "no"
+    --log_level: "critical"
+    --bf16: true
+
+rlhf-single:
+  inherits: rlhf_
+  plan:
+    method: per_gpu
+
+rlhf-gpus:
+  inherits: rlhf_
+  plan:
+    method: njobs
+    n: 1
diff --git a/scripts/article/run_cuda.sh b/scripts/article/run_cuda.sh
index 527adf6c0..1d0055dc0 100644
--- a/scripts/article/run_cuda.sh
+++ b/scripts/article/run_cuda.sh
@@ -77,21 +77,22 @@ else
 fi
 
 
+export MILABENCH_CONFIG=/home/mila/d/delaunap/milabench/benchmarks/rlhf/dev.yaml
+
 if [ "$MILABENCH_PREPARE" -eq 0 ]; then
     cd $MILABENCH_WORDIR
 
-    pip install xformers torch
-    milabench pin --variant cuda  --from-scratch
+    . $MILABENCH_WORDIR/env/bin/activate
 
+    # milabench pin --variant cuda  --from-scratch
     # milabench install --system $MILABENCH_WORDIR/system.yaml --force $ARGS
-    
     # milabench prepare --system $MILABENCH_WORDIR/system.yaml $ARGS
 
     #
     #   Run the benchmakrs
-    # milabench run --system $MILABENCH_WORDIR/system.yaml "$@"
+    milabench run --system $MILABENCH_WORDIR/system.yaml $ARGS
 
     #
     #   Display report
-    # milabench report --runs $MILABENCH_WORDIR/results/runs
+    milabench report --runs $MILABENCH_WORDIR/results/runs
 fi
\ No newline at end of file
diff --git a/tests/test_command_reg/test_command_reg_one_node.txt b/tests/test_command_reg/test_command_reg_one_node.txt
index 2cbaa36a0..05a286f8a 100644
--- a/tests/test_command_reg/test_command_reg_one_node.txt
+++ b/tests/test_command_reg/test_command_reg_one_node.txt
@@ -571,3 +571,41 @@ time (
   wait
 )
 
+echo "---"
+echo "rlhf_"
+echo "====="
+time (
+  CUDA_VISIBLE_DEVICES=0 $SRC/milabench/benchmarks/rlhf/main.py --output_dir $BASE/extra/rlhf_/output --model_name_or_path EleutherAI/pythia-1b-deduped --per_device_train_batch_size 64 --logging_strategy no --log_level critical --bf16 &
+  CUDA_VISIBLE_DEVICES=1 $SRC/milabench/benchmarks/rlhf/main.py --output_dir $BASE/extra/rlhf_/output --model_name_or_path EleutherAI/pythia-1b-deduped --per_device_train_batch_size 64 --logging_strategy no --log_level critical --bf16 &
+  CUDA_VISIBLE_DEVICES=2 $SRC/milabench/benchmarks/rlhf/main.py --output_dir $BASE/extra/rlhf_/output --model_name_or_path EleutherAI/pythia-1b-deduped --per_device_train_batch_size 64 --logging_strategy no --log_level critical --bf16 &
+  CUDA_VISIBLE_DEVICES=3 $SRC/milabench/benchmarks/rlhf/main.py --output_dir $BASE/extra/rlhf_/output --model_name_or_path EleutherAI/pythia-1b-deduped --per_device_train_batch_size 64 --logging_strategy no --log_level critical --bf16 &
+  CUDA_VISIBLE_DEVICES=4 $SRC/milabench/benchmarks/rlhf/main.py --output_dir $BASE/extra/rlhf_/output --model_name_or_path EleutherAI/pythia-1b-deduped --per_device_train_batch_size 64 --logging_strategy no --log_level critical --bf16 &
+  CUDA_VISIBLE_DEVICES=5 $SRC/milabench/benchmarks/rlhf/main.py --output_dir $BASE/extra/rlhf_/output --model_name_or_path EleutherAI/pythia-1b-deduped --per_device_train_batch_size 64 --logging_strategy no --log_level critical --bf16 &
+  CUDA_VISIBLE_DEVICES=6 $SRC/milabench/benchmarks/rlhf/main.py --output_dir $BASE/extra/rlhf_/output --model_name_or_path EleutherAI/pythia-1b-deduped --per_device_train_batch_size 64 --logging_strategy no --log_level critical --bf16 &
+  CUDA_VISIBLE_DEVICES=7 $SRC/milabench/benchmarks/rlhf/main.py --output_dir $BASE/extra/rlhf_/output --model_name_or_path EleutherAI/pythia-1b-deduped --per_device_train_batch_size 64 --logging_strategy no --log_level critical --bf16 &
+  wait
+)
+
+echo "---"
+echo "rlhf-single"
+echo "==========="
+time (
+  CUDA_VISIBLE_DEVICES=0 $SRC/milabench/benchmarks/rlhf/main.py --output_dir $BASE/extra/rlhf-single/output --model_name_or_path EleutherAI/pythia-1b-deduped --per_device_train_batch_size 64 --logging_strategy no --log_level critical --bf16 &
+  CUDA_VISIBLE_DEVICES=1 $SRC/milabench/benchmarks/rlhf/main.py --output_dir $BASE/extra/rlhf-single/output --model_name_or_path EleutherAI/pythia-1b-deduped --per_device_train_batch_size 64 --logging_strategy no --log_level critical --bf16 &
+  CUDA_VISIBLE_DEVICES=2 $SRC/milabench/benchmarks/rlhf/main.py --output_dir $BASE/extra/rlhf-single/output --model_name_or_path EleutherAI/pythia-1b-deduped --per_device_train_batch_size 64 --logging_strategy no --log_level critical --bf16 &
+  CUDA_VISIBLE_DEVICES=3 $SRC/milabench/benchmarks/rlhf/main.py --output_dir $BASE/extra/rlhf-single/output --model_name_or_path EleutherAI/pythia-1b-deduped --per_device_train_batch_size 64 --logging_strategy no --log_level critical --bf16 &
+  CUDA_VISIBLE_DEVICES=4 $SRC/milabench/benchmarks/rlhf/main.py --output_dir $BASE/extra/rlhf-single/output --model_name_or_path EleutherAI/pythia-1b-deduped --per_device_train_batch_size 64 --logging_strategy no --log_level critical --bf16 &
+  CUDA_VISIBLE_DEVICES=5 $SRC/milabench/benchmarks/rlhf/main.py --output_dir $BASE/extra/rlhf-single/output --model_name_or_path EleutherAI/pythia-1b-deduped --per_device_train_batch_size 64 --logging_strategy no --log_level critical --bf16 &
+  CUDA_VISIBLE_DEVICES=6 $SRC/milabench/benchmarks/rlhf/main.py --output_dir $BASE/extra/rlhf-single/output --model_name_or_path EleutherAI/pythia-1b-deduped --per_device_train_batch_size 64 --logging_strategy no --log_level critical --bf16 &
+  CUDA_VISIBLE_DEVICES=7 $SRC/milabench/benchmarks/rlhf/main.py --output_dir $BASE/extra/rlhf-single/output --model_name_or_path EleutherAI/pythia-1b-deduped --per_device_train_batch_size 64 --logging_strategy no --log_level critical --bf16 &
+  wait
+)
+
+echo "---"
+echo "rlhf-gpus"
+echo "========="
+time (
+  $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache accelerate launch --mixed_precision=bf16 --dynamo_backend=no --machine_rank=0 --num_machines=1 --multi_gpu --gradient_accumulation_steps=1 --num_cpu_threads_per_process=4 --main_process_ip=127.0.0.1 --main_process_port=29400 --num_processes=8 $SRC/milabench/benchmarks/rlhf/main.py --output_dir $BASE/extra/rlhf-gpus/output --model_name_or_path EleutherAI/pythia-1b-deduped --per_device_train_batch_size 64 --logging_strategy no --log_level critical --bf16 &
+  wait
+)
+
diff --git a/tests/test_command_reg/test_command_reg_two_nodes.txt b/tests/test_command_reg/test_command_reg_two_nodes.txt
index 0916f5890..c84460dea 100644
--- a/tests/test_command_reg/test_command_reg_two_nodes.txt
+++ b/tests/test_command_reg/test_command_reg_two_nodes.txt
@@ -574,3 +574,41 @@ time (
   wait
 )
 
+echo "---"
+echo "rlhf_"
+echo "====="
+time (
+  CUDA_VISIBLE_DEVICES=0 $SRC/milabench/benchmarks/rlhf/main.py --output_dir $BASE/extra/rlhf_/output --model_name_or_path EleutherAI/pythia-1b-deduped --per_device_train_batch_size 64 --logging_strategy no --log_level critical --bf16 &
+  CUDA_VISIBLE_DEVICES=1 $SRC/milabench/benchmarks/rlhf/main.py --output_dir $BASE/extra/rlhf_/output --model_name_or_path EleutherAI/pythia-1b-deduped --per_device_train_batch_size 64 --logging_strategy no --log_level critical --bf16 &
+  CUDA_VISIBLE_DEVICES=2 $SRC/milabench/benchmarks/rlhf/main.py --output_dir $BASE/extra/rlhf_/output --model_name_or_path EleutherAI/pythia-1b-deduped --per_device_train_batch_size 64 --logging_strategy no --log_level critical --bf16 &
+  CUDA_VISIBLE_DEVICES=3 $SRC/milabench/benchmarks/rlhf/main.py --output_dir $BASE/extra/rlhf_/output --model_name_or_path EleutherAI/pythia-1b-deduped --per_device_train_batch_size 64 --logging_strategy no --log_level critical --bf16 &
+  CUDA_VISIBLE_DEVICES=4 $SRC/milabench/benchmarks/rlhf/main.py --output_dir $BASE/extra/rlhf_/output --model_name_or_path EleutherAI/pythia-1b-deduped --per_device_train_batch_size 64 --logging_strategy no --log_level critical --bf16 &
+  CUDA_VISIBLE_DEVICES=5 $SRC/milabench/benchmarks/rlhf/main.py --output_dir $BASE/extra/rlhf_/output --model_name_or_path EleutherAI/pythia-1b-deduped --per_device_train_batch_size 64 --logging_strategy no --log_level critical --bf16 &
+  CUDA_VISIBLE_DEVICES=6 $SRC/milabench/benchmarks/rlhf/main.py --output_dir $BASE/extra/rlhf_/output --model_name_or_path EleutherAI/pythia-1b-deduped --per_device_train_batch_size 64 --logging_strategy no --log_level critical --bf16 &
+  CUDA_VISIBLE_DEVICES=7 $SRC/milabench/benchmarks/rlhf/main.py --output_dir $BASE/extra/rlhf_/output --model_name_or_path EleutherAI/pythia-1b-deduped --per_device_train_batch_size 64 --logging_strategy no --log_level critical --bf16 &
+  wait
+)
+
+echo "---"
+echo "rlhf-single"
+echo "==========="
+time (
+  CUDA_VISIBLE_DEVICES=0 $SRC/milabench/benchmarks/rlhf/main.py --output_dir $BASE/extra/rlhf-single/output --model_name_or_path EleutherAI/pythia-1b-deduped --per_device_train_batch_size 64 --logging_strategy no --log_level critical --bf16 &
+  CUDA_VISIBLE_DEVICES=1 $SRC/milabench/benchmarks/rlhf/main.py --output_dir $BASE/extra/rlhf-single/output --model_name_or_path EleutherAI/pythia-1b-deduped --per_device_train_batch_size 64 --logging_strategy no --log_level critical --bf16 &
+  CUDA_VISIBLE_DEVICES=2 $SRC/milabench/benchmarks/rlhf/main.py --output_dir $BASE/extra/rlhf-single/output --model_name_or_path EleutherAI/pythia-1b-deduped --per_device_train_batch_size 64 --logging_strategy no --log_level critical --bf16 &
+  CUDA_VISIBLE_DEVICES=3 $SRC/milabench/benchmarks/rlhf/main.py --output_dir $BASE/extra/rlhf-single/output --model_name_or_path EleutherAI/pythia-1b-deduped --per_device_train_batch_size 64 --logging_strategy no --log_level critical --bf16 &
+  CUDA_VISIBLE_DEVICES=4 $SRC/milabench/benchmarks/rlhf/main.py --output_dir $BASE/extra/rlhf-single/output --model_name_or_path EleutherAI/pythia-1b-deduped --per_device_train_batch_size 64 --logging_strategy no --log_level critical --bf16 &
+  CUDA_VISIBLE_DEVICES=5 $SRC/milabench/benchmarks/rlhf/main.py --output_dir $BASE/extra/rlhf-single/output --model_name_or_path EleutherAI/pythia-1b-deduped --per_device_train_batch_size 64 --logging_strategy no --log_level critical --bf16 &
+  CUDA_VISIBLE_DEVICES=6 $SRC/milabench/benchmarks/rlhf/main.py --output_dir $BASE/extra/rlhf-single/output --model_name_or_path EleutherAI/pythia-1b-deduped --per_device_train_batch_size 64 --logging_strategy no --log_level critical --bf16 &
+  CUDA_VISIBLE_DEVICES=7 $SRC/milabench/benchmarks/rlhf/main.py --output_dir $BASE/extra/rlhf-single/output --model_name_or_path EleutherAI/pythia-1b-deduped --per_device_train_batch_size 64 --logging_strategy no --log_level critical --bf16 &
+  wait
+)
+
+echo "---"
+echo "rlhf-gpus"
+echo "========="
+time (
+  $SRC/milabench/milabench/scripts/activator $BASE/venv/torch $BASE/cache accelerate launch --mixed_precision=bf16 --dynamo_backend=no --machine_rank=0 --num_machines=1 --multi_gpu --gradient_accumulation_steps=1 --num_cpu_threads_per_process=4 --main_process_ip=127.0.0.1 --main_process_port=29400 --num_processes=8 $SRC/milabench/benchmarks/rlhf/main.py --output_dir $BASE/extra/rlhf-gpus/output --model_name_or_path EleutherAI/pythia-1b-deduped --per_device_train_batch_size 64 --logging_strategy no --log_level critical --bf16 &
+  wait
+)
+