From 0e1453aed114cd46c8b8251c7666e1ff009e2081 Mon Sep 17 00:00:00 2001 From: zhaohui Date: Wed, 29 May 2024 02:59:22 +0000 Subject: [PATCH 01/37] add rlhf --- examples/rlhf/four_model_8gpu.py | 188 ++ examples/rlhf/four_model_vllm_8gpu.py | 196 +++ examples/rlhf/quick_start.md | 38 + requirements/rlhf.txt | 3 + setup.py | 2 + xtuner/entry_point.py | 12 +- xtuner/rlhf/__init__.py | 0 xtuner/rlhf/config/__init__.py | 0 xtuner/rlhf/config/config.py | 110 ++ xtuner/rlhf/config/config_consts.py | 18 + xtuner/rlhf/config/config_utils.py | 71 + xtuner/rlhf/coordinator.py | 99 ++ xtuner/rlhf/dataset/__init__.py | 0 xtuner/rlhf/dataset/base.py | 237 +++ .../open_datasets/Anthropic_hh_rlhf.py | 88 + xtuner/rlhf/dataset/open_datasets/__init__.py | 0 xtuner/rlhf/dataset/txt_loader.py | 325 ++++ xtuner/rlhf/envs/__init__.py | 0 xtuner/rlhf/envs/txt_env.py | 220 +++ xtuner/rlhf/logger.py | 91 + xtuner/rlhf/loss/__init__.py | 0 xtuner/rlhf/loss/actor_loss.py | 76 + xtuner/rlhf/loss/critic_loss.py | 70 + xtuner/rlhf/loss/pretrain_loss.py | 65 + xtuner/rlhf/main.py | 169 ++ xtuner/rlhf/model_backend/__init__.py | 0 .../rlhf/model_backend/cuda_memory_stats.py | 52 + xtuner/rlhf/model_backend/dist_utils.py | 63 + xtuner/rlhf/model_backend/generate_utils.py | 167 ++ xtuner/rlhf/model_backend/hf_model_runner.py | 898 ++++++++++ xtuner/rlhf/model_backend/models/__init__.py | 0 .../models/configuration_internlm2.py | 159 ++ .../models/critical_and_reward.py | 110 ++ .../models/modeling_internlm2_p.py | 1536 +++++++++++++++++ xtuner/rlhf/model_backend/net_utils.py | 31 + xtuner/rlhf/model_backend/ray_actor_group.py | 19 + xtuner/rlhf/model_backend/ray_actor_mixin.py | 92 + xtuner/rlhf/model_backend/ray_utils.py | 36 + .../rlhf/model_backend/vllm_model_runner.py | 347 ++++ xtuner/rlhf/model_backend/vllm_worker_wrap.py | 77 + xtuner/rlhf/model_server/__init__.py | 0 .../rlhf/model_server/actor_model_server.py | 99 ++ xtuner/rlhf/model_server/base_model_server.py | 170 ++ .../rlhf/model_server/critic_model_server.py | 9 + xtuner/rlhf/model_server/ref_model_server.py | 5 + .../rlhf/model_server/reward_model_server.py | 43 + xtuner/rlhf/policy_output.py | 174 ++ xtuner/rlhf/repeaters/__init__.py | 0 xtuner/rlhf/repeaters/base.py | 311 ++++ xtuner/rlhf/timer.py | 27 + xtuner/rlhf/tokenizer/__init__.py | 0 xtuner/rlhf/tokenizer/tokenizer_utils.py | 88 + xtuner/rlhf/trainer/__init__.py | 0 xtuner/rlhf/trainer/ppo.py | 173 ++ xtuner/rlhf/utils.py | 65 + xtuner/tools/tokenize_ftdp_datasets.py | 2 +- 56 files changed, 6827 insertions(+), 4 deletions(-) create mode 100644 examples/rlhf/four_model_8gpu.py create mode 100644 examples/rlhf/four_model_vllm_8gpu.py create mode 100644 examples/rlhf/quick_start.md create mode 100644 requirements/rlhf.txt create mode 100644 xtuner/rlhf/__init__.py create mode 100644 xtuner/rlhf/config/__init__.py create mode 100644 xtuner/rlhf/config/config.py create mode 100644 xtuner/rlhf/config/config_consts.py create mode 100644 xtuner/rlhf/config/config_utils.py create mode 100644 xtuner/rlhf/coordinator.py create mode 100644 xtuner/rlhf/dataset/__init__.py create mode 100644 xtuner/rlhf/dataset/base.py create mode 100644 xtuner/rlhf/dataset/open_datasets/Anthropic_hh_rlhf.py create mode 100644 xtuner/rlhf/dataset/open_datasets/__init__.py create mode 100644 xtuner/rlhf/dataset/txt_loader.py create mode 100644 xtuner/rlhf/envs/__init__.py create mode 100644 xtuner/rlhf/envs/txt_env.py create mode 100644 xtuner/rlhf/logger.py create mode 100644 xtuner/rlhf/loss/__init__.py create mode 100644 xtuner/rlhf/loss/actor_loss.py create mode 100644 xtuner/rlhf/loss/critic_loss.py create mode 100644 xtuner/rlhf/loss/pretrain_loss.py create mode 100644 xtuner/rlhf/main.py create mode 100644 xtuner/rlhf/model_backend/__init__.py create mode 100644 xtuner/rlhf/model_backend/cuda_memory_stats.py create mode 100644 xtuner/rlhf/model_backend/dist_utils.py create mode 100644 xtuner/rlhf/model_backend/generate_utils.py create mode 100644 xtuner/rlhf/model_backend/hf_model_runner.py create mode 100644 xtuner/rlhf/model_backend/models/__init__.py create mode 100644 xtuner/rlhf/model_backend/models/configuration_internlm2.py create mode 100644 xtuner/rlhf/model_backend/models/critical_and_reward.py create mode 100644 xtuner/rlhf/model_backend/models/modeling_internlm2_p.py create mode 100644 xtuner/rlhf/model_backend/net_utils.py create mode 100644 xtuner/rlhf/model_backend/ray_actor_group.py create mode 100644 xtuner/rlhf/model_backend/ray_actor_mixin.py create mode 100644 xtuner/rlhf/model_backend/ray_utils.py create mode 100644 xtuner/rlhf/model_backend/vllm_model_runner.py create mode 100644 xtuner/rlhf/model_backend/vllm_worker_wrap.py create mode 100644 xtuner/rlhf/model_server/__init__.py create mode 100644 xtuner/rlhf/model_server/actor_model_server.py create mode 100644 xtuner/rlhf/model_server/base_model_server.py create mode 100644 xtuner/rlhf/model_server/critic_model_server.py create mode 100644 xtuner/rlhf/model_server/ref_model_server.py create mode 100644 xtuner/rlhf/model_server/reward_model_server.py create mode 100644 xtuner/rlhf/policy_output.py create mode 100644 xtuner/rlhf/repeaters/__init__.py create mode 100644 xtuner/rlhf/repeaters/base.py create mode 100644 xtuner/rlhf/timer.py create mode 100644 xtuner/rlhf/tokenizer/__init__.py create mode 100644 xtuner/rlhf/tokenizer/tokenizer_utils.py create mode 100644 xtuner/rlhf/trainer/__init__.py create mode 100644 xtuner/rlhf/trainer/ppo.py create mode 100644 xtuner/rlhf/utils.py diff --git a/examples/rlhf/four_model_8gpu.py b/examples/rlhf/four_model_8gpu.py new file mode 100644 index 000000000..2d96a832d --- /dev/null +++ b/examples/rlhf/four_model_8gpu.py @@ -0,0 +1,188 @@ +import torch + +tokenizer_config = dict( + pad_token_id=0, + eos_token_id=92542, + padding_side='left', +) + +rollout_config = dict( + actor_micro_bs=32, + reward_micro_bs=32, + clip_reward_min=-5, + clip_reward_max=5, + max_new_tokens=10, + async_reward=True, + generate_kwargs={ + 'do_sample': True, + 'temperature': 1.0, + 'top_k': 0, + 'top_p': 0.9, + 'pad_token_id': 0, + 'eos_token_id': 92542, + 'early_stopping': True, + 'num_beams': 1, + 'min_new_tokens': 1, + }) + +repeater_config = dict( + actor_micro_bs=8, + ref_micro_bs=8, + critic_micro_bs=32, + reward_scale=False, + fine_grained_rm=False, + value_ema=False, + kl_coeff=0.01, + gamma=1.0, + gae_lambda=0.99, + answer_end_id=92542, + norm_rewards=True, +) +train_config = dict( + ppo_minibatch=64, + value_minibatch=64, + actor_micro_bs=2, + critic_micro_bs=2, + pretrain_step=0, + save_interval=800, +) + +critic_model_path = 'internlm/internlm2-chat-1_8b-sft' + +model_configs = dict( + actor=dict( + model_path='internlm/internlm2-chat-1_8b-sft', + model_type='actor', + use_flash_attn=False, + trainer_config=dict( + trainer_type='huggingface', + torch_dtype=torch.float32, + train_kwargs=dict( + micro_bsz=1, + lr=1e-6, + total_steps=1e9, + lr_decay_rate=1, + loss_type='per_seq', + ), + parallel=dict( + data=dict(size=2, mode='deepspeed'), + tensor=dict(size=1, mode='1d'), + pipeline=dict(size=1, interleaved_overlap=False), + sequence=False, + ), + deepspeed_config={ + 'zero_optimization': { + 'stage': 2, + 'offload_param': { + 'device': 'none' + }, + 'reduce_bucket_size': 'auto', + 'zero_hpz_partition_size': 1, + 'zero_quantized_weights': False, + 'zero_quantized_gradients': False, + 'stage3_gather_16bit_weights_on_model_save': True, + }, + 'bf16': { + 'enabled': True + }, + 'gradient_clipping': 1.0, + 'prescale_gradients': False, + 'wall_clock_breakdown': False, + 'data_types': { + 'grad_accum_dtype': 'fp32' + }, + 'train_micro_batch_size_per_gpu': 2, + 'gradient_accumulation_steps': 16, + 'train_batch_size': 64 + }), + generator_config=dict(shared_with_trainer=True, ), + ), + reference=dict( + model_path='internlm/internlm2-chat-1_8b-sft', + model_type='reference', + use_flash_attn=False, + trainer_config=dict( + torch_dtype=torch.float32, + trainer_type='huggingface', + parallel=dict( + data=dict(size=2, mode='ddp'), + tensor=dict(size=1, mode='1d'), + pipeline=dict(size=1, interleaved_overlap=False), + sequence=False, + ), + ), + ), + critic=dict( + model_path=critic_model_path, + model_type='critic', + use_flash_attn=False, + trainer_config=dict( + torch_dtype='auto', + trainer_type='huggingface', + train_kwargs=dict( + micro_bsz=1, + lr=5e-6, + total_steps=1e9, + lr_decay_rate=1, + loss_type='per_seq', + ), + parallel=dict( + data=dict(size=2, mode='deepspeed'), + tensor=dict(size=1, mode='1d'), + pipeline=dict(size=1, interleaved_overlap=False), + sequence=False, + ), + deepspeed_config={ + 'zero_optimization': { + 'stage': 2, + 'offload_param': { + 'device': 'none' + }, + 'reduce_bucket_size': 'auto', + 'zero_hpz_partition_size': 1, + 'zero_quantized_weights': False, + 'zero_quantized_gradients': False + }, + 'bf16': { + 'enabled': True + }, + 'gradient_clipping': 1.0, + 'prescale_gradients': False, + 'wall_clock_breakdown': False, + 'data_types': { + 'grad_accum_dtype': 'fp32' + }, + 'train_micro_batch_size_per_gpu': 2, + 'gradient_accumulation_steps': 16, + 'train_batch_size': 64 + }), + ), + reward=dict( + model_path=critic_model_path, + model_type='reward', + use_flash_attn=False, + trainer_config=dict( + trainer_type='huggingface', + torch_dtype='auto', + parallel=dict( + data=dict(size=2, mode='ddp'), + tensor=dict(size=1, mode='1d'), + pipeline=dict(size=1, interleaved_overlap=False), + sequence=False, + ), + ), + ), +) + +dataset_config = { + 'num_samples_each_epoch': + 64, + 'max_seq_len': + 1024, + 'random_seed': + 1024, + 'ppo_datas': [ + 'Anthropic/hh-rlhf/helpful-base::1.0', + 'Anthropic/hh-rlhf/harmless-base::0.5', + ], +} diff --git a/examples/rlhf/four_model_vllm_8gpu.py b/examples/rlhf/four_model_vllm_8gpu.py new file mode 100644 index 000000000..654f57691 --- /dev/null +++ b/examples/rlhf/four_model_vllm_8gpu.py @@ -0,0 +1,196 @@ +import torch + +tokenizer_config = dict( + pad_token_id=0, + eos_token_id=92542, + padding_side='left', +) + +rollout_config = dict( + actor_micro_bs=32, + reward_micro_bs=32, + clip_reward_min=-5, + clip_reward_max=5, + max_new_tokens=10, + async_reward=True, + generate_kwargs={ + 'do_sample': True, + 'temperature': 1.0, + 'top_k': 0, + 'top_p': 0.9, + 'pad_token_id': 0, + 'eos_token_id': 92542, + 'early_stopping': True, + 'num_beams': 1, + 'min_new_tokens': 1, + }) + +repeater_config = dict( + actor_micro_bs=8, + ref_micro_bs=8, + critic_micro_bs=32, + reward_scale=False, + fine_grained_rm=False, + value_ema=False, + kl_coeff=0.01, + gamma=1.0, + gae_lambda=0.99, + answer_end_id=92542, + norm_rewards=True, +) +train_config = dict( + ppo_minibatch=64, + value_minibatch=64, + actor_micro_bs=2, + critic_micro_bs=2, + pretrain_step=0, + save_interval=800, +) +critic_model_path = 'internlm/internlm2-chat-1_8b-sft' + +model_configs = dict( + actor=dict( + model_path='internlm/internlm2-chat-1_8b-sft', + model_type='actor', + use_flash_attn=False, + trainer_config=dict( + trainer_type='huggingface', + torch_dtype=torch.float32, + train_kwargs=dict( + micro_bsz=1, + lr=1e-6, + total_steps=1e9, + lr_decay_rate=1, + loss_type='per_seq', + ), + parallel=dict( + data=dict(size=2, mode='deepspeed'), + tensor=dict(size=1, mode='1d'), + pipeline=dict(size=1, interleaved_overlap=False), + sequence=False, + ), + deepspeed_config={ + 'zero_optimization': { + 'stage': 2, + 'offload_param': { + 'device': 'none' + }, + 'reduce_bucket_size': 'auto', + 'zero_hpz_partition_size': 1, + 'zero_quantized_weights': False, + 'zero_quantized_gradients': False, + 'stage3_gather_16bit_weights_on_model_save': True, + }, + 'bf16': { + 'enabled': True + }, + 'gradient_clipping': 1.0, + 'prescale_gradients': False, + 'wall_clock_breakdown': False, + 'data_types': { + 'grad_accum_dtype': 'fp32' + }, + 'train_micro_batch_size_per_gpu': 2, + 'gradient_accumulation_steps': 16, + 'train_batch_size': 64 + }), + generator_config=dict( + shared_with_trainer=False, + generator_type='vllm', + parallel=dict( + data=dict(size=1, mode='ddp'), + tensor=dict(size=2, mode='1d'), + pipeline=dict(size=1, interleaved_overlap=False), + sequence=False, + ), + ), + ), + reference=dict( + model_path='internlm/internlm2-chat-1_8b-sft', + model_type='reference', + use_flash_attn=False, + trainer_config=dict( + torch_dtype=torch.float32, + trainer_type='huggingface', + parallel=dict( + data=dict(size=1, mode='ddp'), + tensor=dict(size=1, mode='1d'), + pipeline=dict(size=1, interleaved_overlap=False), + sequence=False, + ), + ), + ), + critic=dict( + model_path=critic_model_path, + model_type='critic', + use_flash_attn=False, + trainer_config=dict( + torch_dtype='auto', + trainer_type='huggingface', + train_kwargs=dict( + micro_bsz=1, + lr=5e-6, + total_steps=1e9, + lr_decay_rate=1, + loss_type='per_seq', + ), + parallel=dict( + data=dict(size=2, mode='deepspeed'), + tensor=dict(size=1, mode='1d'), + pipeline=dict(size=1, interleaved_overlap=False), + sequence=False, + ), + deepspeed_config={ + 'zero_optimization': { + 'stage': 2, + 'offload_param': { + 'device': 'none' + }, + 'reduce_bucket_size': 'auto', + 'zero_hpz_partition_size': 1, + 'zero_quantized_weights': False, + 'zero_quantized_gradients': False + }, + 'bf16': { + 'enabled': True + }, + 'gradient_clipping': 1.0, + 'prescale_gradients': False, + 'wall_clock_breakdown': False, + 'data_types': { + 'grad_accum_dtype': 'fp32' + }, + 'train_micro_batch_size_per_gpu': 2, + 'gradient_accumulation_steps': 16, + 'train_batch_size': 64 + }), + ), + reward=dict( + model_path=critic_model_path, + model_type='reward', + use_flash_attn=False, + trainer_config=dict( + trainer_type='huggingface', + torch_dtype='auto', + parallel=dict( + data=dict(size=1, mode='ddp'), + tensor=dict(size=1, mode='1d'), + pipeline=dict(size=1, interleaved_overlap=False), + sequence=False, + ), + ), + ), +) + +dataset_config = { + 'num_samples_each_epoch': + 64, + 'max_seq_len': + 1024, + 'random_seed': + 1024, + 'ppo_datas': [ + 'Anthropic/hh-rlhf/helpful-base::1.0', + 'Anthropic/hh-rlhf/harmless-base::0.5', + ], +} diff --git a/examples/rlhf/quick_start.md b/examples/rlhf/quick_start.md new file mode 100644 index 000000000..823b08ad2 --- /dev/null +++ b/examples/rlhf/quick_start.md @@ -0,0 +1,38 @@ +## Quick Start + +### step1: 环境准备 + +``` +# 安装 pytorch +pip install torch==2.1.2+cu118 torchvision --index-url https://download.pytorch.org/whl/cu118 + +# 安装 xtuner rlhf 模块 +git clone https://github.com/2581543189/xtuner.git +cd xtuner +git checkout rlhf +pip install .[rlhf] +``` + +### step2: 使用单引擎(huggingface)启动 rlhf 任务 + +``` +# 启动任务 +xtuner rlhf -c examples/rlhf/four_model_8gpu.py +``` + +### step3: 使用双引擎 (vllm + huggingface) 启动 rlhf 任务 + +``` +# 安装 vllm +export VLLM_VERSION=0.3.3 +export PYTHON_VERSION=310 +pip install https://github.com/vllm-project/vllm/releases/download/v${VLLM_VERSION}/vllm-${VLLM_VERSION}+cu118-cp${PYTHON_VERSION}-cp${PYTHON_VERSION}-manylinux1_x86_64.whl --extra-index-url https://download.pytorch.org/whl/cu118 +pip uninstall xformers -y +pip install xformers==0.0.23.post1 --index-url https://download.pytorch.org/whl/cu118 +pip uninstall cupy-cuda12x -y +pip install cupy-cuda11x==12.1 +python -m cupyx.tools.install_library --library nccl --cuda 11.x + +# 启动任务 +xtuner rlhf -c examples/rlhf/four_model_vllm_8gpu.py +``` diff --git a/requirements/rlhf.txt b/requirements/rlhf.txt new file mode 100644 index 000000000..53f8bbc63 --- /dev/null +++ b/requirements/rlhf.txt @@ -0,0 +1,3 @@ +-r requirements/deepspeed.txt +loguru +ray[default,train]==2.9.1 diff --git a/setup.py b/setup.py index 7a95dfab4..8a3c0a5eb 100644 --- a/setup.py +++ b/setup.py @@ -132,6 +132,8 @@ def gen_packages_items(): 'modelscope': parse_requirements('requirements/runtime.txt') + parse_requirements('requirements/modelscope.txt'), + 'rlhf': + parse_requirements('requirements/rlhf.txt'), }, zip_safe=False, entry_points={'console_scripts': ['xtuner = xtuner:cli']}) diff --git a/xtuner/entry_point.py b/xtuner/entry_point.py index 2af774fd3..404263546 100644 --- a/xtuner/entry_point.py +++ b/xtuner/entry_point.py @@ -12,7 +12,7 @@ # Define valid modes MODES = ('list-cfg', 'copy-cfg', 'log-dataset', 'check-custom-dataset', 'train', 'test', 'chat', 'convert', 'preprocess', 'mmbench', - 'eval_refcoco') + 'eval_refcoco', 'rlhf') CLI_HELP_MSG = \ f""" @@ -207,6 +207,11 @@ def eval_refcoco(): return eval_refcoco.__file__ +def rlhf(): + from xtuner.rlhf import main as rlhf_main + return rlhf_main.__file__ + + modes = { 'list-cfg': list_cfg, 'copy-cfg': copy_cfg, @@ -230,14 +235,15 @@ def eval_refcoco(): '-h': preprocess_help_msg }, 'eval_refcoco': eval_refcoco, - 'list-dataset-format': list_dataset_format + 'list-dataset-format': list_dataset_format, + 'rlhf': rlhf, } HELP_FUNCS = [preprocess_help_msg, convert_help_msg] MAP_FILE_FUNCS = [ list_cfg, copy_cfg, log_dataset, check_custom_dataset, train, test, chat, mmbench, pth_to_hf, merge, split, arxiv_preprocess, eval_refcoco, - convert_refcoco, list_dataset_format + convert_refcoco, list_dataset_format, rlhf ] diff --git a/xtuner/rlhf/__init__.py b/xtuner/rlhf/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/xtuner/rlhf/config/__init__.py b/xtuner/rlhf/config/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/xtuner/rlhf/config/config.py b/xtuner/rlhf/config/config.py new file mode 100644 index 000000000..038aa0aa1 --- /dev/null +++ b/xtuner/rlhf/config/config.py @@ -0,0 +1,110 @@ +# flake8: noqa: E501 +#!/usr/bin/env python + +# Adapted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/config.py + +import inspect +import sys +from importlib.machinery import SourceFileLoader +from pathlib import Path + + +class Config(dict): + """This is a wrapper class for dict objects so that values of which can be + accessed as attributes. + + Args: + config (dict): The dict object to be wrapped. + """ + + def __init__(self, config: dict = None): + if config is not None: + for k, v in config.items(): + self._add_item(k, v) + + def __missing__(self, key): + raise KeyError(key) + + def __getattr__(self, key): + try: + value = super().__getitem__(key) + return value + except KeyError: + raise AttributeError(key) + + def __setattr__(self, key, value): + super().__setitem__(key, value) + + def _add_item(self, key, value): + if isinstance(value, dict): + self.__setattr__(key, Config(value)) + else: + self.__setattr__(key, value) + + def update(self, config): + assert isinstance( + config, + (Config, dict)), 'can only update dictionary or Config objects.' + for k, v in config.items(): + self._add_item(k, v) + return self + + @staticmethod + def from_file(filename: str): + """Reads a python file and constructs a corresponding :class:`Config` + object. + + Args: + filename (str): Name of the file to construct the return object. + + Returns: + :class:`Config`: A :class:`Config` object constructed with information in the file. + + Raises: + AssertionError: Raises an AssertionError if the file does not exist, or the file is not .py file + """ + + # check config path + if isinstance(filename, str): + filepath = Path(filename).absolute() + elif isinstance(filename, Path): + filepath = filename.absolute() + + assert filepath.exists( + ), f'{filename} is not found, please check your configuration path' + + # check extension + extension = filepath.suffix + assert extension == '.py', 'only .py files are supported' + + # import the config as module + remove_path = False + if filepath.parent not in sys.path: + sys.path.insert(0, (filepath)) + remove_path = True + + module_name = filepath.stem + source_file = SourceFileLoader( + fullname=str(module_name), path=str(filepath)) + module = source_file.load_module() + + # load into config + config = Config() + + for k, v in module.__dict__.items(): + if k.startswith('__') or inspect.ismodule(v) or inspect.isclass(v): + continue + else: + config._add_item(k, v) + + # NOTE: variables which starts with __, is a module or class declaration are omitted in config file + # remove module + del sys.modules[module_name] + if remove_path: + sys.path.pop(0) + + return config + + +class ConfigException(Exception): + pass diff --git a/xtuner/rlhf/config/config_consts.py b/xtuner/rlhf/config/config_consts.py new file mode 100644 index 000000000..a54be13db --- /dev/null +++ b/xtuner/rlhf/config/config_consts.py @@ -0,0 +1,18 @@ +# keywords for config files + +# model type (actor, critic, reward, reference, ...) for `model_type` +MODEL_TYPE_ACTOR = 'actor' +MODEL_TYPE_REFERENCE = 'reference' +MODEL_TYPE_REWARD = 'reward' +MODEL_TYPE_CRITIC = 'critic' + +# training or generation engines for `trainer_type` and `generator_type` +ENGINE_HUGGINGFACE = 'huggingface' +ENGINE_INTERNEVO = 'internevo' +ENGINE_VLLM = 'vllm' +ENGINE_LMDEPLOY = 'lmdeploy' + +# plugins for trainer engine (e.g., huggingface accelerate) +ENGINE_PLUGIN_DDP = 'ddp' +ENGINE_PLUGIN_FSDP = 'fsdp' +ENGINE_PLUGIN_DEEPSPEED = 'deepspeed' diff --git a/xtuner/rlhf/config/config_utils.py b/xtuner/rlhf/config/config_utils.py new file mode 100644 index 000000000..6a74d32e2 --- /dev/null +++ b/xtuner/rlhf/config/config_utils.py @@ -0,0 +1,71 @@ +from loguru import logger + + +def get_gpu_requirement(trainer_config: dict) -> int: + # Calculates the number of GPUs required for a given trainer configuration. + num_gpus = 1 + if 'parallel' in trainer_config: + parallel = trainer_config['parallel'] + data = parallel.get('data', {'size': 1}) + tensor = parallel.get('tensor', {'size': 1}) + pipeline = parallel.get('pipeline', {'size': 1}) + num_gpus = data['size'] * tensor['size'] * pipeline['size'] + return num_gpus + + +def get_resource_requirement(model_configs: dict) -> dict: + """Analyzes resource requirements for a list of model configs and returns a + dictionary with the total number of GPUs and CPUs required. + + Args: + model_configs (dict): A dictionary containing model configurations. + + Returns: + dict: A dictionary with the total number of GPUs and CPUs required. + """ + + resources = {'num_gpus': 0} + for name, model_config in model_configs.items(): + if 'trainer_config' not in model_config: + logger.warning(f'{name} has no trainer_config. SKIP.') + continue + trainer_config = model_config['trainer_config'] + num_gpus = get_gpu_requirement(trainer_config) + + if 'generator_config' in model_config: + generator_config = model_config['generator_config'] + if not generator_config.get( + 'shared_with_trainer'): # None or False + num_gpus += get_gpu_requirement(generator_config) + + resources['num_gpus'] += num_gpus + + resources['num_cpus'] = resources['num_gpus'] * 10 + return resources + + +def get_dp_size(trainer_config: dict) -> int: + dp_size = 1 + if 'parallel' in trainer_config: + parallel = trainer_config['parallel'] + data = parallel.get('data', {'size': 1}) + dp_size = data['size'] + return dp_size + + +def get_tp_size(trainer_config: dict) -> int: + tp_size = 1 + if 'parallel' in trainer_config: + parallel = trainer_config['parallel'] + data = parallel.get('tensor', {'size': 1}) + tp_size = data['size'] + return tp_size + + +def get_pp_size(trainer_config: dict) -> int: + pp_size = 1 + if 'parallel' in trainer_config: + parallel = trainer_config['parallel'] + data = parallel.get('pipeline', {'size': 1}) + pp_size = data['size'] + return pp_size diff --git a/xtuner/rlhf/coordinator.py b/xtuner/rlhf/coordinator.py new file mode 100644 index 000000000..2d6fbbc4c --- /dev/null +++ b/xtuner/rlhf/coordinator.py @@ -0,0 +1,99 @@ +from pathlib import Path + +import ray +from loguru import logger + +from .config.config_consts import (MODEL_TYPE_ACTOR, MODEL_TYPE_CRITIC, + MODEL_TYPE_REFERENCE, MODEL_TYPE_REWARD) +from .config.config_utils import get_resource_requirement +from .model_server.actor_model_server import ActorModelServer +from .model_server.base_model_server import BaseModelServer +from .model_server.critic_model_server import CriticModelServer +from .model_server.ref_model_server import RefModelServer +from .model_server.reward_model_server import RewardModelServer + +ROOT_PATH = Path(__file__).parents[1].resolve() + + +class Coordinator: + + def __init__(self, cluster_address: str, model_configs: dict): + self.cluster_address = cluster_address + self.model_configs = model_configs + self.model_dict = dict() + self.context_type: str = None # "client" or "server" + self.context: ray._private.workers.BaseContext = None + + resources = get_resource_requirement(self.model_configs) + logger.info(f'Required resources: {resources}') + runtime_env = {'working_dir': ROOT_PATH} + logger.info(f'working_dir (root_path): {ROOT_PATH}') + + try: + client_context = ray.init( + address=self.cluster_address, + runtime_env=runtime_env, + ignore_reinit_error=True, + ) + logger.info( + f'Connected to a running ray cluster at {self.cluster_address}' + ) + self.context_type = 'client' + self.context = client_context + + except ConnectionError: + logger.info( + f'Error connecting to {self.cluster_address}, try initializing a new ray cluster.' # noqa: E501 + ) + ray_context = ray.init( + address=None, + resources=resources, + runtime_env=runtime_env, + ignore_reinit_error=True, + ) + node_ip_address = ray_context.address_info['node_ip_address'] + logger.info(f'Initialize a ray cluster at {node_ip_address}') + self.context_type = 'server' + self.context = ray_context + + def create_models(self) -> dict[str, BaseModelServer]: + self.model_dict = {} + for model_name, model_config in self.model_configs.items(): + model_type = model_config['model_type'] + if model_type == MODEL_TYPE_ACTOR: + self.model_dict[model_name] = ActorModelServer( + model_name, model_config) + elif model_type == MODEL_TYPE_CRITIC: + self.model_dict[model_name] = CriticModelServer( + model_name, model_config) + elif model_type == MODEL_TYPE_REWARD: + self.model_dict[model_name] = RewardModelServer( + model_name, model_config) + elif model_type == MODEL_TYPE_REFERENCE: + self.model_dict[model_name] = RefModelServer( + model_name, model_config) + else: + raise NotImplementedError(f'Unknown model_type: {model_type}') + self._schedule() + return self.model_dict + + def _schedule(self): + for model_name, model in self.model_dict.items( + ): # naive serial initialize + model.initialize_async() + logger.info( + f'{model_name} {model.__class__.__name__}.is_initialized: {model.is_initialized}' # noqa: E501 + ) + for model_name, model in self.model_dict.items( + ): # naive serial initialize + model.initialize_get() + logger.info( + f'{model_name} {model.__class__.__name__}.is_initialized: {model.is_initialized}' # noqa: E501 + ) + + def clean_up(self): + for _, model_server in self.model_dict.items(): + if model_server.trainer is not None: + model_server.trainer.release_resources() + if model_server.generator is not None: + model_server.generator.release_resources() diff --git a/xtuner/rlhf/dataset/__init__.py b/xtuner/rlhf/dataset/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/xtuner/rlhf/dataset/base.py b/xtuner/rlhf/dataset/base.py new file mode 100644 index 000000000..b64107d8f --- /dev/null +++ b/xtuner/rlhf/dataset/base.py @@ -0,0 +1,237 @@ +"""Basic datasets implement.""" + +import gzip +import json +import random +from contextlib import contextmanager + +import numpy as np +from torch.utils.data import ConcatDataset, Dataset, IterableDataset, Subset + + +@contextmanager +def open_file(filename): + """Construct a file handler. + + The handler can read a normal file or a file compressed by `gzip`. + """ + if filename.endswith('.gz'): + fp = gzip.open(filename, 'rt') + else: + fp = open(filename, encoding='utf-8') + yield fp + fp.close() + + +class InfiniteDataset(IterableDataset): + """Load infinite data from original dataset with shuffle.""" + + def __init__(self, dataset, rng=None): + self.data = list(iter(dataset)) + self.indices = list(range(len(self.data))) + if rng is None: + rng = random.Random() + self.rng = rng + + def __iter__(self): + while True: + self.rng.shuffle(self.indices) + for i in self.indices: + yield self.data[i] + + +class FileDataset(IterableDataset): + """Single json file dataset.""" + + def __init__(self, + filename, + tokenizer, + sys_meta='default', + rm_meta='default'): + self._filename = filename + self.tokenizer = tokenizer + self.data_list = [] + self.sys_meta = sys_meta + self.rm_meta = rm_meta + with open_file(self._filename) as fin: + for lineno, line in enumerate(fin): + data = json.loads(line) + self.data_list.append(data) + + def __len__(self): + return len(self.data_list) + + def __getitem__(self, index: int): + data = self.data_list[index] + try: + self.tokenizer.apply_chat_template(data, tokenize=True) + return { + 'data': data, + 'sys_meta': self.sys_meta, + 'rm_meta': self.rm_meta + } + except Exception: + print(f'[data tokenize check] skip dirty data: {data}') + return None + + def __iter__(self): + with open_file(self._filename) as fin: + for lineno, line in enumerate(fin): + data = json.loads(line) + try: + self.tokenizer.apply_chat_template(data, tokenize=True) + except Exception: + print(f'[data tokenize check] skip dirty data: {data}') + continue + if data is None: + continue + yield { + 'data': data, + 'sys_meta': self.sys_meta, + 'rm_meta': self.rm_meta + } + + +class OpensourceDataset(IterableDataset): + """Opensource dataset.""" + + def __init__(self, + filename, + tokenizer, + sys_meta='default', + rm_meta='default'): + self._filename = filename + self.tokenizer = tokenizer + self.sys_meta = sys_meta + self.rm_meta = rm_meta + assert 'Anthropic' in filename or 'openai' in filename, '[Coming soon] currently only support loading Anthropic and openai opensource datasets...' # noqa: E501 + if 'Anthropic' in filename: + from .open_datasets.Anthropic_hh_rlhf import AnthropicHhrlhf + self.data_list = AnthropicHhrlhf(path=filename) + elif 'openai' in filename: + pass + else: + raise NotImplementedError() + + def __len__(self): + return len(self.data_list) + + def __getitem__(self, index: int): + data = self.data_list[index] + try: + self.tokenizer.apply_chat_template(data, tokenize=True) + return { + 'data': data, + 'sys_meta': self.sys_meta, + 'rm_meta': self.rm_meta + } + except Exception: + print(f'[data tokenize check] skip dirty data: {data}') + return None + + def __iter__(self): + for lineno, data in enumerate(self.data_list): + if data is None: + continue + try: + self.tokenizer.apply_chat_template(data, tokenize=True) + except Exception: + print(f'[data tokenize check] skip dirty data: {data}') + continue + yield { + 'data': data, + 'sys_meta': self.sys_meta, + 'rm_meta': self.rm_meta + } + + +class MultiSourceDatset(IterableDataset): + """Multiple source dataset.""" + + def __init__(self, + task_groups, + sub_dataset_type='file', + tokenizer=None, + random_seed=1024, + ratio_within_datas=True): + self._task_group = [] + for _task in task_groups: + file_path, extra_info = _task.split('::')[0], _task.split('::')[1] + prob = float(extra_info.split('[')[0]) + sys_meta = 'default' + rm_meta = 'default' + if '[META]:' in extra_info: + sys_meta = extra_info.split('[META]:')[-1].split('[')[0] + if '[REWARD_META]:' in extra_info: + rm_meta = extra_info.split('[REWARD_META]:')[-1].split('[')[0] + if prob > 0: + self._task_group.append({ + 'prob': prob, + 'filepath': file_path, + 'sys_meta': sys_meta, + 'rm_meta': rm_meta + }) + print( + f'[DataLoader] Load {_task} with prob:{prob}, sys_meta type: {sys_meta}, reward meta: {rm_meta}' # noqa: E501 + ) + else: + print( + f'[DataLoader] Warning skip file, prob of {file_path} is {prob} ...' # noqa: E501 + ) + assert len(self._task_group) > 0, 'No data to be trained' + if sub_dataset_type == 'file': + for task in self._task_group: + filepath = task['filepath'] + if '.json' in filepath: + task['dataset'] = FileDataset(filepath, tokenizer, + task['sys_meta'], + task['rm_meta']) + else: + # loading opensource datasets + print(f'Try loading {filepath} from huggingface ...') + task['dataset'] = OpensourceDataset( + filepath, tokenizer, task['sys_meta'], task['rm_meta']) + else: + raise NotImplementedError('Cannot support filelist now.') + self.random_seed = random_seed + self.ratio_within_datas = ratio_within_datas + + if self.ratio_within_datas: + sum_prob = sum([task['prob'] for task in self._task_group]) + for task in self._task_group: + task['prob'] = task['prob'] / sum_prob + else: + datasets = [] + for i, task in enumerate(self._task_group): + task['dataset'] = self._get_subset_by_ratio( + task['dataset'], task['prob'], random_seed) + datasets.append(task['dataset']) + + self.all_dataset = ConcatDataset(datasets) + self.iter_all_dataset = iter(self.all_dataset) + + def _get_subset_by_ratio(self, dataset: Dataset, ratio: float, seed: int): + np_random = np.random.RandomState(seed) + indices = np.arange(len(dataset)) + np_random.shuffle(indices) + subset_indices = indices[:int(len(dataset) * ratio)] + subset_indices = list(subset_indices) + return Subset(dataset, subset_indices) + + def __iter__(self): + """sample data one task by probs.""" + if self.ratio_within_datas: + rng = random.Random(self.random_seed) + probs = [task['prob'] for task in self._task_group] + # Initialize task iterator + for task in self._task_group: + task['iterator'] = iter(task['dataset']) + while True: + task = rng.choices(self._task_group, weights=probs)[0] + try: + yield from task['iterator'] + except StopIteration: + task['iterator'] = iter(task['dataset']) + yield from task['iterator'] + else: + yield next(self.iter_all_dataset) diff --git a/xtuner/rlhf/dataset/open_datasets/Anthropic_hh_rlhf.py b/xtuner/rlhf/dataset/open_datasets/Anthropic_hh_rlhf.py new file mode 100644 index 000000000..636cb84b5 --- /dev/null +++ b/xtuner/rlhf/dataset/open_datasets/Anthropic_hh_rlhf.py @@ -0,0 +1,88 @@ +import os +import re +from typing import List + +import datasets +from torch.utils.data import Dataset + + +def deduplicate(data: List[str]): + """Deduplicate data while preserving order. + + Refer to https://stackoverflow.com/questions/9792664/converting-a-list-to-a-set-changes-element-order # noqa: E501 + """ + return list(dict.fromkeys(data).keys()) + + +class AnthropicHhrlhf(Dataset): + """ + helpful-base: + train: 42,537 + test: 2,312 + + harmless-base: + train: 43,835 + test: 2,354 + + helpful-online: + train: 22,007 + test: 1,137 + + helpful-rejection-sampled: + train: 52,421 + test: 2,749 + + red-team-attempts: + train: 38,961 + """ + + def __init__(self, + path: str = 'Anthropic/hh-rlhf/helpful-base', + test=False): + super().__init__() + parts = path.split('/') + assert 'Anthropic' in parts and 'hh-rlhf' in parts, f'{self.__class__.__name__}: {path}' # noqa: E501 + if parts.index('hh-rlhf') == len(parts) - 1: + data_dir = None + else: + data_dir = parts[-1] + if os.path.exists('data/' + path): + raw_datasets = datasets.load_from_disk('data/' + path) + else: + print( + f'loading Anthropic/hh-rlhf data_dir={data_dir} from huggingface ...' # noqa: E501 + ) + raw_datasets = datasets.load_dataset( + 'Anthropic/hh-rlhf', data_dir=data_dir, trust_remote_code=True) + raw_datasets.save_to_disk('data/' + path) + if test: + raw_data_list = raw_datasets['test']['chosen'] + else: + raw_data_list = raw_datasets['train']['chosen'] + raw_data_list = [d for d in raw_data_list if d is not None] + raw_data_list = deduplicate(raw_data_list) + self.data_list = [ + self.format_chatml(prompt) for prompt in raw_data_list + ] + self.name = self.__class__.__name__ + '-' + data_dir if data_dir else self.__class__.__name__ # noqa: E501 + + def __len__(self): + return len(self.data_list) + + def __getitem__(self, index: int): + return self.data_list[index] + + @staticmethod + def format_chatml(string): + pattern = r'(\n\nHuman|\n\nAssistant)(.+?)(?=(\n\nHuman|\n\nAssistant|$))' # noqa: E501 + matches = re.findall(pattern, string, re.DOTALL) + messages = [] + for match in matches: + role, content = match[0].strip(), match[1].strip() + if role == 'Human': + messages.append({'role': 'user', 'content': content[2:]}) + elif role == 'Assistant': + messages.append({'role': 'assistant', 'content': content[2:]}) + else: + raise NotImplementedError('role must in Human or Assistant') + return messages diff --git a/xtuner/rlhf/dataset/open_datasets/__init__.py b/xtuner/rlhf/dataset/open_datasets/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/xtuner/rlhf/dataset/txt_loader.py b/xtuner/rlhf/dataset/txt_loader.py new file mode 100644 index 000000000..5f524206f --- /dev/null +++ b/xtuner/rlhf/dataset/txt_loader.py @@ -0,0 +1,325 @@ +"""Finetuning dataset.""" +import random +from dataclasses import dataclass +from typing import List + +import numpy as np +from torch.utils.data import DataLoader, IterableDataset, RandomSampler + +from .base import InfiniteDataset, MultiSourceDatset + + +@dataclass +class Message: + message: List[dict] + sys_meta: str = 'default' + rm_meta: str = 'default' + token_ids: List[int] = None + mes_type: str = 'ppo' + + +class TxtMessageDataset(IterableDataset): + """Create sequences from dataset. + + Args: + sample_strategy (str) ["in_batch", "in_data"]: + "in_batch": + sample data by ratio for every single training batch + "in_data": + merge all data by ratio first and then sample training batch + """ + + def __init__(self, + ppo_datas: list[str] = None, + pt_datas: list[str] = None, + tokenizer=None, + max_seq_len: int = 4096, + num_samples_each_epoch: int = 64, + pt_data_samples: int = 0, + random_seed: int = 110, + sample_strategy: str = 'in_batch', + ratio_within_datas: bool = True, + **kwargs): + + assert sample_strategy in [ + 'in_batch', 'in_data' + ], f"sample_strategy should in ['in_batch', 'in_data'], but got {sample_strategy}" # noqa: E501 + self.sample_strategy = sample_strategy + assert ppo_datas is not None, '[Data error] Specify your data task config' # noqa: E501 + self.tokenizer = tokenizer + assert self.tokenizer.chat_template is not None, 'Make sure tokenizer has chat_template.' # noqa: E501 + + self.ppo_message_dataset = MultiSourceDatset( + task_groups=ppo_datas, + sub_dataset_type='file', + tokenizer=self.tokenizer, + ratio_within_datas=ratio_within_datas) + if pt_data_samples is not None and pt_data_samples != 0: + assert pt_datas is not None, f'[PT DATA error] samples num {pt_data_samples}, while pt_datas is None' # noqa: E501 + self.pt_message_dataset = MultiSourceDatset( + task_groups=pt_datas, + sub_dataset_type='file', + tokenizer=self.tokenizer, + ratio_within_datas=ratio_within_datas) + self.pt_data_per_epoch = pt_data_samples + self.ppo_data_per_epoch = num_samples_each_epoch - self.pt_data_per_epoch # noqa: E501 + else: + self.pt_message_dataset = None + self.pt_data_per_epoch = 0 + self.ppo_data_per_epoch = num_samples_each_epoch + + self.max_seq_len = max_seq_len + self.num_samples_each_epoch = num_samples_each_epoch + + self.random_seed = random_seed + self.rng = random.Random(self.random_seed) + np.random.seed(self.random_seed) + random.seed(self.random_seed) + + if self.sample_strategy == 'in_batch': + self._init_in_batch() + elif self.sample_strategy == 'in_data': + self._init_in_data() + else: + raise NotImplementedError( + f"sample_strategy should in ['in_batch', 'in_data'], but got {sample_strategy}" # noqa: E501 + ) + + self.epoch_index = 0 + + def _init_in_data(self): + print( + '========================= Init in data sampler =========================' # noqa: E501 + ) + if self.pt_data_per_epoch != 0: + assert hasattr(self.pt_message_dataset, 'all_dataset') + pt_sampler = RandomSampler(self.pt_message_dataset.all_dataset) + self.pt_dataloader = iter( + DataLoader( + self.pt_message_dataset.all_dataset, + collate_fn=lambda x: x, + sampler=pt_sampler, + batch_size=self.pt_data_per_epoch)) + print( + f'[PT data] pretrain data per epoch: {self.pt_data_per_epoch}') + + assert hasattr(self.ppo_message_dataset, 'all_dataset') + prompt_sampler = RandomSampler(self.ppo_message_dataset.all_dataset) + self.prompt_dataloader = iter( + DataLoader( + self.ppo_message_dataset.all_dataset, + collate_fn=lambda x: x, + sampler=prompt_sampler, + batch_size=self.ppo_data_per_epoch)) + + print(f'[PPO data] ppo data per epoch: {self.ppo_data_per_epoch}') + print( + f'[Txt] Training dataset initialized, random seed {self.random_seed}.\n' # noqa: E501 + ) + + def yield_in_data(self): + print( + '========================= yield data from data sampler =========================' # noqa: E501 + ) + batch_sequence = [] + ppo_sequence, pt_sequence = [], [] + if self.pt_data_per_epoch != 0: + pt_batch_messages = next(self.pt_dataloader) + for index, message in enumerate(pt_batch_messages): + sequence = self._postprocess_sequence(message, mes_type='pt') + if sequence is not None: + assert sequence.mes_type == 'pt', f'Data type should be pt, but get {sequence.mes_type}' # noqa: E501 + pt_sequence.append(sequence) + if len(pt_sequence) == self.pt_data_per_epoch: + break + assert len( + pt_sequence + ) == self.pt_data_per_epoch, f'{len(pt_sequence)} != {self.pt_data_per_epoch}' # noqa: E501 + + ppo_batch_messages = next(self.prompt_dataloader) + for index, message in enumerate(ppo_batch_messages): + sequence = self._postprocess_sequence(message, mes_type='ppo') + if sequence is not None: + assert sequence.mes_type == 'ppo', f'Data type should be ppo. but get {sequence.mes_type}' # noqa: E501 + ppo_sequence.append(sequence) + if len(ppo_sequence) == self.ppo_data_per_epoch: + break + if len(ppo_sequence) < self.ppo_data_per_epoch: + missed = self.ppo_data_per_epoch - len(ppo_sequence) + print( + f'[Warning] {missed} dirty data, use {missed} data from sampled data...' # noqa: E501 + ) + for i in range(missed): + ppo_sequence.append(ppo_sequence[i]) + + assert len( + ppo_sequence + ) == self.ppo_data_per_epoch, f'{len(ppo_sequence)} == {self.ppo_data_per_epoch}' # noqa: E501 + + print( + f'prepare TxtMessageDataset done: {len(ppo_sequence)} ppo & {len(pt_sequence)} pretrain, for epoch {self.epoch_index}.' # noqa: E501 + ) + batch_sequence = ppo_sequence + pt_sequence + assert len( + batch_sequence + ) == self.num_samples_each_epoch, '[Epoch {self.epoch_index}] Wrong data len' # noqa: E501 + return batch_sequence + + def _init_in_batch(self): + print( + '========================= Init in batch sampler =========================' # noqa: E501 + ) + samples_cnts = [] + pt_data_len = 0 + if self.pt_data_per_epoch != 0: + for task in self.pt_message_dataset._task_group: + task['target_num_each_epoch'] = int( + task['prob'] * self.pt_data_per_epoch + 0.5) + 1 + inner_dataset = InfiniteDataset(task['dataset'], self.rng) + task['iterator'] = iter(inner_dataset) + samples_cnts.append(task['target_num_each_epoch']) + print( + f"[PT data] {task['filepath']}: task prob: {task['prob']}, " # noqa: E501 + f'ori number of messages: {len(inner_dataset.data)}, ' + f"target_num_each_epoch: {task['target_num_each_epoch']}" + ) # noqa: E501 + pt_data_len = sum(samples_cnts) + assert pt_data_len >= self.pt_data_per_epoch, f'Make sure there are enough pretrain data, {pt_data_len} >= {self.pt_data_per_epoch}' # noqa: E501 + print( + f'[PT data] pretrain data per epoch: {self.pt_data_per_epoch}, sampled {pt_data_len}' # noqa: E501 + ) + for task in self.ppo_message_dataset._task_group: + task['target_num_each_epoch'] = int( + task['prob'] * self.ppo_data_per_epoch + 0.5) + 1 + inner_dataset = InfiniteDataset(task['dataset'], self.rng) + task['iterator'] = iter(inner_dataset) + samples_cnts.append(task['target_num_each_epoch']) + print(f"{task['filepath']}: task prob: {task['prob']}, " + f'ori number of messages: {len(inner_dataset.data)}, ' + f"target_num_each_epoch: {task['target_num_each_epoch']}") + assert ( + sum(samples_cnts) - pt_data_len + ) >= self.ppo_data_per_epoch, 'Make sure there are enough ppo datas' + print( + f'[PPO data] ppo data per epoch: {self.ppo_data_per_epoch}, sampled: {sum(samples_cnts) - pt_data_len}' # noqa: E501 + ) + + if sum(samples_cnts) <= self.num_samples_each_epoch: + print( + f'[Txt loader] Warning!!! sample nums {sum(samples_cnts)} <= samples {self.num_samples_each_epoch}' # noqa: E501 + ) + print( + f'[Txt] Training dataset initialized, random seed {self.random_seed}.\n' # noqa: E501 + ) + + def yield_in_batch(self): + print( + '========================= yield data from batch sampler =========================' # noqa: E501 + ) + batch_sequence = [] + ppo_sequence, pt_sequence = [], [] + + # epoch_rng only use in this epoch. + epoch_rng = np.random.RandomState(self.epoch_index) + # prepare epoch data + if self.pt_data_per_epoch != 0: + pt_batch_messages = [] + for task in self.pt_message_dataset._task_group: + messages = [] + for _ in range(task['target_num_each_epoch']): + messages.append(next(task['iterator'])) + print( + f"[PT] prepare {len(messages)} data from {task['filepath']}" # noqa: E501 + ) + epoch_rng.shuffle(messages) + pt_batch_messages.extend(messages) + epoch_rng.shuffle(pt_batch_messages) + for index, message in enumerate(pt_batch_messages): + sequence = self._postprocess_sequence(message, mes_type='pt') + if sequence is not None: + assert sequence.mes_type == 'pt', f'Data type should be pt, but get {sequence.mes_type}' # noqa: E501 + pt_sequence.append(sequence) + if len(pt_sequence) == self.pt_data_per_epoch: + break + assert len( + pt_sequence + ) == self.pt_data_per_epoch, f'{len(pt_sequence)} != {self.pt_data_per_epoch}' # noqa: E501 + + ppo_batch_messages = [] + for task in self.ppo_message_dataset._task_group: + messages = [] + for _ in range(task['target_num_each_epoch']): + messages.append(next(task['iterator'])) + print( + f"[PPO] prepare {len(messages)} data from {task['filepath']}") + epoch_rng.shuffle(messages) + ppo_batch_messages.extend(messages) + epoch_rng.shuffle(ppo_batch_messages) + for index, message in enumerate(ppo_batch_messages): + sequence = self._postprocess_sequence(message, mes_type='ppo') + if sequence is not None: + assert sequence.mes_type == 'ppo', f'Data type should be ppo. but get {sequence.mes_type}' # noqa: E501 + ppo_sequence.append(sequence) + if len(ppo_sequence) == self.ppo_data_per_epoch: + break + assert len( + ppo_sequence + ) == self.ppo_data_per_epoch, f'{len(ppo_sequence)} == {self.ppo_data_per_epoch}' # noqa: E501 + + print( + f'prepare TxtMessageDataset done: {len(ppo_sequence)} ppo & {len(pt_sequence)} pretrain, for epoch {self.epoch_index}.' # noqa: E501 + ) + batch_sequence = ppo_sequence + pt_sequence + assert len( + batch_sequence + ) == self.num_samples_each_epoch, '[Epoch {self.epoch_index}] Wrong data len' # noqa: E501 + return batch_sequence + + def __iter__(self): + while True: + if self.sample_strategy == 'in_batch': + yield self.yield_in_batch() + elif self.sample_strategy == 'in_data': + yield self.yield_in_data() + + self.epoch_index += 1 + + def _postprocess_sequence(self, message, mes_type='ppo'): + """Post process sequence: tokenization & truncation.""" + message_data = message['data'] + new_meaasage_data = [] + if mes_type == 'ppo': + for _ in reversed(range(len(message_data))): + if message_data[_]['role'] == 'user': + new_meaasage_data = message_data[:_ + 1] + break + assert new_meaasage_data[-1][ + 'role'] == 'user', f'ppo data last role must user, {new_meaasage_data}' # noqa: E501 + token_ids = self.tokenizer.apply_chat_template( + new_meaasage_data, + tokenize=True, + add_generation_prompt=True, + return_tensors='pt') + elif mes_type == 'pt': + for _ in reversed(range(len(message_data))): + if message_data[_]['role'] == 'assistant': + new_meaasage_data = message_data[:_ + 1] + break + assert new_meaasage_data[-1][ + 'role'] == 'assistant', f'pretrain data last role must assistant, {new_meaasage_data}' # noqa: E501 + token_ids = self.tokenizer.apply_chat_template( + new_meaasage_data, + tokenize=True, + add_generation_prompt=False, + return_tensors='pt') + if token_ids.shape[-1] <= 4 or token_ids.shape[-1] > self.max_seq_len: + print( + f'[TXT Loader] Warning, {mes_type} message {message} is too short or long, skipped...' # noqa: E501 + ) + return None + return Message( + message=new_meaasage_data, + token_ids=token_ids, + sys_meta=message['sys_meta'], + rm_meta=message['rm_meta'], + mes_type=mes_type) diff --git a/xtuner/rlhf/envs/__init__.py b/xtuner/rlhf/envs/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/xtuner/rlhf/envs/txt_env.py b/xtuner/rlhf/envs/txt_env.py new file mode 100644 index 000000000..1a01d0741 --- /dev/null +++ b/xtuner/rlhf/envs/txt_env.py @@ -0,0 +1,220 @@ +import time +from copy import deepcopy + +import torch +from loguru import logger +from torch.utils.data import IterableDataset + +from ..model_server.base_model_server import BaseModelServer + +META_PROMPT = { + 'default': + '', + 'helpful': + 'Consider the helpfulness of the given sentence. ' + 'A helpful sentence is one that provides value, offers insight, or assists the user in some way. ' # noqa: E501 + 'It may solve a problem, answer a question, provide relevant information, or offer guidance. ' # noqa: E501 + "It should directly address the user's needs or inquiries and is ideally actionable, " # noqa: E501 + 'meaning the user can use the information to make decisions or take steps towards their goal. ' # noqa: E501 + 'The score should be higher for sentences that are more beneficial and informative to the user. ' # noqa: E501 + 'Here is the conversation:\n\n', + 'harmless': + 'Evaluate the harmlessness of the provided sentence. ' + 'A harmless sentence is one that does not offend, insult, or cause distress to the user, ' # noqa: E501 + "and does not mislead or give false information. It respects the user's feelings, beliefs, " # noqa: E501 + 'and values, and does not provoke conflict or negative emotions. ' + 'Additionally, a harmless sentence should not incite harmful actions or behaviors. ' # noqa: E501 + 'It should uphold the principles of respect, honesty, and fairness. ' + 'Sentences that are more considerate, respectful, and accurate should receive higher scores on this scale. ' # noqa: E501 + 'Here is the conversation:\n\n', + 'hhh': + 'Evaluate the given sentence based on its helpfulness, harmlessness, and honesty. ' # noqa: E501 + 'A helpful sentence provides value, solves problems, or aids the user in an actionable manner. ' # noqa: E501 + "A harmless sentence respects the user's feelings and beliefs, avoids causing distress, " # noqa: E501 + 'and does not mislead or incite harmful behaviors. An honest sentence delivers reliable and true information, ' # noqa: E501 + 'presents facts objectively, and demonstrates integrity and authenticity. Higher scores should be assigned ' # noqa: E501 + 'to sentences that embody these characteristics more strongly. ' + 'Here is the conversation:\n\n', + 'summarization': + 'As a language model performing a summarization task, your goal is to generate a summary that ' # noqa: E501 + 'accurately, succinctly, and coherently encapsulates the key details of the source text. Ensure relevance to ' # noqa: E501 + 'the original material, completeness of main points, and logical structure. Maintain conciseness and high ' # noqa: E501 + 'linguistic standards. Ensure only the summary is outputted, refraining from adding extraneous comments or ' # noqa: E501 + 'remarks. Here is the original material:\n\n', + 'reddit': + 'Imagine you are a knowledgeable and friendly Reddit user. ' + 'A fellow Redditor has just shared a post seeking feedback, advice, or input. ' # noqa: E501 + 'Please read the post and provide a thoughtful, informative, and respectful response, ' # noqa: E501 + 'just as if you were replying on the platform. Here is the post:\n\n', + 'latex': + 'When mathematical content appears in the conversation, please use latex format to express the mathematical content. Here is the conversation:\n\n', # noqa: E501 + 'math_ci': + "Integrate step-by-step reasoning and Python code to solve math problems using the following guidelines:\n- Just write jupyter code to solve the problem without giving your thought;\n- Present the final result in LaTeX using a '\\boxed\\{{}}' without any units. \n", # noqa: E501 +} + + +class TxtEnv: + """A generic RL environment to generate textual sequences.""" + + def __init__( + self, + dataloader: IterableDataset, + max_new_tokens: int = 1024, + actor_micro_bs: int = 32, + reward_micro_bs: int = 32, + clip_reward_min: int = -5, + clip_reward_max: int = 5, + reward_function: BaseModelServer = None, + async_reward: bool = True, + generate_kwargs: dict = None, + **_ignored, + ): + """ + Args: + dataloader (IterableDataset): generate rl data iteratively + reward_function: reward function that computes scalar reward for each episode # noqa: E501 + """ + self.dataloader: IterableDataset = iter(dataloader) + self.reward_function: BaseModelServer = reward_function + self._cur_messagess = [] + self.max_new_tokens = max_new_tokens + self.actor_micro_bs = actor_micro_bs + self.reward_micro_bs = reward_micro_bs + self.clip_reward_min = clip_reward_min + self.clip_reward_max = clip_reward_max + self.async_reward = async_reward + self.generate_kwargs: dict = generate_kwargs + + def rollout(self, policy_model: BaseModelServer, display=False): + sample_data = deepcopy(next(self.dataloader)) + ppo_input_messages = [] + pt_input_messages = [] + for data in sample_data: + if data.sys_meta != 'default': + message = deepcopy([{ + 'role': 'system', + 'content': META_PROMPT[data.sys_meta] + }] + data.message) + else: + message = deepcopy(data.message) + if data.mes_type == 'ppo': + ppo_input_messages.append(message) + elif data.mes_type == 'pt': + pt_input_messages.append(message) + else: + raise TypeError(f'Wrong message type {data.mes_type}') + # ppo data + s_t = time.time() + print(f'[For Generate]: {ppo_input_messages[0]}') + trajectories = policy_model.generate( + inputs=ppo_input_messages, + micro_batch_size=self.actor_micro_bs, + step=self.max_new_tokens, + output_str=True, + generate_kwargs=self.generate_kwargs) + logger.info( + f'[actor generate] duration: {round(time.time() - s_t, 2)} s, len(inputs): {len(ppo_input_messages)} ' # noqa: E501 + ) + + if self.async_reward: + reward_output_ref = self.get_reward_async(sample_data, + trajectories) + trajectories['reward_output_ref'] = reward_output_ref + else: + rewards = self.get_reward(sample_data, trajectories) + clipped_rewards = torch.clamp( + rewards, min=self.clip_reward_min, max=self.clip_reward_max) + trajectories['rewards'] = rewards + trajectories['clipped_rewards'] = clipped_rewards + + # pretrain data + if len(pt_input_messages) > 0: + pt_inputs = [ + policy_model.tokenizer.apply_chat_template( + mes, + tokenize=False, + add_generation_prompt=False, + return_tensors='pt') for mes in pt_input_messages + ] + trajectories.pt_data = policy_model.tokenizer( + pt_inputs, return_tensors='pt', padding=True) + print( + f'[TxtEnv & {policy_model.__class__.__name__}] gets {len(pt_input_messages)} pretrain episodes.' # noqa: E501 + ) + + return trajectories + + # default get_reward() is blocking. get_reward_async() needs to call get_reward_collect() # noqa: E501 + def get_reward_async(self, sample_data, policyout): + s_t = time.time() + rm_input_messages = [] + for i in range(len(sample_data)): + if sample_data[i].rm_meta != 'default': + cur_rm_data = [{ + 'role': 'system', + 'content': META_PROMPT[sample_data[i].rm_meta] + }] + sample_data[i].message + [{ + 'role': + 'assistant', + 'content': + policyout.output_ans_str[i] + }] + else: + cur_rm_data = sample_data[i].message + [{ + 'role': + 'assistant', + 'content': + policyout.output_ans_str[i] + }] + rm_input_messages.append(cur_rm_data) + + print(f'[For Reward]: {rm_input_messages[0]}') + reward_output_ref = self.reward_function.infer_async( + rm_input_messages, + output_logprobs=False, + micro_batch_size=self.reward_micro_bs) + logger.info( + f'[reward infer] async duration: {round(time.time() - s_t, 2)} s') + return reward_output_ref + + def get_reward_collect(self, reward_output_ref): + s_t = time.time() + rm_out = self.reward_function.infer_get(reward_output_ref) + logger.info( + f'[reward infer] async wait duration: {round(time.time() - s_t, 2)} s' # noqa: E501 + ) + rewards = rm_out.logits.squeeze(-1) + return rewards + + def get_reward(self, sample_data, policyout): + s_t = time.time() + rm_input_messages = [] + for i in range(len(sample_data)): + if sample_data[i].rm_meta != 'default': + cur_rm_data = [{ + 'role': 'system', + 'content': META_PROMPT[sample_data[i].rm_meta] + }] + sample_data[i].message + [{ + 'role': + 'assistant', + 'content': + policyout.output_ans_str[i] + }] + else: + cur_rm_data = sample_data[i].message + [{ + 'role': + 'assistant', + 'content': + policyout.output_ans_str[i] + }] + rm_input_messages.append(cur_rm_data) + + print(f'[For Reward]: {rm_input_messages[0]}') + rm_out = self.reward_function.infer( + rm_input_messages, + output_logprobs=False, + micro_batch_size=self.reward_micro_bs) + logger.info( + f'[reward infer] duration: {round(time.time() - s_t, 2)} s') + rewards = rm_out.logits.squeeze(-1) + return rewards diff --git a/xtuner/rlhf/logger.py b/xtuner/rlhf/logger.py new file mode 100644 index 000000000..d774f2923 --- /dev/null +++ b/xtuner/rlhf/logger.py @@ -0,0 +1,91 @@ +# Adapted from +# https://github.com/skypilot-org/skypilot/blob/86dc0f6283a335e4aa37b3c10716f90999f48ab6/sky/sky_logging.py +"""Logging configuration.""" +import logging +import sys +from functools import wraps +from time import perf_counter + +_FORMAT = '%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s' +_DATE_FORMAT = '%m-%d %H:%M:%S' + + +class NewLineFormatter(logging.Formatter): + """Adds logging prefix to newlines to align multi-line messages.""" + + def __init__(self, fmt, datefmt=None): + logging.Formatter.__init__(self, fmt, datefmt) + + def format(self, record): + msg = logging.Formatter.format(self, record) + if record.message != '': + parts = msg.split(record.message) + msg = msg.replace('\n', '\r\n' + parts[0]) + return msg + + +_root_logger = logging.getLogger('marl') +_default_handler = None + + +def _setup_logger(): + _root_logger.setLevel(logging.DEBUG) + global _default_handler + if _default_handler is None: + _default_handler = logging.StreamHandler(sys.stdout) + _default_handler.flush = sys.stdout.flush # type: ignore + _default_handler.setLevel(logging.INFO) + _root_logger.addHandler(_default_handler) + fmt = NewLineFormatter(_FORMAT, datefmt=_DATE_FORMAT) + _default_handler.setFormatter(fmt) + # Setting this will avoid the message + # being propagated to the parent logger. + _root_logger.propagate = False + + +# The logger is initialized when the module is imported. +# This is thread-safe as the module is only imported once, +# guaranteed by the Python GIL. +_setup_logger() + + +def init_logger(name: str): + # Use the same settings as above for root logger + logger = logging.getLogger(name) + logger.setLevel(logging.DEBUG) + logger.addHandler(_default_handler) + logger.propagate = False + return logger + + +def log_decorator(logger): + """ + Usage: + @log_decorator(logger) + def func(a, b, ...): + return 1 / 0 + + """ + + def decorator(func): + + @wraps(func) + def wrapper(*args, **kwargs): + logger.info('----------- LOG DECORATOR -----------') + logger.info( + f'CALLED {func.__name__} ARGS: {args}; KWARGS:{kwargs}') + bgn = perf_counter() + try: + result = func(*args, **kwargs) + end = perf_counter() + dur = end - bgn + logger.info( + f'{func.__name__} RESULT: {result}; DURATION: {dur:4f}s') + return result + except Exception as e: + logger.exception(f'{func.__name__}: {e}') + logger.info('----------- LOG DECORATOR -----------') + + return wrapper + + return decorator diff --git a/xtuner/rlhf/loss/__init__.py b/xtuner/rlhf/loss/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/xtuner/rlhf/loss/actor_loss.py b/xtuner/rlhf/loss/actor_loss.py new file mode 100644 index 000000000..cd81c97db --- /dev/null +++ b/xtuner/rlhf/loss/actor_loss.py @@ -0,0 +1,76 @@ +from typing import Any + +import torch + +from ..policy_output import logprobs_from_logits + + +class ActorLoss(torch.nn.Module): + """Loss function for actor model.""" + + def __init__(self, cliprange: float = 0.2, loss_type: str = 'per_seq'): + super().__init__() + self.cliprange = cliprange + self.loss_type = loss_type + assert self.loss_type in ['per_token', 'per_seq'] + + def actor_loss_fn(self, logprobs, old_logprobs, advantages, mask, + loss_factor): + ratio = (logprobs - old_logprobs).exp() + pg_loss1 = -ratio * advantages + pg_loss2 = -ratio.clamp(1 - self.cliprange, + 1 + self.cliprange) * advantages + if self.loss_type == 'per_seq': + pg_loss = (torch.max(pg_loss1, pg_loss2) * mask).sum() / mask.sum() + elif self.loss_type == 'per_token': + pg_loss = torch.sum( + torch.max(pg_loss1, pg_loss2) * mask) * loss_factor + else: + raise RuntimeError( + f"ActorLoss.loss_type must be ['per_seq', 'per_token'], got {self.loss_type}" # noqa: E501 + ) + return pg_loss.mean() + + def forward(self, logits: torch.Tensor, labels: dict[str, Any]): + """Forward function of ActorLoss. + + Args: + logits (Tensor): Forward result of the model. Its shape may be varied. # noqa: E501 + For packed forward: (micro_bsz * seqlen, 1), where micro_bsz = 1 # noqa: E501 + For non packed forward: (micro_bsz, seqlen, 1) + + labels (tuple[dict]): Label values which are split by pipeline + schedule into pieces. The length of the list is micro_bsz. Each + element is a dict, representing labels to a batch. + + Note: + The parameter `labels` seems strange because of pj-colossalai's + pipeline schedule mechanism. Labels are delivered to colosslai.Engine # noqa: E501 + in List format, so pipeline schedule split it into micro_bsz pieces, # noqa: E501 + and deliver them to loss_fn by `*args`. + + Returns: + Tensor: Return the final loss + """ + assert logits.ndim == 3 + mask = labels['mask'] # (micro_bsz, seqlen) + + assert logits.shape[0] == labels['input_ids'].shape[0] + input_ids = labels['input_ids'] # (micro_bsz, seqlen) + old_logprobs = labels['old_logprobs'] # (micro_bsz, seqlen) + advantages = labels['advantages'] # (micro_bsz, seqlen) + loss_factor = labels['loss_factor'] + + logpy = logprobs_from_logits( + logits=logits[:, :-1, :], labels=input_ids[:, 1:], gather=True) + num_actions = mask.size(1) + logprobs = logpy[:, -num_actions:] + + loss = self.actor_loss_fn( + logprobs=logprobs, + old_logprobs=old_logprobs, + advantages=advantages, + mask=mask, + loss_factor=loss_factor, + ) + return loss diff --git a/xtuner/rlhf/loss/critic_loss.py b/xtuner/rlhf/loss/critic_loss.py new file mode 100644 index 000000000..877c21c28 --- /dev/null +++ b/xtuner/rlhf/loss/critic_loss.py @@ -0,0 +1,70 @@ +from typing import Any + +import torch + + +class CriticLoss(torch.nn.Module): + """Loss function for critic model.""" + + def __init__(self, + cliprange_value: float = 100, + loss_type: str = 'per_seq'): + super().__init__() + self.cliprange_value = cliprange_value + self.loss_type = loss_type + assert self.loss_type in ['per_token', 'per_seq'] + + def critic_loss_fn(self, values, old_values, returns, mask, loss_factor): + values_clipped = old_values + (values - old_values).clamp( + -self.cliprange_value, self.cliprange_value) + vf_loss1 = (values_clipped - returns)**2 + vf_loss2 = (values - returns)**2 + + if self.loss_type == 'per_seq': + vf_loss = (torch.max(vf_loss1, vf_loss2) * mask).sum() / mask.sum() + elif self.loss_type == 'per_token': + vf_loss = torch.sum( + torch.max(vf_loss1, vf_loss2) * mask * loss_factor) + else: + raise RuntimeError( + f"CriticLoss.loss_type must be ['per_seq', 'per_token'], got {self.loss_type}" # noqa: E501 + ) + return 0.5 * vf_loss.mean() + + def forward(self, values: torch.Tensor, labels: dict[str, Any]): + """Forward function of CriticLoss. + + Args: + values (Tensor): Forward result of the model. Its shape may be varied. # noqa: E501 + For packed forward: (micro_bsz * seqlen, 1), where micro_bsz = 1 # noqa: E501 + For non packed forward: (micro_bsz, seqlen, 1) + + labels (Tuple[dict]): Label values which are split by pipeline + schedule into pieces. The length of the list is micro_bsz. Each + element is a dict, representing labels to a batch. + + Note: + The parameter `labels` seems strange because of pj-colossalai's + pipeline schedule mechanism. Labels are delivered to colosslai.Engine # noqa: E501 + in List format, so pipeline schedule split it into micro_bsz pieces, # noqa: E501 + and deliver them to loss_fn by `*args`. + + Returns: + Tensor: Return the final loss + """ + assert values.ndim == 2 + mask = labels['mask'] # (micro_bsz, seqlen) + num_actions = mask.size(1) + values = values[:, -num_actions:] + + old_values = labels['old_values'] # (micro_bsz, seqlen) + returns = labels['returns'] # (micro_bsz, seqlen) + loss_factor = labels['loss_factor'] + loss = self.critic_loss_fn( + values=values, + old_values=old_values, + returns=returns, + mask=mask, + loss_factor=loss_factor, + ) + return loss diff --git a/xtuner/rlhf/loss/pretrain_loss.py b/xtuner/rlhf/loss/pretrain_loss.py new file mode 100644 index 000000000..fe08d2a0b --- /dev/null +++ b/xtuner/rlhf/loss/pretrain_loss.py @@ -0,0 +1,65 @@ +import torch +from loguru import logger + +try: + from flash_attn.losses.cross_entropy import \ + CrossEntropyLoss as FlashCrossEntropyLoss +except ImportError: + pass + + +# Adapted from: https://gitlab.pjlab.org.cn/openmmlab/bigmodel/rl3m/-/blob/main/rl3m/layers/loss.py#L37 # noqa: E501 +class FlashGPTLMLoss(torch.nn.Module): + """Loss function for flash GPT Language Model.""" + + def __init__(self, parallel_output=True, label_smoothing=0): + super().__init__() + + if label_smoothing is not None and label_smoothing != 0: + logger.warning(f'Use label_smoothing: {label_smoothing}') + self.label_smoothing = label_smoothing + + if parallel_output: + # The loss in this place is bound to the gather_output initialized by VocabParallelClassifier1D # noqa: E501 + self.loss_fn = FlashCrossEntropyLoss( + reduction='mean', + inplace_backward=True, + process_group=None, + label_smoothing=label_smoothing, + ) + else: + # Here, the output will gather output is set in the model, so use ordinary loss # noqa: E501 + self.loss_fn = torch.nn.CrossEntropyLoss( + reduction='mean', label_smoothing=label_smoothing) + + def forward(self, *args): + if len(args) == 3: + # residual is to match prenorm + logits, _, labels = args + elif len(args) == 2: + # When using postnorm + logits, labels = args + else: + raise RuntimeError( + f'The number of criterion inputs are:{len(args)}') + shift_logits = logits.contiguous().view(-1, logits.size(-1)) + shift_labels = labels.contiguous().view(-1) + loss = self.loss_fn(shift_logits, shift_labels) + # There is no need to consider the ignore_index problem here, because the loss calculation will be # noqa: E501 + # calculated through the calculation range, and -100 must be outside this range, so there is no problem # noqa: E501 + + return loss + + +# Adapted from: https://gitlab.pjlab.org.cn/openmmlab/bigmodel/rl3m/-/blob/main/rl3m/layers/loss.py#L37 # noqa: E501 +class PretrainLoss(FlashGPTLMLoss): + """Modified from pretrain/sft loss, but with a loss factor term to balance + with ppo policy loss.""" + + def __init__(self, *args, loss_factor=1.0, **kwargs): + super().__init__(*args, **kwargs) + self.loss_factor = loss_factor + + def forward(self, *args, **kwargs): + loss = super().forward(*args, **kwargs) + return loss * self.loss_factor diff --git a/xtuner/rlhf/main.py b/xtuner/rlhf/main.py new file mode 100644 index 000000000..cf3812fcc --- /dev/null +++ b/xtuner/rlhf/main.py @@ -0,0 +1,169 @@ +import argparse +import json +import os +import time + +import numpy as np +from loguru import logger + +from xtuner.rlhf.config.config import Config +from xtuner.rlhf.coordinator import Coordinator +from xtuner.rlhf.dataset.txt_loader import TxtMessageDataset +from xtuner.rlhf.envs.txt_env import TxtEnv +from xtuner.rlhf.repeaters.base import BaseRepeater +from xtuner.rlhf.tokenizer.tokenizer_utils import get_tokenizer +from xtuner.rlhf.trainer.ppo import PPOTrainer + + +def parse_args(): + parser = argparse.ArgumentParser(description='Train LLM') + parser.add_argument( + '-c', + '--config', + help='config file name or path.', + type=str, + default='examples/rlhf/four_model_8gpu.py') + parser.add_argument( + '-w', + '--work_dir', + help='the dir to save logs and models', + type=str, + default=None) + parser.add_argument( + '-a', '--address', help='ray head address', type=str, default='auto') + args = parser.parse_args() + return args + + +def validate_config(config: Config): + assert config['model_configs'] is not None + assert config['model_configs']['actor'] is not None + assert config['model_configs']['actor']['model_path'] is not None + assert config['dataset_config'] is not None + assert config['rollout_config'] is not None + assert config['rollout_config']['generate_kwargs'] is not None + assert config['rollout_config']['max_new_tokens'] is not None + + +if __name__ == '__main__': + args = parse_args() + assert args.config is not None, 'config should not be None' + work_dir = args.work_dir + if work_dir is None: + work_dir = os.getcwd() + work_dir = os.path.abspath(work_dir) + logger.info(f'using work_dir: {work_dir}') + os.makedirs(work_dir, exist_ok=True) + + logger.add( + f'{work_dir}/train.log', + filter=lambda record: record['extra'].get('name') == 'train') + logger.add( + f'{work_dir}/rollout.log', + filter=lambda record: record['extra'].get('name') == 'rollout') + logger_train = logger.bind(name='train') + + configs_path = args.config + config = Config.from_file(configs_path) + logger.info('#################### CONFIG BGN ####################') + for k, v in config.items(): + logger.info(f'{k}: {v}') + logger.info('#################### CONFIG END ####################') + + # init dataset + model_path = config['model_configs']['actor']['model_path'] + tokenizer_config = config.get('tokenizer_config', {}) + for model_type in config['model_configs'].keys(): + if 'tokenizer_config' not in config['model_configs'][model_type]: + config['model_configs'][model_type][ + 'tokenizer_config'] = tokenizer_config + tokenizer = get_tokenizer( + model_path, trust_remote_code=True, **tokenizer_config) + dataset_config = config['dataset_config'] + dataset_config['tokenizer'] = tokenizer + txt_loader = TxtMessageDataset(**dataset_config) + + # init model + cluster_address = args.address + if cluster_address != 'auto': + cluster_address = f'ray://{cluster_address}:10001' + logger.info(f'cluster_address={cluster_address}') + coordinator = Coordinator(cluster_address, config['model_configs']) + model_dict = coordinator.create_models() + sft_model = model_dict['reference'] + actor_model = model_dict['actor'] + reward_model = model_dict['reward'] + critic_model = model_dict['critic'] + + # init txt env + + rollout_config = config.get('rollout_config', {}) + txt_env = TxtEnv( + dataloader=txt_loader, + reward_function=reward_model, + **rollout_config, + ) + # init repeater + repeater_config = config.get('repeater_config', {}) + rl_repeater = BaseRepeater( + sft_model=sft_model, + **repeater_config, + ) + # init trainer + train_config = config.get('train_config', {}) + ppo = PPOTrainer( + policy_model=actor_model, value_model=None, **train_config) + pretrain_step = train_config['pretrain_step'] + save_interval = train_config['save_interval'] + np.set_printoptions(threshold=np.inf) + step = 1 + while True: + s_t = time.time() + trajectories = txt_env.rollout(policy_model=actor_model) + # deal with trajectories + trajectories = rl_repeater.process( + trajectories, + policy_model=actor_model, + value_model=critic_model, + sft_model=None, + env=txt_env) + + # # for value & policy learn + value_loss_ref = ppo.value_learn_async(trajectories, critic_model) + + ppo_loss = 0.0 + if pretrain_step <= 0: + ppo_loss, pt_loss = ppo.policy_learn(trajectories, actor_model) + logger_train.info( + f'[Policy Train] Step: {step}, ppo loss: {ppo_loss}, pretrain loss: {pt_loss}' # noqa: E501 + ) + + value_loss = ppo.value_learn_get(value_loss_ref, critic_model) + logger_train.info( + f'[Value Train] step: {step}, value loss: {value_loss}') + logger_train.info(f'rewards: {trajectories.rewards.mean()}') + pretrain_step -= 1 + + if config['rollout_config'].get('write_to_file', True): + with open(f'{work_dir}/rollout.log', 'a') as file: + file.write(f'generates: {trajectories.output_str}') + summaries = dict( + reward_mean=trajectories.rewards.mean().item(), + reward_std=trajectories.rewards.std().item(), + new_tokens_mean=trajectories.action_mask.sum( + -1).float().mean().item(), + new_tokens_std=trajectories.action_mask.sum( + -1).float().std().item(), + kl=trajectories.kl.mean().item(), + entropy=trajectories.entropy.mean().item(), + step=step, + policy_loss=ppo_loss, + critic_loss=value_loss, + ) + with open(f'{work_dir}/train.log.jsonl', 'a') as f: + f.write(json.dumps(summaries) + '\n') + + step += 1 + logger_train.info(f'[end to end] duration: {time.time() - s_t} s') + if step % save_interval == 0: + actor_model.save_model(f'{work_dir}/ckpt/{step}/') diff --git a/xtuner/rlhf/model_backend/__init__.py b/xtuner/rlhf/model_backend/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/xtuner/rlhf/model_backend/cuda_memory_stats.py b/xtuner/rlhf/model_backend/cuda_memory_stats.py new file mode 100644 index 000000000..cdf195fad --- /dev/null +++ b/xtuner/rlhf/model_backend/cuda_memory_stats.py @@ -0,0 +1,52 @@ +from loguru import logger + +GB_SHIFT = 30 +MB_SHIFT = 20 + + +class CudaMemoryStats(dict): + # see: https://pytorch.org/docs/stable/generated/torch.cuda.memory_stats.html # noqa: E501 + # def add_memory_stats(self, key, device): + # import torch + # status = torch.cuda.memory_stats(device=device) + # self.__setattr__(key, status) + + @property + def num_gpus(self): + return len(self.keys()) + + @property + def total_current_bytes(self): + CURRENT_BYTE_KEY = 'allocated_bytes.all.current' + total = 0 + for _, v in self.items(): + total += v.get(CURRENT_BYTE_KEY, 0) + return total + + @property + def total_current_gb(self): + return self.total_current_bytes >> GB_SHIFT + + @property + def total_current_mb(self): + return self.total_current_bytes >> MB_SHIFT + + @property + def avg_current_bytes(self): + return self.total_current_bytes / self.num_gpus if self.num_gpus != 0 else 0 # noqa: E501 + + def __repr__(self): + return f'CudaMemoryStats: {self.num_gpus} GPU takes {self.total_current_mb} MiB' # noqa: E501 + + +def merge_cuda_memory_stats_list( + dict_list: list[CudaMemoryStats]) -> CudaMemoryStats: + if isinstance(dict_list, CudaMemoryStats): + logger.warning('dict_list is a CudaMemoryStatus instead of a list') + return dict_list + memory_stats_dict: CudaMemoryStats = dict_list[0] + assert isinstance(memory_stats_dict, CudaMemoryStats) + if len(dict_list) > 1: + for m in dict_list[1:]: + memory_stats_dict.update(m) + return memory_stats_dict diff --git a/xtuner/rlhf/model_backend/dist_utils.py b/xtuner/rlhf/model_backend/dist_utils.py new file mode 100644 index 000000000..30e63a229 --- /dev/null +++ b/xtuner/rlhf/model_backend/dist_utils.py @@ -0,0 +1,63 @@ +from datetime import timedelta +from typing import Any, Optional, Union + +from torch.distributed.distributed_c10d import (Backend, PrefixStore, Store, + _new_process_group_helper, + _world, default_pg_timeout, + rendezvous) + + +# Adapted from https://github.com/pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py # noqa: E501 +def init_process_group( + backend: Union[str, Backend] = None, + init_method: Optional[str] = None, + timeout: Optional[timedelta] = None, + world_size: int = -1, + rank: int = -1, + store: Optional[Store] = None, + group_name: str = '', + pg_options: Optional[Any] = None, +): + assert (store is None) or ( + init_method is None), 'Cannot specify both init_method and store.' + + if store is not None: + assert world_size > 0, 'world_size must be positive if using store' + assert rank >= 0, 'rank must be non-negative if using store' + elif init_method is None: + init_method = 'env://' + + if backend: + backend = Backend(backend) + else: + backend = Backend('undefined') + + if timeout is None: + timeout = default_pg_timeout + + # backward compatible API + if store is None: + rendezvous_iterator = rendezvous( + init_method, rank, world_size, timeout=timeout) + store, rank, world_size = next(rendezvous_iterator) + store.set_timeout(timeout) + + # Use a PrefixStore to avoid accidental overrides of keys used by + # different systems (e.g. RPC) in case the store is multi-tenant. + store = PrefixStore(group_name, store) + + pg = _new_process_group_helper( + world_size, + rank, + [], + backend, + store, + group_name=group_name, + pg_options=pg_options, + timeout=timeout, + ) + + pg = pg[0] if isinstance(pg, tuple) else pg + _world.pg_group_ranks[pg] = {i: i for i in range(world_size)} + + return pg diff --git a/xtuner/rlhf/model_backend/generate_utils.py b/xtuner/rlhf/model_backend/generate_utils.py new file mode 100644 index 000000000..e88995d28 --- /dev/null +++ b/xtuner/rlhf/model_backend/generate_utils.py @@ -0,0 +1,167 @@ +from typing import Optional, Union + +import torch +from transformers import PreTrainedTokenizer + + +def get_question_answer_mask( + input_ids: torch.Tensor, + output_ids: torch.Tensor, + tokenizer_pad_token_id: int, + generate_pad_token_id: int = None, +): + """ + Example: + input_ids = torch.tensor([[0, 1, 9]]) + output_ids = torch.tensor([[0, 1, 9, 2, 3, 4, 5]]) + tokenizer_pad_token_id = 0 # set 0 as neither question or answer + generate_pad_token_id = None + expected_qst_mask = torch.tensor([[0, 1, 1, 0, 0, 0, 0]]) + expected_ans_mask = torch.tensor([[0, 0, 0, 1, 1, 1, 1]]) + """ + # seq_mask yields zero where token == pad_token_id + seq_mask = output_ids.not_equal(tokenizer_pad_token_id).int() + if generate_pad_token_id is not None: + seq_mask *= output_ids.not_equal(generate_pad_token_id).int() + + question_len = input_ids.shape[-1] + question_mask = seq_mask.clone() + question_mask[:, question_len:] = 0 + answer_mask = seq_mask.clone() + answer_mask[:, :question_len] = 0 + return question_mask, answer_mask + + +def partition_by_micro_batch_size( + input_ids: Union[list[str], torch.Tensor, list[int]], + micro_batch_size: int, + attention_mask: torch.Tensor = None, + labels: Optional[Union[list[torch.Tensor], torch.Tensor, + dict[str, torch.Tensor]]] = None, +) -> list[dict[str, torch.Tensor]]: + micro_batches: list[dict[str, torch.Tensor]] = [] + batch_size = input_ids.shape[0] if isinstance( + input_ids, torch.Tensor) else len(input_ids) + if micro_batch_size <= 0 or batch_size == micro_batch_size: + micro_batch = {} + micro_batch['input_ids'] = input_ids + micro_batch['attention_mask'] = attention_mask + micro_batch['labels'] = labels + micro_batches.append(micro_batch) + return micro_batches + if micro_batch_size > batch_size: + micro_batch_size = batch_size + + num_splits = int(batch_size // micro_batch_size) + ( + batch_size % micro_batch_size > 0) + if isinstance(input_ids, torch.Tensor): + input_ids_split = torch.split(input_ids, micro_batch_size, dim=0) + else: + input_ids_split = [ + input_ids[i:i + micro_batch_size] + for i in range(0, len(input_ids), micro_batch_size) + ] + attention_mask_split = ( + torch.split(attention_mask, micro_batch_size, dim=0) + if attention_mask is not None else [None for _ in range(num_splits)]) + labels_split = ( + partition_label_by_micro_batch_size(labels, micro_batch_size, + num_splits) + if labels is not None else [None for _ in range(num_splits)]) + for i in range(num_splits): + micro_batch = {} + micro_batch['input_ids'] = input_ids_split[i] + micro_batch['attention_mask'] = attention_mask_split[i] + micro_batch['labels'] = labels_split[i] + micro_batches.append(micro_batch) + return micro_batches + + +def partition_label_by_micro_batch_size( + labels: Union[list[torch.Tensor], torch.Tensor, dict[str, torch.Tensor]], + micro_batch_size: int, + num_splits: int = 1, +): + if isinstance(labels, torch.Tensor): + return torch.split(labels, micro_batch_size, dim=0) + if isinstance(labels, list): + return [ + labels[i:i + micro_batch_size] + for i in range(0, len(labels), micro_batch_size) + ] + if isinstance(labels, dict): + split = [{} for _ in range(num_splits)] + for key in labels.keys(): + if key == 'loss_factor': + for i in range(num_splits): + split[i][key] = labels[key] + else: + tensors = partition_label_by_micro_batch_size( + labels[key], micro_batch_size) + for i in range(num_splits): + split[i][key] = tensors[i] + return split + + +def partition_list_by_micro_batch_size( + input_ids: list[torch.Tensor], + micro_batch_size: list[int], + labels: list[torch.Tensor], + attention_mask: Optional[list[torch.Tensor]] = None, + loss_weights: Optional[list[float]] = None, +) -> list[dict]: + length = len(input_ids) + batch_size = input_ids[0].shape[0] + num_splits = int(batch_size // micro_batch_size[0]) + ( + batch_size % micro_batch_size[0] > 0) + micro_batches = [[{} for i in range(length)] for _ in range(num_splits)] + if loss_weights is None: + loss_weights = [None for _ in range(length)] + if attention_mask is None: + attention_mask = [None for _ in range(length)] + for i in range(length): + sub_input_ids = input_ids[i] + sub_attention_mask = attention_mask[i] + sub_labels = labels[i] + sub_loss_weights = loss_weights[i] + sub_micro_batches = partition_by_micro_batch_size( + sub_input_ids, micro_batch_size[i], sub_attention_mask, sub_labels) + for micro_batch_index, sub_micro_batch in enumerate(sub_micro_batches): + micro_batches[micro_batch_index][i]['input_ids'] = sub_micro_batch[ + 'input_ids'] + micro_batches[micro_batch_index][i][ + 'attention_mask'] = sub_micro_batch['attention_mask'] + micro_batches[micro_batch_index][i]['labels'] = sub_micro_batch[ + 'labels'] + micro_batches[micro_batch_index][i][ + 'loss_weights'] = sub_loss_weights + return micro_batches + + +def merge_loss_list(loss_list_mb: list[list[torch.Tensor]]): + micro_batch_num = len(loss_list_mb) + loss_num = len(loss_list_mb[0]) + loss_list = [i for i in range(loss_num)] + for loss_index in range(loss_num): + losses = [] + for batch_index in range(micro_batch_num): + losses.append(loss_list_mb[batch_index][loss_index]) + loss_list[loss_index] = sum(losses) / micro_batch_num + return loss_list + + +def get_answer_str( + tokenizer: PreTrainedTokenizer, + output_ids: torch.Tensor, + answer_mask: torch.Tensor, +): + answer_ids = output_ids * answer_mask + zero_mask = answer_ids.eq(0) + answer_ids = zero_mask * tokenizer.all_special_ids[0] + answer_ids + + answer_str = tokenizer.batch_decode( + answer_ids, + skip_special_tokens=True, + clean_up_tokenization_spaces=False, + ) + return answer_str diff --git a/xtuner/rlhf/model_backend/hf_model_runner.py b/xtuner/rlhf/model_backend/hf_model_runner.py new file mode 100644 index 000000000..ca6a826f5 --- /dev/null +++ b/xtuner/rlhf/model_backend/hf_model_runner.py @@ -0,0 +1,898 @@ +import os +import socket +from typing import Optional, Union + +import ray +import torch +from accelerate import Accelerator +from accelerate.utils import FullyShardedDataParallelPlugin +from loguru import logger +from ray.util.placement_group import placement_group as create_placement_group +from ray.util.placement_group import remove_placement_group +from torch.nn.modules.loss import _Loss +from torch.optim.lr_scheduler import _LRScheduler +from transformers import AutoModelForCausalLM, PreTrainedModel +from transformers import get_scheduler as transformers_get_scheduler +from transformers.dynamic_module_utils import init_hf_modules +from transformers.generation.utils import GenerateDecoderOnlyOutput + +from ..config.config_consts import (ENGINE_PLUGIN_DDP, ENGINE_PLUGIN_DEEPSPEED, + ENGINE_PLUGIN_FSDP) +from ..config.config_utils import get_dp_size, get_gpu_requirement +from ..policy_output import (PolicyOutput, concat_policy_outputs, + logprobs_from_logits) +from ..tokenizer import tokenizer_utils +from ..utils import set_seed +from .dist_utils import init_process_group +from .generate_utils import (get_answer_str, get_question_answer_mask, + merge_loss_list, partition_by_micro_batch_size, + partition_list_by_micro_batch_size) +from .ray_actor_group import RayActorGroup +from .ray_actor_mixin import RayActorMixin +from .ray_utils import DEFAULT_NUM_CPUS, DEFAULT_NUM_GPUS, create_ray_actors + +DEFAULT_NEW_TOKENS = 64 +MAXIMUM_NEW_TOKENS = 1024 +""" +HfModelRunner can be individually called by other process +HfModelRunnerRayActor is called by ModelServer with .remote() +""" + + +class HfModelRunner: + """ModelTrainer is capable of training, inference, and generation.""" + + def __init__(self, model_config): + self.model_config: dict = model_config + + def initialize(self): + # 0. Environment + envs = self.model_config.get('envs', {}) + for key, value in envs.items(): + os.environ[key] = value + + # 1. Model + model_path = self.model_config.get('model_path') + self.model_type = self.model_config.get('model_type', '').lower() + torch_dtype = self.model_config.get('torch_dtype', 'auto') + use_flash_attn = self.model_config.get('use_flash_attn', None) + model_class = self.model_config.get('model_class', + AutoModelForCausalLM) + self.model: PreTrainedModel = model_class.from_pretrained( + pretrained_model_name_or_path=model_path, + device_map='auto', + torch_dtype=torch_dtype, + trust_remote_code=True, + attn_implementation='flash_attention_2' + if use_flash_attn else None, + ) + + # Graident checkpointing + gradient_checkpointing = self.model_config.get( + 'gradient_checkpointing', False) + if gradient_checkpointing: + self.model.gradient_checkpointing_enable() + self.vocab_size = self.model.config.vocab_size + + # 2. Tokenizer + tokenizer_path = self.model_config.get('tokenizer_path', model_path) + tokenizer_config = self.model_config.get('tokenizer_config', {}) + self.tokenizer = tokenizer_utils.get_tokenizer( + tokenizer_path, trust_remote_code=True, **tokenizer_config) + + # 3. Trainer + parallel: dict = self.model_config['parallel'] + assert parallel['tensor']['size'] == 1 # TODO: support TP + assert parallel['pipeline']['size'] == 1 # TODO: support PP + self.step = 0 + self.zero_stage = 1 + mixed_precision = self.model_config.get('mixed_precision', None) + if parallel['data'].get('mode') == ENGINE_PLUGIN_FSDP: + self.accelerator = Accelerator( + fsdp_plugin=FullyShardedDataParallelPlugin()) + self.zero_stage = 3 + elif parallel['data'].get('mode') == ENGINE_PLUGIN_DEEPSPEED: + from accelerate import DeepSpeedPlugin + + ds_config = self.model_config['deepspeed_config'] # requisite + self.accelerator = Accelerator( + deepspeed_plugin=DeepSpeedPlugin(ds_config)) + self.zero_stage = ds_config['zero_optimization']['stage'] + else: + self.accelerator = Accelerator(mixed_precision=mixed_precision) + self.zero_stage = 0 + + train_kwargs = self.model_config.get('train_kwargs') + if train_kwargs is None: # requires no training + self.device = self.accelerator.device + logger.info( + f'[{self.model_type}] __init__() done without train_kwargs.') + return + optimizer_type = train_kwargs.get('optimizer', torch.optim.AdamW) + learning_rate = train_kwargs.get('lr', 1e-5) + self.clip_grad_norm = train_kwargs.get('clip_grad_norm', 1.0) + self.optimizer: torch.optim.Optimizer = optimizer_type( + params=self.model.parameters(), + lr=learning_rate, + ) + + lr_scheduler_type = train_kwargs.get('lr_scheduler', 'linear') + lr_scheduler_kwargs = train_kwargs.get( + 'lr_scheduler_kwargs', + { + 'num_warmup_steps': 0, + 'num_training_steps': 10000000000 + }, + ) + self.lr_scheduler: _LRScheduler = transformers_get_scheduler( + lr_scheduler_type, + optimizer=self.optimizer, + **lr_scheduler_kwargs, + ) + self.model, self.optimizer, self.lr_scheduler = self.accelerator.prepare( # noqa: E501 + self.model, self.optimizer, self.lr_scheduler) + + # Others + self.device = self.accelerator.device + set_seed(self.model_config.get('seed')) + if mixed_precision is not None: + self.info_rank0( + f'[{self.model_type}]: Enable mixed_precision = {mixed_precision}' # noqa: E501 + ) + if gradient_checkpointing: + self.info_rank0( + f'[{self.model_type}]: Enable gradient_checkpointing') + self.info_rank0( + f'[{self.model_type}] __init__() done with optimizer {self.optimizer.optimizer}.' # noqa: E501 + ) + + # Training + def compute_loss_and_backward( + self, + input_ids: Union[list[torch.Tensor], torch.Tensor], + labels: Optional[Union[list[torch.Tensor], torch.Tensor, + dict[str, torch.Tensor]]] = None, + attention_mask: Optional[Union[list[torch.Tensor], + torch.Tensor]] = None, + criterion: Optional[Union[list[_Loss], _Loss]] = None, + loss_weights: Optional[list[float]] = None, + gradient_accumulation_steps=1, + **_ignored, + ) -> tuple[torch.Tensor, list[torch.Tensor]]: + """ + criterion: _Loss class, e.g., torch.nn.CrossEntropyLoss() + """ + if isinstance(input_ids, torch.Tensor): # returns torch.Tensor + # rarely, since self.train() changes all input_ids to [input_ids] + loss = self.compute_loss(input_ids, labels, attention_mask, + criterion) + self.accelerator.backward(loss) + return loss + + elif type(input_ids) == list: # returns list[torch.Tensor] + # multiple inputs grouped to compute loss, see: + # https://stackoverflow.com/questions/53994625/how-can-i-process-multi-loss-in-pytorch + assert ( + len(input_ids) == len(labels) == len(criterion) == + len(attention_mask) == len(loss_weights) + ), f'{len(input_ids)} {len(labels)} {len(criterion)} {len(attention_mask)} {len(loss_weights)} must equal' # noqa: E501 + loss_list = [0 for _ in range(len(input_ids))] + loss_weights = [ + x / float(len(loss_weights)) for x in loss_weights + ] # to 1 + + loss_sum = 0 + for i in range(len(input_ids)): + with self.accelerator.autocast(): + loss = self.compute_loss(input_ids[i], labels[i], + attention_mask[i], criterion[i]) + loss_sum += loss * loss_weights[i] + loss_list[i] = loss + self.accelerator.backward(loss_sum) + return loss_list + + else: + raise NotImplementedError(f'unknown input {input_ids}') + + def compute_loss( + self, + input_ids: torch.Tensor, + labels: Optional[Union[torch.Tensor, dict[str, torch.Tensor]]] = None, + attention_mask: Optional[torch.Tensor] = None, + criterion: Optional[_Loss] = None, + loss_weight: Optional[float] = None, + **_ignored, + ) -> torch.Tensor: + input_ids = input_ids.to(self.device) + labels = input_ids.clone() if labels is None else labels + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + batch = { + 'input_ids': input_ids, + 'attention_mask': attention_mask, + 'position_ids': position_ids.to(self.device) + } + self.model.train() + + if criterion is None: + # OPT. A) Default settings + assert isinstance( + labels, torch.Tensor + ), 'Please pass in `criterion` for non-tensor labels' + batch['labels'] = labels.to(self.device) + fwd_output = self.model(**batch, use_cache=False) + loss = fwd_output.loss + elif isinstance(labels, torch.Tensor): + # OPT. B) Use preset loss functions, e.g., torch.nn.CrossEntropyLoss() # noqa: E501 + # Adopted from: https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L1199 # noqa: E501 + logits: torch.Tensor = self.model(**batch, use_cache=False).logits + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + shift_logits = shift_logits.view(-1, self.vocab_size) + shift_labels = shift_labels.view(-1) + shift_labels = shift_labels.to( + shift_logits.device) # enable model para + # loss_fct = criterion() + loss = criterion(shift_logits, shift_labels) + elif isinstance(labels, dict): + # OPT. C) Use customized loss function, see loss/actor_loss.py + logits: torch.Tensor = self.model( + **batch, use_cache=False, return_dict=True).logits + # loss_fct = criterion() + for k, v in labels.items(): + labels[k] = v.to(self.device) + loss = criterion(logits, labels) + else: + raise ValueError(f'labels of unsupported type: {type(labels)}') + + if loss_weight is not None: + loss *= loss_weight + return loss + + def parameter_update(self, step_interval=1): + self.info_rank0(f'[{self.model_type}] self.parameter_update()') + self.step += 1 + if self.step % step_interval == 0: + self.accelerator.clip_grad_norm_(self.model.parameters(), + self.clip_grad_norm) + self.optimizer.step() + self.lr_scheduler.step() + self.optimizer.zero_grad() + + def train( + self, + input_ids: Union[list[torch.Tensor], torch.Tensor], + labels: Optional[Union[list[torch.Tensor], torch.Tensor, + dict[str, torch.Tensor]]] = None, + attention_mask: Optional[Union[list[torch.Tensor], + torch.Tensor]] = None, + criterion: Optional[Union[list[_Loss], _Loss]] = None, + loss_weights: Optional[Union[list[float], float]] = None, + step_interval: int = 1, + # None means using the entire input as one batch + micro_batch_size: Optional[Union[list[int], int]] = None, + debug=False, + **_ignored, + ): + return_list = True + + if isinstance(input_ids, torch.Tensor): + input_ids = [input_ids] + labels = [labels] + attention_mask = [attention_mask] + criterion = [criterion] + loss_weights = [1] if loss_weights is None else [loss_weights] + micro_batch_size = None if micro_batch_size is None else [ + micro_batch_size + ] + return_list = False + + if micro_batch_size is None: + for i in range(len(input_ids)): + self.info_rank0( + f'[{self.model_type}] train input_ids[{i}] shape[{input_ids[i].shape}]' # noqa: E501 + ) + origin_loss = self.compute_loss_and_backward( + input_ids, labels, attention_mask, criterion, loss_weights) + else: + assert isinstance(input_ids, list) + micro_batches = partition_list_by_micro_batch_size( + input_ids, micro_batch_size, labels, attention_mask, + loss_weights) + origin_loss_list_mb = [] + for index, micro_batch in enumerate(micro_batches): + input_ids_mb = [] + attention_mask_mb = [] + labels_mb = [] + loss_weights_mb = [] + for i in range(len(micro_batch)): + input_ids_mb.append(micro_batch[i]['input_ids'].to( + self.device)) + attention_mask_mb.append( + micro_batch[i]['attention_mask'].to(self.device)) + labels_mb.append(micro_batch[i]['labels']) + loss_weights_mb.append(micro_batch[i]['loss_weights']) + if index == 0: + for i in range(len(input_ids_mb)): + self.info_rank0( + f'[{self.model_type}] will train input_ids_mb[{i}] shape[{input_ids_mb[i].shape}] * {len(micro_batches)} times' # noqa: E501 + ) + origin_loss_mb = self.compute_loss_and_backward( + input_ids_mb, + labels_mb, + attention_mask_mb, + criterion, + loss_weights_mb, + gradient_accumulation_steps=len(micro_batches), + ) + origin_loss_list_mb.append(origin_loss_mb) + if debug: + set_seed(1234) + origin_loss = merge_loss_list(origin_loss_list_mb) + + self.parameter_update(step_interval) + return origin_loss if return_list else origin_loss[0] + + # Inference + @torch.no_grad() + def _infer( + self, + input_ids: torch.Tensor, + attention_mask=None, + output_logprobs=True, + output_logits=False, + output_attentions=False, + output_hidden_states=False, + infer_kwargs: Optional[dict] = {}, + **_ignored, + ) -> PolicyOutput: + assert isinstance(input_ids, torch.Tensor) + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + model_output = self.model( + input_ids.to(self.device), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + attention_mask=attention_mask, + position_ids=position_ids.to(self.device), + return_dict=True, + **infer_kwargs, + ) + + output = PolicyOutput() + if output_logits: + output['logits'] = model_output['logits'] + if output_attentions: + output['attentions'] = model_output['attentions'] + if output_hidden_states: + output['hidden_states'] = model_output['hidden_states'] + if output_logprobs: + log_probs = logprobs_from_logits( + logits=model_output['logits'][:, :-1, :], + labels=input_ids[:, 1:], + gather=True, + ) + output['logprobs'] = log_probs + output.to('cpu') + return output + + @torch.no_grad() + def infer( + self, + inputs: Union[torch.Tensor, list[dict], list[list[dict]]], + micro_batch_size: Optional[ + int] = -1, # -1: use the entire input as one batch + tokenizer=None, # Only used for reward models + attention_mask=None, + output_logprobs=False, + output_logits=True, + output_attentions=False, + output_hidden_states=False, + infer_kwargs: Optional[dict] = {}, + debug=False, + **_ignored, + ) -> PolicyOutput: + self.info_rank0( + f'[{self.model_type}] self.infer() kwargs: {infer_kwargs}') + if not isinstance(inputs, torch.Tensor): + input_ids, attention_mask = tokenizer_utils.encode( + inputs, self.tokenizer) + else: + input_ids = inputs + + input_ids = input_ids.to(self.device) + if attention_mask is not None: + attention_mask = attention_mask.to(self.device) + # returns entire-input-as-one-batch inference results + if micro_batch_size < 0: + self.info_rank0( + f'[{self.model_type}] infer() input_ids.shape: {input_ids.shape}' # noqa: E501 + ) + return self._infer( + input_ids, + attention_mask, + output_logprobs, + output_logits, + output_attentions, + output_hidden_states, + infer_kwargs, + ) + + # Otherwise, partition the input into micro batches and run inference on each micro batch separately # noqa: E501 + micro_batches = partition_by_micro_batch_size(input_ids, + micro_batch_size, + attention_mask) + policy_outputs = [] + for index, micro_batch in enumerate(micro_batches): + input_ids_mb = micro_batch['input_ids'] + attention_mask_mb = micro_batch['attention_mask'] + if index == 0: + self.info_rank0( + f'[{self.model_type}] will infer() input_ids_mb.shape: {input_ids_mb.shape} * {len(micro_batches)} times' # noqa: E501 + ) + policy_output_mb = self._infer( + input_ids_mb, + attention_mask_mb, + output_logprobs, + output_logits, + output_attentions, + output_hidden_states, + infer_kwargs, + ) + policy_outputs.append(policy_output_mb) + if debug: + self.set_seed(1234) + # Concatenate the policy outputs from each micro batch and return the result # noqa: E501 + return concat_policy_outputs(policy_outputs) + + # Generate + @torch.no_grad() + def _generate( + self, + input_ids: torch.Tensor, + attention_mask=None, + step=-1, + output_str=True, + output_logits=False, + output_attentions=False, + output_hidden_states=False, + generate_kwargs: Optional[dict] = {}, + ) -> PolicyOutput: + assert isinstance(input_ids, torch.Tensor) + if isinstance(self.model, torch.nn.parallel.DistributedDataParallel): + model = self.accelerator.unwrap_model(self.model) + else: + model = self.model + + max_new_tokens = ( + MAXIMUM_NEW_TOKENS + if 'eos_token_id' in generate_kwargs else DEFAULT_NEW_TOKENS) + max_new_tokens = step if step > 0 else max_new_tokens + + # TODO: stop if meeting eos_token_id + model_output: GenerateDecoderOnlyOutput = model.generate( + input_ids.to(model.device), + use_cache=True, + max_new_tokens=max_new_tokens, + return_dict_in_generate=True, + output_logits=output_logits, # transformers >= 4.38.2 + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + attention_mask=attention_mask, + **generate_kwargs, + ) + + output_ids = model_output['sequences'] + self.info_rank0( + f'generate input_ids shape:[{input_ids.shape}], output_ids shape:[{output_ids.shape}]' # noqa: E501 + ) + output = PolicyOutput(output_ids=output_ids) + # masks + output['question_mask'], output[ + 'answer_mask'] = get_question_answer_mask( + input_ids, + output_ids, + tokenizer_pad_token_id=self.tokenizer.pad_token_id, + generate_pad_token_id=generate_kwargs.get('pad_token_id'), + ) + output['attention_mask'] = output.question_mask + output.answer_mask + output['action_mask'] = output['attention_mask'][:, + input_ids.size(1) - + 1:-1] + + if output_logits: + output['logits'] = model_output['logits'] # tuple(torch.Tensor, ) + if output_attentions: + output['attentions'] = model_output['attentions'] + if output_hidden_states: + output['hidden_states'] = model_output['hidden_states'] + if output_str: # customized post processing + output['output_str'] = self.tokenizer.batch_decode( + output_ids, + skip_special_tokens=True, + clean_up_tokenization_spaces=False, + ) + output['output_ans_str'] = get_answer_str( + tokenizer=self.tokenizer, + output_ids=output_ids, + answer_mask=output.answer_mask, + ) + + output.to('cpu') + return output + + # Generate + @torch.no_grad() + def generate( + self, + inputs: Union[torch.Tensor, str, list[str]], + micro_batch_size: Optional[ + int] = -1, # -1: use the entire input as one batch + attention_mask=None, + step=-1, + output_str=True, + output_logits=False, + output_attentions=False, + output_hidden_states=False, + chat_template=None, + generate_kwargs: Optional[dict] = {}, + debug=False, + **_ignored, + ) -> PolicyOutput: + self.info_rank0( + f'[{self.model_type}] self.generate() kwargs: {generate_kwargs}') + if not isinstance(inputs, torch.Tensor): + input_ids, attention_mask = tokenizer_utils.encode( + inputs, self.tokenizer, add_generation_prompt=True) + else: + input_ids = inputs + input_ids = input_ids.to(self.device) + if attention_mask is not None: + assert isinstance(attention_mask, torch.Tensor) + attention_mask = attention_mask.to(self.device) + + if micro_batch_size < 0: + return self._generate( + input_ids, + attention_mask, + step, + output_str, + output_logits, + output_attentions, + output_hidden_states, + generate_kwargs, + ) + + micro_batches = partition_by_micro_batch_size(input_ids, + micro_batch_size, + attention_mask) + policy_outputs = [] + for micro_batch in micro_batches: + input_ids_mb = micro_batch['input_ids'] + attention_mask_mb = micro_batch['attention_mask'] + policy_output_mb = self._generate( + input_ids_mb, + attention_mask_mb, + step, + output_str, + output_logits, + output_attentions, + output_hidden_states, + generate_kwargs, + ) + policy_outputs.append(policy_output_mb) + if debug: + self.set_seed(1234) + + padding_token_map = {'output_ids': self.tokenizer.pad_token_id} + return concat_policy_outputs(policy_outputs, padding_token_map) + + def get_model(self): + parallel: dict = self.model_config['parallel'] + dp = parallel['data'].get('size') + dp_mode = parallel['data'].get('mode') + if dp > 1 and dp_mode != ENGINE_PLUGIN_DDP: + raise ('please use get_state_dict instead when using parallel') + _model = self.accelerator.unwrap_model(self.model) + return _model + + def get_state_dict(self): + state_dict = self.accelerator.get_state_dict(self.model) + if not self.accelerator.is_main_process: + return None + return state_dict + + def set_seed(self, seed=None): + set_seed(seed) + + def save_model(self, path): + if not self.accelerator.is_main_process: + self.accelerator.get_state_dict(self.model) + return + unwrapped_model = self.accelerator.unwrap_model(self.model) + if not os.path.exists(path): + os.makedirs(path) + unwrapped_model.save_pretrained( + path, + is_main_process=True, + save_function=self.accelerator.save, + state_dict=self.accelerator.get_state_dict(self.model), + ) + logger.info(f'save model to {path}') + + def info_rank0(self, content): + if self.accelerator.is_main_process: + logger.info(content) + + +# Adapted from https://github.com/OpenLLMAI/OpenRLHF/blob/v0.2.5/openrlhf/trainer/ray/ppo_actor.py # noqa: E501 +class HfModelRunnerRayActor(HfModelRunner, RayActorMixin): + """A ray.remote Actor Class initialized by HfModelRunnerRayActorGroup, + extending HfModelRunner with ray related method via RayActorMixin.""" + + def init_process_group(self, generator): + if self.accelerator.is_main_process: + # init process groups for vllm engine + master_address = ray._private.services.get_node_ip_address() + with socket.socket() as sock: + sock.bind(('', 0)) + master_port = sock.getsockname()[1] + + world_size = generator.dp_size * generator.tp_size + 1 + refs = [ + engine.init_process_group.remote( + master_address, + master_port, + i * generator.tp_size + 1, + world_size, + 'vllm', + ) for i, engine in enumerate(generator.ray_actors) + ] + self._model_update_group = init_process_group( + backend='nccl', + init_method=f'tcp://{master_address}:{master_port}', + world_size=world_size, + rank=0, + group_name='vllm', + ) + ray.get(refs) + + def broadcast_model_to_generator(self, generator): + # TODO: Support Pytorch FSDP. + if self.model_config['parallel']['data'].get( + 'mode') == ENGINE_PLUGIN_FSDP: + raise NotImplementedError('FSDP is not supported yet.') + logger.info('Broadcast BEGIN') + model = self.accelerator.unwrap_model(self.model) + for name, param in model.named_parameters(): + if self.accelerator.is_main_process: + shape = param.shape if self.zero_stage != 3 else param.ds_shape + + for engine in generator.ray_actors: + engine.update_weight.remote( + name, dtype=param.dtype, shape=shape) + + if self.zero_stage != 3: + if self.accelerator.is_main_process: + torch.distributed.broadcast( + param.data, 0, group=self._model_update_group) + else: + from deepspeed.runtime.zero.partition_parameters import \ + GatheredParameters + + with GatheredParameters([param]): + if self.accelerator.is_main_process: + torch.distributed.broadcast( + param.data, 0, group=self._model_update_group) + + logger.info('Broadcast END') + + +class HfModelRunnerRayActorGroup(RayActorGroup): + """HfModelRunnerRayActorGroup manages a list of HfModelRunnerRayActor + create ray actors.""" + + # avoid ModuleNotFoundError: No module named 'transformers_modules' + # refer to https://github.com/vllm-project/vllm/pull/871 + init_hf_modules() + + def __init__(self, name: str, config: dict): + super().__init__(name, config) + self.released = True + num_gpus = get_gpu_requirement(config) + self.dp_size = get_dp_size(config) + self.tokenizer_pad_token_id = config.tokenizer_config.pad_token_id + bundles = [{ + 'CPU': DEFAULT_NUM_CPUS, + 'GPU': DEFAULT_NUM_GPUS + } for _ in range(num_gpus)] + self.placement_group = create_placement_group(bundles) + self.ray_actors: list[HfModelRunnerRayActor] = create_ray_actors( + name_prefix=name, + config=config, + placement_group=self.placement_group, + trainer_class=ray.remote( + num_cpus=DEFAULT_NUM_CPUS, + num_gpus=DEFAULT_NUM_GPUS)(HfModelRunnerRayActor), + ) + self.released = False + + master_ip = ray.get(self.ray_actors[0].get_metadata.remote()).node_ip + master_port = ray.get(self.ray_actors[0].get_free_port.remote()) + ray.get([ + actor.inject_distribute_env.remote( + master_ip=master_ip, + master_port=master_port, + rank_id=rank, + world_size=len(self.ray_actors), + ) for rank, actor in enumerate(self.ray_actors) + ]) + self.initialize_ref = [ + actor.initialize.remote() for actor in self.ray_actors + ] + + def initialize_get(self): + if self.initialize_ref is not None: + ray.get(self.initialize_ref) + else: + logger.info( + 'self.initialize_get None, maybe self.generator==self.trainer') + self.initialize_ref = None + + # Training + def train_async(self, input_ids, labels, attention_mask, *args, **kwargs): + if isinstance(input_ids, torch.Tensor): + micro_batch_size = input_ids.shape[0] // self.dp_size + ( + input_ids.shape[0] % self.dp_size > 0 + ) # round up division, i.e., math.ceil(a / b) + micro_batches = partition_by_micro_batch_size( + input_ids, micro_batch_size, attention_mask, labels) + assert len(micro_batches) == self.dp_size + return [ + self.ray_actors[index].train.remote( + input_ids=micro_batch['input_ids'], + attention_mask=micro_batch['attention_mask'], + labels=micro_batch['labels'], + *args, + **kwargs, + ) for index, micro_batch in enumerate(micro_batches) + ] + elif isinstance(input_ids, list): + """a list of tensors whose training loss will be taken average.""" + assert isinstance(input_ids[0], torch.Tensor) + micro_batch_size = [i for i in range(len(input_ids))] + for index, input_id in enumerate(input_ids): + micro_batch_size[ + index] = input_id[index].shape[0] // self.dp_size + ( + input_id[index].shape[0] % self.dp_size > 0 + ) # round up division, i.e., math.ceil(a / b) + micro_batches = partition_list_by_micro_batch_size( + input_ids, self.dp_size, attention_mask, labels) + object_refs = [] + for index, micro_batch in enumerate(micro_batches): + input_ids_mb = [] + attention_mask_mb = [] + labels_mb = [] + loss_weights_mb = [] + assert len(micro_batch) == self.dp_size + for i in range(len(micro_batch)): + input_ids_mb.append(micro_batch[i]['input_ids']) + attention_mask_mb.append(micro_batch[i]['attention_mask']) + labels_mb.append(micro_batch[i]['labels']) + loss_weights_mb.append(micro_batch[i]['loss_weights']) + + object_ref = self.ray_actors[index].train.remote( + inputs=input_ids_mb, + attention_mask=attention_mask_mb, + labels=labels_mb, + loss_weights=loss_weights_mb, + *args, + **kwargs, + ) + object_refs.append(object_ref) + return object_ref + + def train_get(self, object_refs, timeout=None): + losses = ray.get(object_refs, timeout=timeout) + return sum(losses) / len(losses) + + def train(self, *args, **kwargs): + object_refs = self.train_async(*args, **kwargs) + return self.train_get(object_refs) + + # Inference + def infer_async(self, input_ids, attention_mask, *args, **kwargs): + micro_batch_size = input_ids.shape[0] // self.dp_size + ( + input_ids.shape[0] % self.dp_size > 0 + ) # round up division, i.e., math.ceil(a / b) + micro_batches = partition_by_micro_batch_size(input_ids, + micro_batch_size, + attention_mask) + assert len(micro_batches) == self.dp_size + return [ + self.ray_actors[index].infer.remote( + inputs=micro_batch['input_ids'], + attention_mask=micro_batch['attention_mask'], + *args, + **kwargs, + ) for index, micro_batch in enumerate(micro_batches) + ] + + def infer_get(self, object_refs, timeout=None): + outputs = ray.get(object_refs, timeout=timeout) + return concat_policy_outputs(outputs) + + def infer(self, *args, **kwargs): + object_refs = self.infer_async(*args, **kwargs) + return self.infer_get(object_refs) + + # Generation + def generate_async(self, input_ids, attention_mask, *args, **kwargs): + micro_batch_size = input_ids.shape[0] // self.dp_size + ( + input_ids.shape[0] % self.dp_size > 0 + ) # round up division, i.e., math.ceil(a / b) + micro_batches = partition_by_micro_batch_size(input_ids, + micro_batch_size, + attention_mask) + assert len(micro_batches) == self.dp_size + return [ + self.ray_actors[index].generate.remote( + inputs=micro_batch['input_ids'], + attention_mask=micro_batch['attention_mask'], + *args, + **kwargs, + ) for index, micro_batch in enumerate(micro_batches) + ] + + def generate_get(self, object_refs, timeout=None): + outputs = ray.get(object_refs, timeout=timeout) + padding_token_map = { + 'output_ids': self.config.tokenizer_config.pad_token_id + } + return concat_policy_outputs(outputs, padding_token_map) + + def generate(self, *args, **kwargs): + object_refs = self.generate_async(*args, **kwargs) + return self.generate_get(object_refs) + + # Others + def get_model(self): + return self.ray_actors[0].get_model.remote() + + def get_state_dict(self): + state_dicts = [ + actor.get_state_dict.remote() for actor in self.ray_actors + ] + return state_dicts[0] + + def set_seed(self, seed=None): + ray.get([actor.set_seed.remote(seed) for actor in self.ray_actors]) + + def release_resources(self): + """release ray resources.""" + if self.released: + return + for actor in self.ray_actors: + try: + ray.kill(actor=actor, no_restart=True) + except BaseException as exp: + logger.error(f'failed to kill ray actor {actor}. {exp}') + remove_placement_group(self.placement_group) + self.released = True + + def save_model(self, path): + ray.get([actor.save_model.remote(path) for actor in self.ray_actors]) + + def init_process_group(self, generator): + refs = [ + hfm.init_process_group.remote(generator) + for i, hfm in enumerate(self.ray_actors) + ] + ray.get(refs) + + def broadcast_model_to_generator(self, generator: None): + refs = [ + hfm.broadcast_model_to_generator.remote(generator) + for i, hfm in enumerate(self.ray_actors) + ] + ray.get(refs) diff --git a/xtuner/rlhf/model_backend/models/__init__.py b/xtuner/rlhf/model_backend/models/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/xtuner/rlhf/model_backend/models/configuration_internlm2.py b/xtuner/rlhf/model_backend/models/configuration_internlm2.py new file mode 100644 index 000000000..c76e1407f --- /dev/null +++ b/xtuner/rlhf/model_backend/models/configuration_internlm2.py @@ -0,0 +1,159 @@ +# flake8: noqa: E501 +# Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved. +# +# This code is based on transformers/src/transformers/models/llama/configuration_llama.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""InternLM2 model configuration.""" + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +INTERNLM2_PRETRAINED_CONFIG_ARCHIVE_MAP = {} + + +# Modified from transformers.model.llama.configuration_llama.LlamaConfig +class InternLM2Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`InternLM2Model`]. It is used to instantiate + an InternLM2 model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the InternLM2-7B. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the InternLM2 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`InternLM2Model`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings(`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + Example: + + """ + model_type = 'internlm2' + _auto_class = 'AutoConfig' + + def __init__( # pylint: disable=W0102 + self, + vocab_size=103168, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act='silu', + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + reward_token_id=92527, + two_linear_reward_head=False, + tie_word_embeddings=False, + bias=True, + rope_theta=10000, + rope_scaling=None, + attn_implementation='eager', + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.bias = bias + + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self._rope_scaling_validation() + + self.attn_implementation = attn_implementation + if self.attn_implementation is None: + self.attn_implementation = 'eager' + + self.reward_token_id = reward_token_id + self.two_linear_reward_head = two_linear_reward_head + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + def _rope_scaling_validation(self): + """Validate the `rope_scaling` configuration.""" + if self.rope_scaling is None: + return + + if not isinstance(self.rope_scaling, + dict) or len(self.rope_scaling) != 2: + raise ValueError( + '`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, ' + f'got {self.rope_scaling}') + rope_scaling_type = self.rope_scaling.get('type', None) + rope_scaling_factor = self.rope_scaling.get('factor', None) + if rope_scaling_type is None or rope_scaling_type not in [ + 'linear', 'dynamic' + ]: + raise ValueError( + f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" + ) + if rope_scaling_factor is None or not isinstance( + rope_scaling_factor, float) or rope_scaling_factor < 1.0: + raise ValueError( + f"`rope_scaling`'s factor field must be a float >= 1, got {rope_scaling_factor}" + ) diff --git a/xtuner/rlhf/model_backend/models/critical_and_reward.py b/xtuner/rlhf/model_backend/models/critical_and_reward.py new file mode 100644 index 000000000..bb1e3697a --- /dev/null +++ b/xtuner/rlhf/model_backend/models/critical_and_reward.py @@ -0,0 +1,110 @@ +from typing import Optional + +import torch +import torch.nn as nn +from transformers import AutoConfig, AutoModel +from transformers.dynamic_module_utils import get_class_from_dynamic_module +from transformers.modeling_outputs import SequenceClassifierOutputWithPast + + +def _get_model_class(model_name_or_path: str): + config = AutoConfig.from_pretrained( + model_name_or_path, trust_remote_code=True) + config_class = type(config) + if config_class in AutoModel._model_mapping: + model_class = AutoModel._model_mapping[type(config)] + model_base_class = model_class.__base__ + return model_class, model_base_class + + if 'AutoModel' in config.auto_map: + module_file, causal_model_name = config.auto_map['AutoModel'].split( + '.') + elif 'AutoModelForCausalLM' in config.auto_map: + module_file, causal_model_name = config.auto_map[ + 'AutoModelForCausalLM'].split('.') + else: + raise Exception( + f'config of {model_name_or_path} has no AutoModel or AutoModelForCausalLM in auto_map' # noqa: E501 + ) + + model_class_name = (causal_model_name.split('For')[0] + 'Model' + ) # e.g., "InternLM2Model" + model_class = get_class_from_dynamic_module( + f'{module_file}.{model_class_name}', model_name_or_path) + model_base_class_name = (causal_model_name.split('For')[0] + + 'PreTrainedModel' + ) # e.g., "InternLM2PreTrainedModel" + model_base_class = get_class_from_dynamic_module( + f'{module_file}.{model_base_class_name}', model_name_or_path) + return model_class, model_base_class + + +def get_critic_model(model_name_or_path: str, head_name): + model_class, model_base_class = _get_model_class(model_name_or_path) + + class CriticModel(model_base_class): + supports_gradient_checkpointing = True + + def __init__(self, config: AutoConfig): + super().__init__(config) + self.model = model_class(config) + self.head_name = head_name + setattr(self, head_name, + nn.Linear(config.hidden_size, 1, bias=False)) + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + **_ignored, + ) -> torch.Tensor: + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + ) + hidden_states = outputs[0] + logits = getattr(self, + self.head_name)(hidden_states).squeeze(-1)[:, :-1] + + return SequenceClassifierOutputWithPast(logits=logits, ) + + return CriticModel + + +def get_reward_model(model_name_or_path: str, head_name): + model_class, model_base_class = _get_model_class(model_name_or_path) + + class RewardModel(model_base_class): + supports_gradient_checkpointing = True + + def __init__(self, config: AutoConfig): + super().__init__(config) + self.model = model_class(config) + self.head_name = head_name + setattr(self, head_name, + nn.Linear(config.hidden_size, 1, bias=False)) + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + **_ignored, + ) -> torch.Tensor: + eos_indices = ( + attention_mask.size(1) - 1 - + attention_mask.long().fliplr().argmax(dim=1, keepdim=True)) + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + ) + hidden_states = outputs[0] + values = getattr(self, self.head_name)(hidden_states).squeeze(-1) + reward_scores = values.gather(dim=1, index=eos_indices).squeeze(1) + + return SequenceClassifierOutputWithPast(logits=reward_scores, ) + + return RewardModel diff --git a/xtuner/rlhf/model_backend/models/modeling_internlm2_p.py b/xtuner/rlhf/model_backend/models/modeling_internlm2_p.py new file mode 100644 index 000000000..9d87f94ad --- /dev/null +++ b/xtuner/rlhf/model_backend/models/modeling_internlm2_p.py @@ -0,0 +1,1536 @@ +# flake8: noqa: E501 +# Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved. +# +# This code is based on transformers/src/transformers/models/llama/modeling_llama.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch InternLM2 model.""" +import math +import queue +import threading +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from einops import rearrange +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers.activations import ACT2FN +from transformers.modeling_outputs import (BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import (add_start_docstrings, + add_start_docstrings_to_model_forward, logging, + replace_return_docstrings) + +try: + from transformers.generation.streamers import BaseStreamer +except: # pylint: disable=bare-except + BaseStreamer = None + +from .configuration_internlm2 import InternLM2Config + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = 'InternLM2Config' + +flash_attn_func, flash_attn_varlen_func = None, None +pad_input, index_first_axis, unpad_input = None, None, None + + +def _import_flash_attn(): + global flash_attn_func, flash_attn_varlen_func + global pad_input, index_first_axis, unpad_input + try: + from flash_attn import flash_attn_func as _flash_attn_func + from flash_attn import \ + flash_attn_varlen_func as _flash_attn_varlen_func + from flash_attn.bert_padding import \ + index_first_axis as _index_first_axis + from flash_attn.bert_padding import pad_input as _pad_input + from flash_attn.bert_padding import unpad_input as _unpad_input + flash_attn_func, flash_attn_varlen_func = _flash_attn_func, _flash_attn_varlen_func + pad_input, index_first_axis, unpad_input = _pad_input, _index_first_axis, _unpad_input + except ImportError: + raise ImportError('flash_attn is not installed.') + + +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad( + torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +# Copied from transformers.models.bart.modeling_bart._make_causal_mask +def _make_causal_mask(input_ids_shape: torch.Size, + dtype: torch.dtype, + device: torch.device, + past_key_values_length: int = 0): + """Make causal mask used for bi-directional self-attention.""" + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), + torch.tensor(torch.finfo(dtype).min, device=device), + device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat([ + torch.zeros( + tgt_len, past_key_values_length, dtype=dtype, device=device), + mask + ], + dim=-1) + return mask[None, None, :, :].expand(bsz, 1, tgt_len, + tgt_len + past_key_values_length) + + +# Copied from transformers.models.bart.modeling_bart._expand_mask +def _expand_mask(mask: torch.Tensor, + dtype: torch.dtype, + tgt_len: Optional[int] = None): + """Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, + src_seq_len]`.""" + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, + src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill( + inverted_mask.to(torch.bool), + torch.finfo(dtype).min) + + +# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->InternLM2 +class InternLM2RMSNorm(nn.Module): + + def __init__(self, hidden_size, eps=1e-6): + """InternLM2RMSNorm is equivalent to T5LayerNorm.""" + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +# Copied from transformers.model.llama.modeling_llama.LlamaRotaryEmbedding with Llama->InternLM2 +class InternLM2RotaryEmbedding(nn.Module): + + def __init__(self, + dim, + max_position_embeddings=2048, + base=10000, + device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / ( + self.base + **(torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer('inv_freq', inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, + device=self.inv_freq.device, + dtype=torch.get_default_dtype()) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + + freqs = torch.einsum('i,j->ij', t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer( + 'cos_cached', emb.cos().to(dtype), persistent=False) + self.register_buffer( + 'sin_cached', emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache( + seq_len=seq_len, device=x.device, dtype=torch.float32) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), + ) + + +# Copied from transformers.model.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->InternLM2 +class InternLM2LinearScalingRotaryEmbedding(InternLM2RotaryEmbedding): + """InternLM2RotaryEmbedding extended with linear scaling. + + Credits to the Reddit user /u/kaiokendev + """ + + def __init__(self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + t = t / self.scaling_factor + + freqs = torch.einsum('i,j->ij', t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer( + 'cos_cached', emb.cos().to(dtype), persistent=False) + self.register_buffer( + 'sin_cached', emb.sin().to(dtype), persistent=False) + + +# Copied from transformers.model.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->InternLM2 +class InternLM2DynamicNTKScalingRotaryEmbedding(InternLM2RotaryEmbedding): + """InternLM2RotaryEmbedding extended with Dynamic NTK scaling. + + Credits to the Reddit users /u/bloc97 and /u/emozilla. + """ + + def __init__(self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + + if seq_len > self.max_position_embeddings: + base = self.base * ((self.scaling_factor * seq_len / + self.max_position_embeddings) - + (self.scaling_factor - 1))**( + self.dim / (self.dim - 2)) + inv_freq = 1.0 / ( + base + **(torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer('inv_freq', inv_freq, persistent=False) + + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + + freqs = torch.einsum('i,j->ij', t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer( + 'cos_cached', emb.cos().to(dtype), persistent=False) + self.register_buffer( + 'sin_cached', emb.sin().to(dtype), persistent=False) + + +# Copied from transformers.model.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., :x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.model.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors.""" + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class InternLM2MLP(nn.Module): + + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.w1 = nn.Linear( + self.hidden_size, self.intermediate_size, bias=False) + self.w3 = nn.Linear( + self.hidden_size, self.intermediate_size, bias=False) + self.w2 = nn.Linear( + self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.w2(self.act_fn(self.w1(x)) * self.w3(x)) + + return down_proj + + +# Copied from transformers.model.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """This is the equivalent of torch.repeat_interleave(x, dim=1, + repeats=n_rep). + + The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to + (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, + None, :, :].expand(batch, + num_key_value_heads, + n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, + head_dim) + + +# Modified from transformers.model.llama.modeling_llama.LlamaAttention +class InternLM2Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper.""" + + def __init__(self, config: InternLM2Config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.is_causal = True + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f'hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}' + f' and `num_heads`: {self.num_heads}).') + + self.wqkv = nn.Linear( + self.hidden_size, + (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim, + bias=config.bias, + ) + + self.wo = nn.Linear( + self.num_heads * self.head_dim, self.hidden_size, bias=config.bias) + self._init_rope() + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = InternLM2RotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.config.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling['type'] + scaling_factor = self.config.rope_scaling['factor'] + if scaling_type == 'dynamic': + self.rotary_emb = InternLM2DynamicNTKScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.config.rope_theta, + scaling_factor=scaling_factor, + ) + elif scaling_type == 'linear': + self.rotary_emb = InternLM2LinearScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.config.rope_theta, + scaling_factor=scaling_factor, + ) + else: + raise ValueError( + "Currently we only support rotary embedding's type being 'dynamic' or 'linear'." + ) + return self.rotary_emb + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, + self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], + Optional[Tuple[torch.Tensor]]]: + if 'padding_mask' in kwargs: + warnings.warn( + 'Passing `padding_mask` is deprecated and will be removed in v4.37. ' + 'Please make sure use `attention_mask` instead.`') + + bsz, q_len, _ = hidden_states.size() + + qkv_states = self.wqkv(hidden_states) + + qkv_states = rearrange( + qkv_states, + 'b q (h gs d) -> b q h gs d', + gs=2 + self.num_key_value_groups, + d=self.head_dim, + ) + + query_states = qkv_states[..., :self.num_key_value_groups, :] + query_states = rearrange(query_states, 'b q h gs d -> b q (h gs) d') + key_states = qkv_states[..., -2, :] + value_states = qkv_states[..., -1, :] + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose( + 2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f'Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is' + f' {attn_weights.size()}') + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f'Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}' + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f'`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is' + f' {attn_output.size()}') + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.wo(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# Modified from transformers.model.llama.modeling_llama.InternLM2FlashAttention2 +class InternLM2FlashAttention2(InternLM2Attention): + """InternLM2 flash attention module. + + This module inherits from `InternLM2Attention` as the weights of the module + stays untouched. The only required change would be on the forward pass + where it needs to correctly call the public API of flash attention and deal + with padding tokens in case the input contains any of them. + """ + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], + Optional[Tuple[torch.Tensor]]]: + # InternLM2FlashAttention2 attention does not support output_attentions + if 'padding_mask' in kwargs: + warnings.warn( + 'Passing `padding_mask` is deprecated and will be removed in v4.37. ' + 'Please make sure use `attention_mask` instead.`') + + # overwrite attention_mask with padding_mask + attention_mask = kwargs.pop('padding_mask') + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + qkv_states = self.wqkv(hidden_states) + + qkv_states = rearrange( + qkv_states, + 'b q (h gs d) -> b q h gs d', + gs=2 + self.num_key_value_groups, + d=self.head_dim, + ) + + query_states = qkv_states[..., :self.num_key_value_groups, :] + query_states = rearrange(query_states, 'b q h gs d -> b q (h gs) d') + key_states = qkv_states[..., -2, :] + value_states = qkv_states[..., -1, :] + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + attn_output = self._flash_attention_forward(query_states, key_states, + value_states, + attention_mask, q_len) + attn_output = attn_output.reshape(bsz, q_len, + self.hidden_size).contiguous() + attn_output = self.wo(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def _flash_attention_forward(self, + query_states, + key_states, + value_states, + attention_mask, + query_length, + dropout=0.0, + softmax_scale=None): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`int`, *optional*): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + # Contains at least one padding token in the sequence + causal = self.is_causal and query_length != 1 + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._unpad_input( + query_states, key_states, value_states, attention_mask, + query_length) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, + query_length) + else: + attn_output = flash_attn_func( + query_states, + key_states, + value_states, + dropout, + softmax_scale=softmax_scale, + causal=causal) + + return attn_output + + def _unpad_input(self, query_layer, key_layer, value_layer, attention_mask, + query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data( + attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, + head_dim), indices_k) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, + head_dim), indices_k) + + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, + head_dim), indices_k) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input( + query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q.to(torch.int64), + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +INTERNLM2_ATTENTION_CLASSES = { + 'eager': InternLM2Attention, + 'flash_attention_2': InternLM2FlashAttention2, +} + + +# Modified from transformers.model.llama.modeling_llama.LlamaDecoderLayer +class InternLM2DecoderLayer(nn.Module): + + def __init__(self, config: InternLM2Config): + super().__init__() + self.hidden_size = config.hidden_size + + self.attention = INTERNLM2_ATTENTION_CLASSES[ + config.attn_implementation]( + config=config) + + self.feed_forward = InternLM2MLP(config) + self.attention_norm = InternLM2RMSNorm( + config.hidden_size, eps=config.rms_norm_eps) + self.ffn_norm = InternLM2RMSNorm( + config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, + torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + if 'padding_mask' in kwargs: + warnings.warn( + 'Passing `padding_mask` is deprecated and will be removed in v4.37. ' + 'Please make sure use `attention_mask` instead.`') + + residual = hidden_states + + hidden_states = self.attention_norm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.attention( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.ffn_norm(hidden_states) + hidden_states = self.feed_forward(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states, ) + + if output_attentions: + outputs += (self_attn_weights, ) + + if use_cache: + outputs += (present_key_value, ) + + return outputs + + +InternLM2_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`InternLM2Config`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel with Llama->InternLM2 +@add_start_docstrings( + 'The bare InternLM2 Model outputting raw hidden-states without any specific head on top.', + InternLM2_START_DOCSTRING, +) +class InternLM2PreTrainedModel(PreTrainedModel): + config_class = InternLM2Config + base_model_prefix = 'model' + supports_gradient_checkpointing = True + _no_split_modules = ['InternLM2DecoderLayer'] + _skip_keys_device_placement = 'past_key_values' + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +InternLM2_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or + when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, decoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +# Modified from transformers.model.llama.modeling_llama.LlamaModel +@add_start_docstrings( + 'The bare InternLM2 Model outputting raw hidden-states without any specific head on top.', + InternLM2_START_DOCSTRING, +) +class InternLM2Model(InternLM2PreTrainedModel): + """Transformer decoder consisting of *config.num_hidden_layers* layers. + Each layer is a [`InternLM2DecoderLayer`] + + Args: + config: InternLM2Config + """ + + _auto_class = 'AutoModel' + + def __init__(self, config: InternLM2Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.config = config + + self.tok_embeddings = nn.Embedding(config.vocab_size, + config.hidden_size, + self.padding_idx) + + self.layers = nn.ModuleList([ + InternLM2DecoderLayer(config) + for _ in range(config.num_hidden_layers) + ]) + self.norm = InternLM2RMSNorm( + config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.tok_embeddings + + def set_input_embeddings(self, value): + self.tok_embeddings = value + + def _prepare_decoder_attention_mask(self, attention_mask, input_shape, + inputs_embeds, past_key_values_length): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + inputs_embeds.dtype, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask( + attention_mask, inputs_embeds.dtype, + tgt_len=input_shape[-1]).to(inputs_embeds.device) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else + expanded_attn_mask + combined_attention_mask) + + return combined_attention_mask + + @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else + self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.config.attn_implementation == 'flash_attention_2': + _import_flash_attn() + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + 'You cannot specify both input_ids and inputs_embeds at the same time' + ) + elif input_ids is not None: + batch_size, seq_length = input_ids.shape[:2] + elif inputs_embeds is not None: + batch_size, seq_length = inputs_embeds.shape[:2] + else: + raise ValueError( + 'You have to specify either input_ids or inputs_embeds') + + seq_length_with_past = seq_length + past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device) + position_ids = position_ids.unsqueeze(0) + + if inputs_embeds is None: + inputs_embeds = self.tok_embeddings(input_ids) + + if self.config.attn_implementation == 'flash_attention_2': + # 2d mask is passed through the layers + attention_mask = attention_mask if ( + attention_mask is not None and 0 in attention_mask) else None + else: + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_length_with_past), + dtype=torch.bool, + device=inputs_embeds.device) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, + past_key_values_length) + + # embed positions + hidden_states = inputs_embeds + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + '`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...' + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states, ) + + past_key_value = past_key_values[ + idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, None) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + position_ids, + None, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += ( + layer_outputs[2 if output_attentions else 1], ) + + if output_attentions: + all_self_attns += (layer_outputs[1], ) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states, ) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple( + v for v in + [hidden_states, next_cache, all_hidden_states, all_self_attns] + if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class NormHead(nn.Module): + + def __init__(self, hidden_size, vocab_size, bias=False): + super().__init__() + self.weight = nn.Parameter(torch.empty((vocab_size, hidden_size))) + self.first_flag = True + + def forward(self, hidden_states): + norm_weight = nn.functional.normalize(self.weight) + return nn.functional.linear(hidden_states, norm_weight) + + +# Modified from transformers.model.llama.modeling_llama.LlamaForCausalLM +class InternLM2ForCausalLM(InternLM2PreTrainedModel): + _auto_class = 'AutoModelForCausalLM' + + _tied_weights_keys = ['output.weight'] + + def __init__(self, config): + super().__init__(config) + self.model = InternLM2Model(config) + self.vocab_size = config.vocab_size + self.output = NormHead( + config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + self.norm_head = True + self.first_eval_flag = True + self.tmp_weight = None + + def get_input_embeddings(self): + return self.model.tok_embeddings + + def set_input_embeddings(self, value): + self.model.tok_embeddings = value + + def get_output_embeddings(self): + return self.output + + def set_output_embeddings(self, new_embeddings): + self.output = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING) + @replace_return_docstrings( + output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, InternLM2ForCausalLM + + >>> model = InternLM2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else + self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + + logits = self.output(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits, ) + outputs[1:] + return (loss, ) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation(self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + **kwargs): + if past_key_values is not None: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] + + position_ids = kwargs.get('position_ids', None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1]:] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {'inputs_embeds': inputs_embeds} + else: + model_inputs = {'input_ids': input_ids} + + model_inputs.update({ + 'position_ids': position_ids, + 'past_key_values': past_key_values, + 'use_cache': kwargs.get('use_cache'), + 'attention_mask': attention_mask, + }) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += (tuple( + past_state.index_select(0, beam_idx.to(past_state.device)) + for past_state in layer_past), ) + return reordered_past + + def build_inputs(self, + tokenizer, + query: str, + history: List[Tuple[str, str]] = [], + meta_instruction=''): + if tokenizer.add_bos_token: + prompt = '' + else: + prompt = tokenizer.bos_token + if meta_instruction: + prompt += f"""<|im_start|>system\n{meta_instruction}<|im_end|>\n""" + for record in history: + prompt += f"""<|im_start|>user\n{record[0]}<|im_end|>\n<|im_start|>assistant\n{record[1]}<|im_end|>\n""" + prompt += f"""<|im_start|>user\n{query}<|im_end|>\n<|im_start|>assistant\n""" + return tokenizer([prompt], return_tensors='pt') + + @torch.no_grad() + def chat( + self, + tokenizer, + query: str, + history: List[Tuple[str, str]] = [], + streamer: Optional[BaseStreamer] = None, + max_new_tokens: int = 1024, + do_sample: bool = True, + temperature: float = 0.8, + top_p: float = 0.8, + meta_instruction: + str = 'You are an AI assistant whose name is InternLM (书生·浦语).\n' + '- InternLM (书生·浦语) is a conversational language model that is developed by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n' + '- InternLM (书生·浦语) can understand and communicate fluently in the language chosen by the user such as English and 中文.', + **kwargs, + ): + inputs = self.build_inputs(tokenizer, query, history, meta_instruction) + inputs = { + k: v.to(self.device) + for k, v in inputs.items() if torch.is_tensor(v) + } + # also add end-of-assistant token in eos token id to avoid unnecessary generation + eos_token_id = [ + tokenizer.eos_token_id, + tokenizer.convert_tokens_to_ids(['<|im_end|>'])[0] + ] + outputs = self.generate( + **inputs, + streamer=streamer, + max_new_tokens=max_new_tokens, + do_sample=do_sample, + temperature=temperature, + top_p=top_p, + eos_token_id=eos_token_id, + **kwargs, + ) + outputs = outputs[0].cpu().tolist()[len(inputs['input_ids'][0]):] + response = tokenizer.decode(outputs, skip_special_tokens=True) + response = response.split('<|im_end|>')[0] + history = history + [(query, response)] + return response, history + + @torch.no_grad() + def stream_chat( + self, + tokenizer, + query: str, + history: List[Tuple[str, str]] = [], + max_new_tokens: int = 1024, + do_sample: bool = True, + temperature: float = 0.8, + top_p: float = 0.8, + **kwargs, + ): + """Return a generator in format: (response, history) Eg. + + ('你好,有什么可以帮助您的吗', [('你好', '你好,有什么可以帮助您的吗')]) ('你好,有什么可以帮助您的吗?', [('你好', + '你好,有什么可以帮助您的吗?')]) + """ + if BaseStreamer is None: + raise ModuleNotFoundError( + 'The version of `transformers` is too low. Please make sure ' + 'that you have installed `transformers>=4.28.0`.') + + response_queue = queue.Queue(maxsize=20) + + class ChatStreamer(BaseStreamer): + + def __init__(self, tokenizer) -> None: + super().__init__() + self.tokenizer = tokenizer + self.queue = response_queue + self.query = query + self.history = history + self.response = '' + self.cache = [] + self.received_inputs = False + self.queue.put( + (self.response, history + [(self.query, self.response)])) + + def put(self, value): + if len(value.shape) > 1 and value.shape[0] > 1: + raise ValueError('ChatStreamer only supports batch size 1') + elif len(value.shape) > 1: + value = value[0] + + if not self.received_inputs: + # The first received value is input_ids, ignore here + self.received_inputs = True + return + + self.cache.extend(value.tolist()) + token = self.tokenizer.decode( + self.cache, skip_special_tokens=True) + if token.strip() != '<|im_end|>': + self.response = self.response + token + history = self.history + [(self.query, self.response)] + self.queue.put((self.response, history)) + self.cache = [] + else: + self.end() + + def end(self): + self.queue.put(None) + + def stream_producer(): + return self.chat( + tokenizer=tokenizer, + query=query, + streamer=ChatStreamer(tokenizer=tokenizer), + history=history, + max_new_tokens=max_new_tokens, + do_sample=do_sample, + temperature=temperature, + top_p=top_p, + **kwargs, + ) + + def consumer(): + producer = threading.Thread(target=stream_producer) + producer.start() + while True: + res = response_queue.get() + if res is None: + return + yield res + + return consumer() + + +# Copied from transformers.model.llama.modeling_llama.LlamaForSequenceClassification with Llama->InternLM2 +@add_start_docstrings( + """ + The InternLM2 Model transformer with a sequence classification head on top (linear layer). + + [`InternLM2ForSequenceClassification`] uses the last token in order to do the classification, + as other causal models (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + InternLM2_START_DOCSTRING, +) +class InternLM2ForSequenceClassification(InternLM2PreTrainedModel): + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = InternLM2Model(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.tok_embeddings + + def set_input_embeddings(self, value): + self.model.tok_embeddings = value + + @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError( + 'Cannot handle batch sizes > 1 if no padding token is defined.' + ) + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = (torch.eq( + input_ids, self.config.pad_token_id).int().argmax(-1) - + 1).to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), + sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = 'regression' + elif self.num_labels > 1 and (labels.dtype == torch.long + or labels.dtype == torch.int): + self.config.problem_type = 'single_label_classification' + else: + self.config.problem_type = 'multi_label_classification' + + if self.config.problem_type == 'regression': + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == 'single_label_classification': + loss_fct = CrossEntropyLoss() + loss = loss_fct( + pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == 'multi_label_classification': + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits, ) + transformer_outputs[1:] + return ((loss, ) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) diff --git a/xtuner/rlhf/model_backend/net_utils.py b/xtuner/rlhf/model_backend/net_utils.py new file mode 100644 index 000000000..7fc715836 --- /dev/null +++ b/xtuner/rlhf/model_backend/net_utils.py @@ -0,0 +1,31 @@ +import socket + + +def get_ip(): + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + s.settimeout(0) + try: + s.connect(('10.254.254.254', 1)) + local_ip = s.getsockname()[0] + except BaseException: + local_ip = '127.0.0.1' + finally: + s.close() + return local_ip + + +def get_ip_hostname(): + hostname = socket.gethostname() + return get_ip(), hostname + + +def get_free_port() -> int: + """Get a free port for the actor to use for DDP dist_init. + + Returns: A free port that could be used. + """ + tcp = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + tcp.bind(('', 0)) + _, port = tcp.getsockname() + tcp.close() + return port diff --git a/xtuner/rlhf/model_backend/ray_actor_group.py b/xtuner/rlhf/model_backend/ray_actor_group.py new file mode 100644 index 000000000..a2da48a7e --- /dev/null +++ b/xtuner/rlhf/model_backend/ray_actor_group.py @@ -0,0 +1,19 @@ +import ray + +from .cuda_memory_stats import merge_cuda_memory_stats_list +from .ray_actor_mixin import RayActorMixin + + +class RayActorGroup: + + def __init__(self, name: str, config: dict): + self.config = config + self.name = name # name_prefix for ray_actors + self.ray_actors: list[RayActorMixin] = [] + + def get_cuda_mem_stats(self): + return merge_cuda_memory_stats_list( + ray.get([ + ray_actor.get_memory_stats_of_visible_devices.remote() + for ray_actor in self.ray_actors + ])) diff --git a/xtuner/rlhf/model_backend/ray_actor_mixin.py b/xtuner/rlhf/model_backend/ray_actor_mixin.py new file mode 100644 index 000000000..c075fc13e --- /dev/null +++ b/xtuner/rlhf/model_backend/ray_actor_mixin.py @@ -0,0 +1,92 @@ +import json +import os +from dataclasses import dataclass +from typing import Optional + +import torch + +from .cuda_memory_stats import CudaMemoryStats +from .net_utils import get_free_port, get_ip, get_ip_hostname + + +@dataclass +class RayActorMetadata: + """Metadata for Ray actor. + + This information is expected to stay the same throughout the lifetime of actor. # noqa: E501 + + Args: + node_ip (str): Node IP address that this actor is on. + hostname (str): Hostname that this actor is on. + gpu_ids (Optional[list[int]]): List of CUDA IDs available to this actor. # noqa: E501 + gpu_num (int): Number of used GPUs of this actor. + """ + + node_ip: str + hostname: str + gpu_ids: Optional[list[int]] + gpu_num: int + + def __str__(self) -> str: + info = { + 'Node_IP': self.node_ip, + 'Hostname': self.hostname, + 'GPU_IDs': self.gpu_ids, + 'GPU_Num': self.gpu_num, + } + return json.dumps(info, indent=4, sort_keys=True) + + +class RayActorMixin: + + def inject_distribute_env( + self, + master_ip: Optional[str] = None, + master_port: int = 0, + rank_id: int = 0, + world_size: int = 0, + ) -> None: + """Inject Environment Variables before training. + + Args: + master_ip (Optional[str]): The ip address of the master node. + master_port (int): The port on the master node used for dist_init. + rank_id (int): The rank id of this actor. + world_size (int): Number of Actors for DDP training. + """ + os.environ['MASTER_ADDR'] = master_ip + os.environ['MASTER_PORT'] = str(master_port) + os.environ['RANK'] = str(rank_id) + os.environ['WORLD_SIZE'] = str(world_size) + os.environ['LOCAL_RANK'] = '0' + + def get_metadata(self) -> RayActorMetadata: + node_ip, hostname = get_ip_hostname() + gpu_ids = os.environ['CUDA_VISIBLE_DEVICES'] + gpu_num = torch.cuda.device_count() + + return RayActorMetadata( + node_ip=node_ip, + hostname=hostname, + gpu_ids=gpu_ids, + gpu_num=gpu_num, + ) + + def get_free_port(self): + return get_free_port() + + def get_memory_stats_of_visible_devices(self) -> CudaMemoryStats: + visible_gpu_ids = [] + if 'CUDA_VISIBLE_DEVICES' in os.environ: + visible_gpu_ids = os.environ['CUDA_VISIBLE_DEVICES'].split(',') + else: + visible_gpu_ids = [ + str(index) for index in range(torch.cuda.device_count()) + ] + + cuda_memory_stats = CudaMemoryStats() + for index, gpu_id in enumerate(visible_gpu_ids): + status = torch.cuda.memory_stats(device=index) + node_ip = get_ip() + cuda_memory_stats[f'ip{node_ip}-gpu{gpu_id}'] = status + return cuda_memory_stats diff --git a/xtuner/rlhf/model_backend/ray_utils.py b/xtuner/rlhf/model_backend/ray_utils.py new file mode 100644 index 000000000..c7dc4d1f2 --- /dev/null +++ b/xtuner/rlhf/model_backend/ray_utils.py @@ -0,0 +1,36 @@ +import uuid +from typing import TypeVar + +from ray.util.placement_group import PlacementGroup +from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy + +DEFAULT_NUM_CPUS = 1 +DEFAULT_NUM_GPUS = 1 +T = TypeVar('T') +UUID = uuid.uuid4() # may called multiple times in different ray instances + + +# Create Ray Actors +def create_ray_actors( + name_prefix: str, + config: dict, + placement_group: PlacementGroup, + trainer_class: T, +) -> list[T]: + ray_actors = [_ for _ in range(placement_group.bundle_count)] + for index in range(placement_group.bundle_count): + ray_actors[index] = trainer_class.options( + name=f'{name_prefix}_rank_{index}', + namespace=f'{UUID}_{trainer_class.__class__.__name__}', + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=placement_group, + placement_group_bundle_index=index, + ), + runtime_env=set_runtime_env(), + ).remote(config) + return ray_actors + + +def set_runtime_env(): + runtime_env = {'env_vars': {'HF_ENDPOINT': 'https://hf-mirror.com'}} + return runtime_env diff --git a/xtuner/rlhf/model_backend/vllm_model_runner.py b/xtuner/rlhf/model_backend/vllm_model_runner.py new file mode 100644 index 000000000..c04677a1f --- /dev/null +++ b/xtuner/rlhf/model_backend/vllm_model_runner.py @@ -0,0 +1,347 @@ +import os +from typing import Optional, Union + +import ray +import torch +from loguru import logger +from ray.util.placement_group import placement_group as create_placement_group +from ray.util.placement_group import remove_placement_group +from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy +from vllm import LLM, SamplingParams +from vllm.sampling_params import _SAMPLING_EPS + +from ..config.config_utils import get_dp_size, get_tp_size +from ..policy_output import PolicyOutput, concat_policy_outputs +from .generate_utils import (get_question_answer_mask, + partition_by_micro_batch_size) +from .ray_actor_group import RayActorGroup +from .ray_actor_mixin import RayActorMixin +from .ray_utils import DEFAULT_NUM_CPUS, DEFAULT_NUM_GPUS, set_runtime_env + +VLLM_DEFAULT_DEVICE = 'cuda' + + +class VllmGenerator: + + def __init__(self, model_config) -> None: + self.model_config: dict = model_config + + # Adapted from https://github.com/OpenLLMAI/OpenRLHF/blob/v0.2.5/openrlhf/trainer/ray/vllm_engine.py # noqa: E501 + def initialize(self) -> None: + model_path = self.model_config.get('model_path') + torch_dtype = self.model_config.get('torch_dtype', 'auto') + tokenizer_path = self.model_config.get('tokenizer_path', model_path) + parallel: dict = self.model_config.get('parallel') + tensor_parallel_size = 1 if parallel is None else parallel['tensor'][ + 'size'] + + import vllm + + if '0.2.7' <= vllm.__version__ <= '0.3.3' and tensor_parallel_size != 1: # noqa: E501 + # NOTE: In 0.2.7, vLLM made a major change to its architecture which move one worker into the driver process. # noqa: E501 + # Driver process will manually set CUDA_VISIBLE_DEVICES before worker init. To avoid importing torch before # noqa: E501 + # set CUDA_VISIBLE_DEVICES, we must defer monkey patch. + # For more detail, see: https://github.com/vllm-project/vllm/pull/2221 # noqa: E501 + def _set_cuda_visible_devices(device_ids: list[int]): + os.environ['CUDA_VISIBLE_DEVICES'] = ','.join( + map(str, device_ids)) + from vllm.worker import worker + + from .vllm_worker_wrap import VllmWorkerWrap + + worker.Worker = VllmWorkerWrap + + vllm.engine.llm_engine.set_cuda_visible_devices = _set_cuda_visible_devices # noqa: E501 + else: + from vllm.worker import worker + + from .vllm_worker_wrap import VllmWorkerWrap + + worker.Worker = VllmWorkerWrap + + self.llm: LLM = vllm.LLM( + model=model_path, + tokenizer=tokenizer_path, + trust_remote_code=True, + dtype=torch_dtype, + swap_space=0, + tensor_parallel_size=tensor_parallel_size, + device=VLLM_DEFAULT_DEVICE, + ) + self.tokenizer = self.llm.get_tokenizer() + tokenizer_config = self.model_config.get('tokenizer_config', {}) + for key, value in tokenizer_config.items(): + setattr(self.tokenizer, key, value) + + @staticmethod + def get_sampling_params_from_dict(generate_kwargs: dict) -> SamplingParams: + sp = SamplingParams() + for k, v in generate_kwargs.items(): + if k in sp.__dict__: + sp.__dict__[k] = v + elif k == 'num_beams' and v > 1: + sp.__dict__['use_beam_search'] = True + elif k == 'eos_token_id': + sp.__dict__['stop_token_ids'] = [v] + + sp.top_k = -1 if sp.top_k <= 1 else sp.top_k + sp._verify_args() + + if sp.use_beam_search: + sp._verify_beam_search() + else: + sp.early_stopping = False + sp._verify_non_beam_search() + if sp.temperature < _SAMPLING_EPS: + # Zero temperature means greedy sampling. + sp.top_p = 1.0 + sp.top_k = -1 + sp.min_p = 0.0 + sp._verify_greedy_sampling() + return sp + + def generate( + self, + inputs: Union[torch.Tensor, str, list[str]], + step=-1, + output_str=True, + output_logits=False, + output_attentions=False, + output_hidden_states=False, + generate_kwargs: Optional[dict] = {}, + **_ignored, + ) -> list[tuple[list[int], str]]: + sp = VllmGenerator.get_sampling_params_from_dict(generate_kwargs) + sp.max_tokens = step if step > 0 else None + logger.info( + f'[{self.__class__.__name__}] self.generate() SamplingParams: {sp}' + ) + + if isinstance(inputs, torch.Tensor): + if len(inputs.shape) == 2: # e.g., [batch_size, seq_len] + prompt = self.tokenizer.batch_decode( + inputs, + skip_special_tokens=False, + clean_up_tokenization_spaces=False, + ) + elif len(inputs.shape) == 1: # e.g., [seq_len] + prompt = self.tokenizer.decode( + inputs, + skip_special_tokens=False, + clean_up_tokenization_spaces=False, + ) + else: + raise ValueError( + f'Unsupported tensor inputs of shape({inputs.shape})') + + elif isinstance(inputs, str): + prompt = inputs # str + elif isinstance(inputs, list): + if isinstance(inputs[0], list): + prompt = inputs # list[int] + else: + raise ValueError( + f'Unsupported inputs[0] with type({type(inputs[0])})') + else: + raise ValueError(f'Unsupported inputs with type({type(inputs)})') + + # Calling vllm's generate + req_outputs = self.llm.generate( + prompt_token_ids=prompt, sampling_params=sp) + + def get_longest_list_length(list_of_lists): + max_length = 0 + for int_list in list_of_lists: + current_length = len(int_list) + if current_length > max_length: + max_length = current_length + return max_length + + _max_length = get_longest_list_length(prompt) + + def pad_list_with_pad_token(int_list, max_length, pad_token_id): + if len(int_list) < max_length: + num_pad_token_to_add = max_length - len(int_list) + padded_list = [pad_token_id] * num_pad_token_to_add + int_list + return padded_list + else: + return int_list + + policy_outputs = [] + for _, req_output in enumerate(req_outputs): + output = PolicyOutput() + input_ids = [item for item in req_output.prompt_token_ids] + input_ids = pad_list_with_pad_token(input_ids, _max_length, + self.tokenizer.pad_token_id) + output_token_ids = [ + item for item in req_output.outputs[0].token_ids + ] + output_ids = input_ids + output_token_ids # concat + output['input_ids'] = torch.Tensor(input_ids).to( + torch.long).unsqueeze(0) + output['output_ids'] = torch.tensor(output_ids).to( + torch.long).unsqueeze(0) + + output['question_mask'], output[ + 'answer_mask'] = get_question_answer_mask( + output['input_ids'], + output['output_ids'], + tokenizer_pad_token_id=self.tokenizer.pad_token_id, + generate_pad_token_id=generate_kwargs.get('pad_token_id'), + ) + output[ + 'attention_mask'] = output.question_mask + output.answer_mask # noqa: E501 + output['action_mask'] = output['attention_mask'][:, _max_length - + 1:-1] + if output_logits: + raise NotImplementedError('TODO: output_logits') + if output_attentions: + raise NotImplementedError('TODO: output_attentions') + if output_hidden_states: + raise NotImplementedError('TODO: output_hidden_states') + if output_str: # return list[str] + output['output_ans_str'] = [req_output.outputs[0].text] + output_str = self.tokenizer.decode( + output_ids, + skip_special_tokens=True, + clean_up_tokenization_spaces=False, + ) + output['output_str'] = [output_str] + output.to('cpu') + + policy_outputs.append(output) + + padding_token_map = {'output_ids': self.tokenizer.pad_token_id} + concated_policy_out = concat_policy_outputs(policy_outputs, + padding_token_map) + return concated_policy_out + + +class VllmGeneratorRayActor(VllmGenerator, RayActorMixin): + + # Adapted from https://github.com/OpenLLMAI/OpenRLHF/blob/v0.2.5/openrlhf/trainer/ray/vllm_engine.py # noqa: E501 + def init_process_group(self, master_address, master_port, rank_offset, + world_size, group_name): + return self.llm.llm_engine._run_workers( + 'init_process_group', + master_address, + master_port, + rank_offset, + world_size, + group_name, + ) + + def update_weight(self, name, dtype, shape, empty_cache=False): + return self.llm.llm_engine._run_workers('update_weight', name, dtype, + shape, empty_cache) + + +class VllmGeneratorRayActorGroup(RayActorGroup): + + def __init__(self, name: str, config: dict): + import uuid + self.released = True + self.config = config + self.tp_size = get_tp_size(config) # tensor parallelism + self.dp_size = get_dp_size(config) # num of vllm_engines + self.tokenizer_pad_token_id = config.tokenizer_config.pad_token_id + self.ray_actors: list[VllmGeneratorRayActor] = [] # i.e., vllm_engines + + # Adapted from https://github.com/OpenLLMAI/OpenRLHF/blob/v0.2.5/openrlhf/trainer/ray/vllm_engine.py # noqa: E501 + for dp_i in range(self.dp_size): + ray_actor_num_gpus = int(self.tp_size == 1) + scheduling_strategy = None + + if self.tp_size > 1: + bundles = [{ + 'CPU': DEFAULT_NUM_CPUS, + 'GPU': DEFAULT_NUM_GPUS + }] * self.tp_size + self.placement_group = create_placement_group(bundles) + ray.get(self.placement_group.ready()) + + scheduling_strategy = PlacementGroupSchedulingStrategy( + placement_group=self.placement_group, + placement_group_capture_child_tasks=True, + placement_group_bundle_index=0, + ) + + namespace = f'{uuid.uuid4()}_{VllmGeneratorRayActor.__class__.__name__}' # noqa: E501 + self.ray_actors.append( + ray.remote(VllmGeneratorRayActor).options( + name=f'{name}_rank_{dp_i}', + namespace=namespace, + num_cpus=1, + num_gpus=ray_actor_num_gpus, + scheduling_strategy=scheduling_strategy, + runtime_env=set_runtime_env(), + ).remote(config)) + + self.released = False + self.initialize_ref = [ + actor.initialize.remote() for actor in self.ray_actors + ] + + def initialize_get(self): + shared_with_trainer = self.config.get('shared_with_trainer', False) + if shared_with_trainer: + assert self.initialize_ref is None + return # assuming trainer.initialize_get() has been called + if self.initialize_ref is not None: + ray.get(self.initialize_ref) + else: + logger.warning( + 'self.initialize_ref is None when calling initialize_get()') + self.initialize_ref = None + + # Generation + def generate_async(self, input_ids, attention_mask, *args, **kwargs): + assert ( + len(input_ids) >= self.dp_size + ), f'The length of input_ids({len(input_ids)}) must not be less than dp_size({self.dp_size}).' # noqa: E501 + micro_batch_size = len(input_ids) // self.dp_size + ( + len(input_ids) % self.dp_size > 0 + ) # round up division, i.e., math.ceil(a / b) + micro_batches = partition_by_micro_batch_size(input_ids, + micro_batch_size, + attention_mask) + assert len(micro_batches + ) == self.dp_size, f'{len(micro_batches)}, :{self.dp_size}' + return [ + self.ray_actors[index].generate.remote( + inputs=micro_batch['input_ids'], + attention_mask=micro_batch['attention_mask'], + *args, + **kwargs, + ) for index, micro_batch in enumerate(micro_batches) + ] + + def generate_get(self, object_refs, timeout=None): + outputs = ray.get(object_refs, timeout=timeout) + padding_token_map = { + 'output_ids': self.config.tokenizer_config.pad_token_id + } + return concat_policy_outputs(outputs, padding_token_map) + + def generate(self, *args, **kwargs): + object_refs = self.generate_async(*args, **kwargs) + return self.generate_get(object_refs) + + # Others + def get_model(self): + return self.ray_actors[0].get_model.remote() + + def set_seed(self, seed=None): + ray.get([actor.set_seed.remote(seed) for actor in self.ray_actors]) + + def release_resources(self): + """release ray resources.""" + if self.released: + return + for actor in self.ray_actors: + try: + ray.kill(actor=actor, no_restart=True) + except BaseException as exp: + logger.error(f'failed to kill ray actor {actor}. {exp}') + remove_placement_group(self.placement_group) + self.released = True diff --git a/xtuner/rlhf/model_backend/vllm_worker_wrap.py b/xtuner/rlhf/model_backend/vllm_worker_wrap.py new file mode 100644 index 000000000..843941e81 --- /dev/null +++ b/xtuner/rlhf/model_backend/vllm_worker_wrap.py @@ -0,0 +1,77 @@ +# Adapted from https://github.com/OpenLLMAI/OpenRLHF/blob/v0.2.5/openrlhf/trainer/ray/vllm_worker_wrap.py # noqa: E501 +import importlib + +import torch +from vllm.model_executor.weight_utils import hf_model_weights_iterator +from vllm.worker.worker import Worker + +from ..logger import init_logger +from .dist_utils import init_process_group + +logger = init_logger(__name__) + + +def _hf_model_weights_iterator_wrap(model_name_or_path, *args, **kwargs): + if isinstance(model_name_or_path, dict): + yield from model_name_or_path.items() + else: + yield from hf_model_weights_iterator(model_name_or_path, *args, + **kwargs) + + +class VllmWorkerWrap(Worker): + + def __init__(self, *args, **kwargs): + # Monkey patch hf_model_weights_iterator to allow update single weight + # NOTE: In 0.2.5, vLLM introduce lazy model loader + # https://github.com/vllm-project/vllm/pull/2044 + from vllm.model_executor.models import _MODELS, ModelRegistry + + load_model_cls = ModelRegistry.load_model_cls + + def patched_load_model_cls(model_arch: str): + module_name, _ = _MODELS[model_arch] + module = importlib.import_module( + f'vllm.model_executor.models.{module_name}') + module.hf_model_weights_iterator = _hf_model_weights_iterator_wrap + logger.info( + f'Monkey patch hf_model_weights_iterator for module {module_name}' # noqa: E501 + ) + + return load_model_cls(model_arch) + + ModelRegistry.load_model_cls = patched_load_model_cls + + super().__init__(*args, **kwargs) + + def init_process_group(self, master_address, master_port, rank_offset, + world_size, group_name): + """Init torch process group for model weights update.""" + assert torch.distributed.is_initialized( + ), 'default torch process group must be initialized' + assert group_name != '', 'group name must not be empty' + + rank = torch.distributed.get_rank() + rank_offset + self._model_update_group = init_process_group( + backend='nccl', + init_method=f'tcp://{master_address}:{master_port}', + world_size=world_size, + rank=rank, + group_name=group_name, + ) + logger.info( + f'init_process_group: master_address={master_address}, master_port={master_port}, ' # noqa: E501 + f'rank={rank}, world_size={world_size}, group_name={group_name}') + + def update_weight(self, name, dtype, shape, empty_cache=False): + """Broadcast weight to all vllm workers from source rank 0 (actor + model)""" + if torch.distributed.get_rank() == 0: + logger.debug( + f'update weight: {name}, dtype: {dtype}, shape: {shape}') + + weight = torch.empty(shape, dtype=dtype, device='cuda') + torch.distributed.broadcast(weight, 0, group=self._model_update_group) + self.model_runner.model.load_weights(model_name_or_path={name: weight}) + + del weight diff --git a/xtuner/rlhf/model_server/__init__.py b/xtuner/rlhf/model_server/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/xtuner/rlhf/model_server/actor_model_server.py b/xtuner/rlhf/model_server/actor_model_server.py new file mode 100644 index 000000000..a18cf43f4 --- /dev/null +++ b/xtuner/rlhf/model_server/actor_model_server.py @@ -0,0 +1,99 @@ +from typing import Optional + +import torch +from loguru import logger + +from ..config.config_consts import ENGINE_VLLM +from ..tokenizer import tokenizer_utils +from .base_model_server import BaseModelServer + + +class ActorModelServer(BaseModelServer): + # Initialize + def initialize_async(self): + super().initialize_async() + + self.generator_eq_trainer = True + # use trainer for self.generate() by default + self.generator = self.trainer + if 'generator_config' not in self.model_config: + return # self.generator = self.trainer + + generator_config = self.model_config['generator_config'] # optional + if generator_config.get('shared_with_trainer', True): + return # self.generator = self.trainer + + generator_config['model_path'] = self.model_config['model_path'] + generator_config['tokenizer_config'] = self.tokenizer_config + generator_config[ + 'tokenizer_path'] = self.tokenizer_config.tokenizer_path + generator_type = generator_config.get('generator_type', None) + if generator_type == ENGINE_VLLM: + from ..model_backend.vllm_model_runner import \ + VllmGeneratorRayActorGroup + self.generator = VllmGeneratorRayActorGroup( + f'{self.model_name}_generator', generator_config) + # to sync model among trainer and generator + self.trainer.initialize_get() + self.trainer.init_process_group(self.generator) + else: + raise ValueError( + f"No generator is registered with type '{generator_type}'") + self.generator_eq_trainer = False + + def initialize_get(self): + self.generator.initialize_get() + self.is_initialized = True + logger.info(f'{self.model_name} has been initialized. ') + + # Generation + def generate_async(self, + inputs, + attention_mask=None, + *args, + **generate_kwargs): + if isinstance(inputs, torch.Tensor): + input_ids = inputs + elif isinstance(inputs, list): + if not self.generator_eq_trainer: + input_ids, attention_mask = tokenizer_utils.encode( + inputs, + self.tokenizer, + return_tensors=None, + padding=False, + add_generation_prompt=True) + else: + input_ids, attention_mask = tokenizer_utils.encode( + inputs, self.tokenizer, add_generation_prompt=True) + else: + raise NotImplementedError(f'unknown inputs: {inputs}') + + return self.generator.generate_async( + input_ids=input_ids, + attention_mask=attention_mask, + *args, + **generate_kwargs) + + def generate_get(self, object_refs, timeout: Optional[float] = None): + return self.generator.generate_get(object_refs, timeout=timeout) + + def generate(self, inputs, *args, **generate_kwargs): + object_refs = self.generate_async(inputs, *args, **generate_kwargs) + policy_output = self.generate_get(object_refs) + self.log_cuda_mem_stats(remark='[generate] ') + return policy_output + + # Sync + def sync_model(self, *args, **kwargs): + if not self.generator_eq_trainer: + self.trainer.broadcast_model_to_generator(self.generator) + + # Misc. + def log_cuda_mem_stats(self, remark=''): + if self.show_cuda_mem_stats: + trainer_mem = self.trainer.get_cuda_mem_stats() + generator_mem = self.generator.get_cuda_mem_stats() + logger.info( + f'{remark}{self.model_name} trainer allocated GPU memory: {trainer_mem.total_current_mb} MiB, ' # noqa: E501 + f'generator allocated GPU memory: {generator_mem.total_current_mb} MiB, ' # noqa: E501 + f'generator_eq_trainer: {self.generator_eq_trainer}') diff --git a/xtuner/rlhf/model_server/base_model_server.py b/xtuner/rlhf/model_server/base_model_server.py new file mode 100644 index 000000000..ffb2426bd --- /dev/null +++ b/xtuner/rlhf/model_server/base_model_server.py @@ -0,0 +1,170 @@ +from typing import Optional + +import ray +import torch +from loguru import logger +from transformers import AutoModelForCausalLM + +from ..config.config_consts import ENGINE_HUGGINGFACE, ENGINE_INTERNEVO +from ..model_backend.hf_model_runner import HfModelRunnerRayActorGroup +from ..model_backend.models.modeling_internlm2_p import InternLM2ForCausalLM +from ..tokenizer import tokenizer_utils + +DEFAULT_GET_TIMEOUT = 600.0 # 10 min + + +class BaseModelServer: + # Initialize + def __init__(self, model_name: str, model_config: dict): + self.model_name = model_name + self.model_config = model_config + self.tokenizer = None + self.tokenizer_config = None + self.trainer = None + self.trainer_config = None + self.model_ref = None + self.is_initialized = False + self.show_cuda_mem_stats = self.model_config.get( + 'show_cuda_mem_stats', False) + logger.info(f'model_name={model_name}, model_config={model_config}') + + def init_tokenizer_and_config(self, model_config): + tokenizer_config = model_config.get('tokenizer_config', {}) + if 'tokenizer_path' in tokenizer_config: + tokenizer_path = tokenizer_config['tokenizer_path'] + elif 'tokenizer_path' in model_config: + tokenizer_path = model_config['tokenizer_path'] + else: + tokenizer_path = model_config['model_path'] + + self.tokenizer = tokenizer_utils.get_tokenizer( + tokenizer_path, trust_remote_code=True, **tokenizer_config) + + tokenizer_config['tokenizer_path'] = tokenizer_path + tokenizer_config['pad_token_id'] = self.tokenizer.pad_token_id + self.tokenizer_config = tokenizer_config + + def init_trainer_config(self, model_config, tokenizer_config): + model_path = model_config['model_path'] + trainer_config: dict = model_config['trainer_config'] # requisite + trainer_config['tokenizer_config'] = tokenizer_config + trainer_config['tokenizer_path'] = tokenizer_config['tokenizer_path'] + trainer_config['model_path'] = model_path + trainer_config['model_type'] = model_config['model_type'] + trainer_config['model_class'] = self.get_model_class(model_path) + self.trainer_config = trainer_config + + def get_model_class(self, model_path): + # will be changed in subclasses + if model_path == 'internlm/internlm2-chat-1_8b-sft': + return InternLM2ForCausalLM + return AutoModelForCausalLM + + def initialize_async(self): + self.init_tokenizer_and_config(self.model_config) + self.init_trainer_config(self.model_config, self.tokenizer_config) + + trainer_type = self.trainer_config.get('trainer_type', + 'huggingface').lower() + if trainer_type == ENGINE_HUGGINGFACE: + self.trainer = HfModelRunnerRayActorGroup( + name=f'{self.model_name}_trainer', config=self.trainer_config) + elif trainer_type == ENGINE_INTERNEVO: + raise NotImplementedError(f'{trainer_type}.') + else: + raise ValueError( + f'No trainer is registered with type: {trainer_type}') + + def initialize_get(self): + self.trainer.initialize_get() + self.is_initialized = True + logger.info(f'{self.model_name} has been initialized.') + + # Inference + def infer_async(self, inputs, attention_mask=None, *args, **infer_kwargs): + if not isinstance(inputs, torch.Tensor): + input_ids, attention_mask = tokenizer_utils.encode( + inputs, self.tokenizer) + else: + input_ids = inputs + return self.trainer.infer_async( + input_ids=input_ids, + attention_mask=attention_mask, + *args, + **infer_kwargs) + + def infer_get(self, object_refs, timeout: Optional[float] = None): + return self.trainer.infer_get(object_refs, timeout=timeout) + + def infer(self, inputs, *args, **infer_kwargs): + object_refs = self.infer_async(inputs, *args, **infer_kwargs) + results = self.infer_get(object_refs) + self.log_cuda_mem_stats(remark='[infer] ') + return results + + # Training + def train_async(self, + input_ids, + labels=None, + attention_mask=None, + *args, + **train_kwargs): + return self.trainer.train_async(input_ids, labels, attention_mask, + *args, **train_kwargs) + + def train_get(self, object_refs, timeout: Optional[float] = None): + return self.trainer.train_get(object_refs, timeout=timeout) + + def train(self, + input_ids, + labels=None, + attention_mask=None, + *args, + **train_kwargs): + object_refs = self.train_async(input_ids, labels, attention_mask, + *args, **train_kwargs) + loss = self.train_get(object_refs) + self.log_cuda_mem_stats(remark='[train] ') + return loss + + # Generation + def generate_async(self, + inputs, + attention_mask=None, + *args, + **generate_kwargs): + raise NotImplementedError + + def generate_get(self, object_refs, timeout: Optional[float] = None): + raise NotImplementedError + + def generate(self, inputs, *args, **generate_kwargs): + raise NotImplementedError + + # Model + def model_get(self): + if not self.model_ref: + self.model_ref = self.trainer.get_model() # an reference + return ray.get(self.model_ref, timeout=DEFAULT_GET_TIMEOUT) + + def state_dict_get(self): + return ray.get( + self.trainer.get_state_dict(), timeout=DEFAULT_GET_TIMEOUT) + + def save_model(self, path): + self.trainer.save_model(path) + + # Misc. + def set_seed(self, seed: int = None): + self.trainer.set_seed(seed) + + def log_cuda_mem_stats(self, remark=''): + if self.show_cuda_mem_stats: + trainer_mem = self.trainer.get_cuda_mem_stats() + logger.info( + f'{remark}{self.model_name} trainer allocated GPU memory: {trainer_mem.total_current_mb} MiB' # noqa: E501 + ) + + def clean_up(self): + self.trainer.release_resources() + logger.info(f'{self.model_name} is destroyed.') diff --git a/xtuner/rlhf/model_server/critic_model_server.py b/xtuner/rlhf/model_server/critic_model_server.py new file mode 100644 index 000000000..fe5afc5b2 --- /dev/null +++ b/xtuner/rlhf/model_server/critic_model_server.py @@ -0,0 +1,9 @@ +from ..model_backend.models.critical_and_reward import get_critic_model +from .base_model_server import BaseModelServer + + +class CriticModelServer(BaseModelServer): + # Initialize + def get_model_class(self, model_path): + head_name = self.model_config.get('head_name', 'v_head') + return get_critic_model(model_path, head_name) diff --git a/xtuner/rlhf/model_server/ref_model_server.py b/xtuner/rlhf/model_server/ref_model_server.py new file mode 100644 index 000000000..90b1dcce3 --- /dev/null +++ b/xtuner/rlhf/model_server/ref_model_server.py @@ -0,0 +1,5 @@ +from .base_model_server import BaseModelServer + + +class RefModelServer(BaseModelServer): + pass # same as BaseModelServer diff --git a/xtuner/rlhf/model_server/reward_model_server.py b/xtuner/rlhf/model_server/reward_model_server.py new file mode 100644 index 000000000..de275c0d9 --- /dev/null +++ b/xtuner/rlhf/model_server/reward_model_server.py @@ -0,0 +1,43 @@ +import torch +from transformers import AutoConfig + +from ..model_backend.models.critical_and_reward import get_reward_model +from ..tokenizer import tokenizer_utils +from ..utils import expand_reward_token_id +from .base_model_server import BaseModelServer + + +class RewardModelServer(BaseModelServer): + # Initialize + def get_model_class(self, model_path): + head_name = self.model_config.get('head_name', 'v_head') + return get_reward_model(model_path, head_name) + + def init_tokenizer_and_config(self, model_config): + super().init_tokenizer_and_config(self.model_config) + + self.reward_token_id = self.tokenizer.pad_token_id + model_path = model_config['model_path'] + auto_config = AutoConfig.from_pretrained( + model_path, trust_remote_code=True) + if hasattr(auto_config, 'reward_token_id'): + self.reward_token_id = auto_config.reward_token_id + + # Inference + def infer_async(self, inputs, attention_mask=None, *args, **infer_kwargs): + if not isinstance(inputs, torch.Tensor): + input_ids, attention_mask = tokenizer_utils.encode( + inputs, self.tokenizer) + else: + input_ids = inputs + + # Reward model specific + if self.reward_token_id is not None: + input_ids, attention_mask = expand_reward_token_id( + self.reward_token_id, input_ids, attention_mask) + + return self.trainer.infer_async( + input_ids=input_ids, + attention_mask=attention_mask, + *args, + **infer_kwargs) diff --git a/xtuner/rlhf/policy_output.py b/xtuner/rlhf/policy_output.py new file mode 100644 index 000000000..a7287cee6 --- /dev/null +++ b/xtuner/rlhf/policy_output.py @@ -0,0 +1,174 @@ +# Adopted from: https://github.com/huggingface/transformers/blob/HEAD/src/transformers/generation/utils.py # noqa: E501 +from dataclasses import dataclass +from typing import Optional + +import torch +from transformers.utils.generic import ModelOutput + + +@dataclass +class PolicyOutput(ModelOutput): + output_ids: Optional[torch.Tensor] = None + output_str: Optional[list[str]] = None + loss: Optional[torch.Tensor] = None + logits: Optional[torch.Tensor] = None + attentions: Optional[torch.Tensor] = None + hidden_states: Optional[torch.Tensor] = None + logits_entropy: Optional[torch.Tensor] = None + logprobs: Optional[torch.Tensor] = None + top_logprobs: Optional[torch.Tensor] = None + question_mask: Optional[torch.Tensor] = None + answer_mask: Optional[torch.Tensor] = None + + def __eq__(self, other: ModelOutput): + if len(self.keys()) != len(other.keys()): + return False + for k, v in self.items(): + if k not in other: + return False + vother = other[k] + + if isinstance(v, torch.Tensor): + if not torch.equal(v, vother): + return False + elif isinstance(v, tuple): # tuple(torch.Tensor) + for i, j in zip(v, vother): + if isinstance(i, torch.Tensor): + if not torch.equal(i, j): + return False + else: + if i != j: + return False + else: + if v != vother: + return False + return True + + def to(self, device): + for k, v in self.items(): + if isinstance(v, torch.Tensor): + self[k] = v.to(device) + + def get_tensor_keys(self): + keys = [] + for k, v in self.items(): + if isinstance(v, torch.Tensor): + keys.append(k) + return keys + + +def union_keys_from_policy_outputs(policy_outputs: list[PolicyOutput]) -> list: + all_keys = set() + for po in policy_outputs: + all_keys = all_keys.union(set(po.keys())) + return list( + all_keys) # e.g., return ["output_str", "output_ids", "loss", ...] + + +def union_tensor_keys_from_policy_outputs( + policy_outputs: list[PolicyOutput]) -> list: + all_keys = set() + for po in policy_outputs: + all_keys = all_keys.union(set(po.get_tensor_keys())) + return list(all_keys) # e.g., return ["output_ids", "loss", ...] + + +def concat_policy_outputs(policy_outputs: list[PolicyOutput], + padding_token_map: dict = None) -> PolicyOutput: + if isinstance(policy_outputs, PolicyOutput): + return policy_outputs # Wrong input type + elif policy_outputs is None or len(policy_outputs) == 0: + return PolicyOutput(None) + elif len(policy_outputs) == 1: + return policy_outputs[0] + + if padding_token_map is not None: # padding + policy_outputs = padding_policy_outputs(policy_outputs, + padding_token_map) + + concated = PolicyOutput() + all_keys = union_keys_from_policy_outputs(policy_outputs) + for key in all_keys: + for po in policy_outputs: + value = po[key] + if value is not None: + break # get the first non-empty value + if value is None: + continue # skip if all values are None + + if isinstance(value, torch.Tensor): + concated[key] = torch.cat( + [po[key] for po in policy_outputs if po[key] is not None], + dim=0) + elif isinstance(value, list): # e.g., list[str] + concated[key] = [] + for po in policy_outputs: + if po[key] is not None: + concated[key].extend(po[key]) + elif isinstance(value, tuple) and isinstance(value[0], torch.Tensor): + results = [] + for i in range(len(value)): + beef = [ + po[key][i] for po in policy_outputs + if po[key][i] is not None + ] + tensor = torch.cat( + beef, dim=0) if len(beef) > 0 else torch.Tensor() + results.append(tensor) + concated[key] = tuple(results) + raise NotImplementedError( + f'{value}\n{[v.shape for v in value]}\n{results}') + else: + raise TypeError( + f'value: {value} with unsupported type: {type(value)}.') + return concated + + +def padding_policy_outputs(policy_outputs: list[PolicyOutput], + padding_token_map={}): + DEFAULT_PADDING_ID = 0 + RIGHT_PADDING = True + tensor_keys = union_tensor_keys_from_policy_outputs(policy_outputs) + for key in tensor_keys: + padding_id = padding_token_map.get(key, DEFAULT_PADDING_ID) + max_seq_len = find_max_seq_len(policy_outputs, key) + for policy_output in policy_outputs: + origin_tensor = policy_output[key] + padding_size = max_seq_len - origin_tensor.shape[1] + pad = (0, padding_size) if RIGHT_PADDING else (padding_size, 0) + padded_tensor = torch.nn.functional.pad( + origin_tensor, pad, mode='constant', value=padding_id) + policy_output[key] = padded_tensor + return policy_outputs + + +def find_max_seq_len(policy_outputs: list[PolicyOutput], key): + max_seq_len = 0 + for policy_output in policy_outputs: + if policy_output[key] is None: + continue + batch_size, seq_len = policy_output[ + key].shape # assert: only support 2d tensor + max_seq_len = seq_len if seq_len > max_seq_len else max_seq_len + return max_seq_len + + +def logprobs_from_logits(logits: torch.Tensor, + labels: torch.Tensor = None, + gather: bool = True) -> torch.Tensor: + r""" + Adapted from: https://github.com/huggingface/trl/blob/main/trl/core.py#L95 + + Example: + + ```python + >>> logits, _, values = model(**input_kwargs) + >>> input_ids = input_kwargs["input_ids"] + >>> logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:]) + ```""" + + logp = torch.nn.functional.log_softmax(logits, dim=-1) + if not gather or labels is None: + return logp + logpy = torch.gather(logp, -1, labels.unsqueeze(2)).squeeze(-1) + return logpy.cuda() diff --git a/xtuner/rlhf/repeaters/__init__.py b/xtuner/rlhf/repeaters/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/xtuner/rlhf/repeaters/base.py b/xtuner/rlhf/repeaters/base.py new file mode 100644 index 000000000..0e68600b4 --- /dev/null +++ b/xtuner/rlhf/repeaters/base.py @@ -0,0 +1,311 @@ +import time + +import numpy as np +import torch +from loguru import logger + +from ..model_server.base_model_server import BaseModelServer +from ..policy_output import PolicyOutput + + +def find_mask_begin(padded_datas, mask_id=0): + """finding the mask id begin index and it's length.""" + begin_indexs = [] + lengths = [] + + for padded_data in padded_datas: + is_flag = 0 + for index, data in enumerate(padded_data): + if data != mask_id: + is_flag = 1 + begin_indexs.append(index) + length = (np.array(padded_data) != mask_id).sum() + lengths.append(length) + break + assert is_flag + return begin_indexs, lengths + + +class RunningStates: + # adopt from https://github.com/DLR-RM/stable-baselines3/blob/master/stable_baselines3/common/running_mean_std.py # noqa: E501 + def __init__(self, epsilon: float = 1e-4): + self.mean = torch.tensor(0, dtype=torch.float32) + self.var = torch.tensor(0, dtype=torch.float32) + self.count = epsilon + + def update(self, x: torch.Tensor): + x_var, x_mean = torch.var_mean(x.cpu(), unbiased=False) + x_count = x.shape[0] + self.update_from_moments(x_mean, x_var, x_count) + + def update_from_other(self, other: 'RunningStates'): + self.update_from_moments(other.mean, other.var, other.count) + + def update_from_moments(self, mean: torch.Tensor, var: torch.Tensor, + count: int): + delta = mean - self.mean + tot_count = self.count + count + m_a = self.var * self.count + m_b = var * count + m_2 = m_a + m_b + delta**2 * self.count * count / (self.count + count) + new_var = m_2 / (self.count + count) + + self.mean += delta * count / tot_count + self.var = new_var + self.count = tot_count + + def state_dict(self): + return dict(mean=self.mean, var=self.var, count=self.count) + + def load_state_dict(self, states): + self.mean = states['mean'] + self.var = states['var'] + self.count = states['count'] + + +class BaseRepeater: + + def __init__( + self, + sft_model, + reward_scale: bool = False, + fine_grained_rm: bool = False, + value_ema: bool = False, + actor_micro_bs: int = 8, + ref_micro_bs: int = 8, + critic_micro_bs: int = 32, + kl_coeff=0.02, + gamma=1.0, + gae_lambda=0.95, + answer_end_id=92542, + norm_adv=False, + norm_rewards=True, + **_ignored, + ): + self.sft_model = sft_model + self.actor_micro_bs = actor_micro_bs + self.ref_micro_bs = ref_micro_bs + self.critic_micro_bs = critic_micro_bs + self.reward_scale = reward_scale + self.fine_grained_rm = fine_grained_rm + self.value_ema = value_ema + self.kl_coeff = kl_coeff + self.gamma = gamma + self.gae_lambda = gae_lambda + self.answer_end_id = answer_end_id + self.norm_rewards = norm_rewards + if self.norm_rewards: + self.running_states = RunningStates(epsilon=0) + + def process( + self, + trajectories: PolicyOutput, + policy_model: BaseModelServer, + value_model: BaseModelServer, + sft_model: BaseModelServer = None, + # only used for async reward model.infer_get() in _get_kl_rewards + env=None, + ): + value_output_ref = self._get_values_async(trajectories, value_model) + action_mask = trajectories['action_mask'] + num_actions = action_mask.size(1) + if sft_model is not None: + self.sft_model: BaseModelServer = sft_model + kl_rewards, entropy, kl_distance, policy_logprobs, sft_logprobs = self._get_kl_rewards( # noqa: E501 + trajectories, policy_model, env=env) + trajectories['kl'] = (kl_distance * action_mask).sum( + axis=-1) / action_mask.sum(axis=-1) + trajectories['entropy'] = entropy + trajectories['kl_rewards'] = kl_rewards + trajectories['policy_logprobs'] = policy_logprobs + trajectories['sft_logprobs'] = sft_logprobs + + values = self._get_values_collect(value_output_ref, value_model) + old_values = values[:, -num_actions:] + advantages, returns = self.get_advantages_and_returns( + old_values, kl_rewards, action_mask) + + trajectories['advantages'] = advantages + trajectories['returns'] = returns + trajectories['old_values'] = old_values + + return trajectories + + def _get_kl_rewards(self, + trajectories: PolicyOutput, + policy_model: BaseModelServer, + env=None): + s_t = time.time() + policy_output = policy_model.infer_async( + inputs=trajectories.output_ids, + micro_batch_size=self.actor_micro_bs, + attention_mask=trajectories.attention_mask, + output_logits=False, + output_logprobs=True) + sft_output = self.sft_model.infer_async( + inputs=trajectories.output_ids, + micro_batch_size=self.ref_micro_bs, + attention_mask=trajectories.attention_mask, + output_logits=False, + output_logprobs=True) + policy_output = policy_model.infer_get(policy_output) + sft_output = self.sft_model.infer_get(sft_output) + logger.info( + f'[actor & ref infer_async] duration: {round(time.time() - s_t, 2)} s' # noqa: E501 + ) + + # Experimental + if env.async_reward: + rewards = env.get_reward_collect(trajectories['reward_output_ref']) + trajectories['reward_output_ref'] = None + clipped_rewards = torch.clamp( + rewards, min=env.clip_reward_min, max=env.clip_reward_max) + trajectories['rewards'] = rewards + trajectories['clipped_rewards'] = clipped_rewards + # Experimental + rewards = trajectories.clipped_rewards + if self.norm_rewards: + self.running_states.update(rewards) + norm_reward_score = (rewards - self.running_states.mean) / ( + self.running_states.var.sqrt() + 1e-8) + action_mask = trajectories.action_mask + num_actions = action_mask.size(1) + + policy_logprobs = policy_output.logprobs[:, -num_actions:] + sft_logprobs = sft_output.logprobs[:, -num_actions:] + + if self.kl_coeff <= 0.0: + self.kl_coeff = 0.0 + # compute_approx_kl + log_ratio = policy_logprobs - sft_logprobs + kl = log_ratio * action_mask + kl_reward = -self.kl_coeff * kl + + eos_indices = action_mask.size( + 1) - 1 - action_mask.long().fliplr().argmax( + dim=1, keepdim=True) + last_reward = torch.zeros_like(kl).scatter_( + dim=1, + index=eos_indices, + src=norm_reward_score.unsqueeze(1).to(kl.dtype)) + + reward = last_reward + kl_reward + + entropy = -(policy_logprobs * + action_mask).sum(axis=-1) / action_mask.sum(axis=-1) + return reward, entropy, kl, policy_logprobs, sft_logprobs + + def _get_values(self, trajectories: PolicyOutput, + value_model: BaseModelServer): + s_t = time.time() + value_output = value_model.infer( + inputs=trajectories.output_ids, + attention_mask=trajectories.attention_mask, + output_logits=True, + micro_batch_size=self.critic_micro_bs, + ) + logger.info( + f'[critic infer] duration: {round(time.time() - s_t, 2)} s') + raw_values = value_output.logits.squeeze(-1) + return raw_values + + def _get_values_async(self, trajectories: PolicyOutput, + value_model: BaseModelServer): + s_t = time.time() + value_output_ref = value_model.infer_async( + inputs=trajectories.output_ids, + attention_mask=trajectories.attention_mask, + output_logits=True, + micro_batch_size=self.critic_micro_bs, + ) + logger.info( + f'[critic infer] async duration: {round(time.time() - s_t, 2)} s') + return value_output_ref + + def _get_values_collect(self, value_output_ref, + value_model: BaseModelServer): + s_t = time.time() + value_output = value_model.infer_get(value_output_ref) + raw_values = value_output.logits.squeeze(-1) + logger.info( + f'[critic infer] async wait duration: {round(time.time() - s_t, 2)} s' # noqa: E501 + ) + return raw_values + + def _get_advantages_and_returns(self, trajectories): + output_ids = trajectories.output_ids + answer_mask = trajectories.answer_mask + values_with_last_value = trajectories.values_with_last_value + kl_rewards = trajectories.kl_rewards + + begins_index, answers_length = find_mask_begin(answer_mask, 0) + count = 0 + advantages_padded, returns_padded = torch.zeros_like( + kl_rewards, dtype=values_with_last_value.dtype), torch.zeros_like( + kl_rewards, dtype=values_with_last_value.dtype) + for begin_index, ans_len, value_with_last_value, reward, output_id in zip( # noqa: E501 + begins_index, answers_length, values_with_last_value, + kl_rewards, output_ids): + # shape :ans_len + 1 + value_with_last_value = value_with_last_value[begin_index - + 1:begin_index + + ans_len] + # shape :ans_len + reward = reward[begin_index:begin_index + ans_len] + last_gae_lam = torch.zeros((1), dtype=values_with_last_value.dtype) + # shape :ans_len + advantages = torch.zeros_like( + reward, dtype=values_with_last_value.dtype) + step_nums = advantages.shape[-1] + # shape:ans_len + 1 + dones = self._build_dones(output_id[begin_index:begin_index + + ans_len]) + for step in reversed(range(step_nums)): + next_non_terminal = 1 - dones[step + 1] + next_values = value_with_last_value[step + 1] + # delta and last_gae_lam using value and reward + delta = reward[ + step] + self.gamma * next_values * next_non_terminal - value_with_last_value[ # noqa: E501 + step] + last_gae_lam = delta + self.gamma * self.gae_lambda * next_non_terminal * last_gae_lam # noqa: E501 + advantages[step] = last_gae_lam[0] + returns = advantages + value_with_last_value[:-1] + advantages_padded[count, + begin_index:begin_index + ans_len] = advantages + returns_padded[count, begin_index:begin_index + ans_len] = returns + count += 1 + return advantages_padded, returns_padded + + # ans_len + 1: dones + def _build_dones(self, answer_ids): + dones = torch.tensor( + (answer_ids == self.answer_end_id).numpy().astype(np.float32)) + # (1, )the first one is not done, so obs_0_dones=0 + obs_0_dones = torch.zeros((1), dtype=torch.float32) + # (ans_len + 1), + dones = torch.concat((obs_0_dones, dones), axis=0) + return dones + + def get_advantages_and_returns( + self, + values: torch.Tensor, + rewards: torch.Tensor, + action_mask: torch.Tensor, + ): + # Adopted from https://github.com/CarperAI/trlx/blob/main/trlx/models/modeling_ppo.py#L134 # noqa: E501 + lastgaelam = 0 + advantages_reversed = [] + response_length = rewards.size(1) + + # Mask invalid responses + values = action_mask * values + rewards = action_mask * rewards + + for t in reversed(range(response_length)): + nextvalues = values[:, t + 1] if t < response_length - 1 else 0.0 + delta = rewards[:, t] + self.gamma * nextvalues - values[:, t] + lastgaelam = delta + self.gamma * self.gae_lambda * lastgaelam + advantages_reversed.append(lastgaelam) + advantages = torch.stack(advantages_reversed[::-1], dim=1) + returns = advantages + values + return advantages.detach(), returns diff --git a/xtuner/rlhf/timer.py b/xtuner/rlhf/timer.py new file mode 100644 index 000000000..4574ca8c7 --- /dev/null +++ b/xtuner/rlhf/timer.py @@ -0,0 +1,27 @@ +import time + +from loguru import logger + + +class Timer: + """Timer.""" + + def __init__(self, task_name: str): + self.task_name = task_name + self.duration = 0 + + def __enter__(self): + self.start() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.end() + + def start(self): + logger.info(f'Start {self.task_name}') + self.start = time.time() + + def end(self): + self.duration = time.time() - self.start + logger.info( + f' End {self.task_name}, duration = {self.duration:.2f} seconds') diff --git a/xtuner/rlhf/tokenizer/__init__.py b/xtuner/rlhf/tokenizer/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/xtuner/rlhf/tokenizer/tokenizer_utils.py b/xtuner/rlhf/tokenizer/tokenizer_utils.py new file mode 100644 index 000000000..d782c6eb0 --- /dev/null +++ b/xtuner/rlhf/tokenizer/tokenizer_utils.py @@ -0,0 +1,88 @@ +from typing import Optional, Union + +from transformers import (AutoTokenizer, LlamaTokenizer, PreTrainedTokenizer, + PreTrainedTokenizerFast) + +from ..logger import init_logger + +logger = init_logger(__name__) + +PADDING_SIDE = 'left' + + +def get_tokenizer( + tokenizer_name: str, + *args, + trust_remote_code: bool = False, + tokenizer_revision: Optional[str] = None, + padding_side: Optional[str] = PADDING_SIDE, + **kwargs, +) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: + """Gets a tokenizer for the given model name via Huggingface.""" + + try: + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_name, + *args, + trust_remote_code=trust_remote_code, + tokenizer_revision=tokenizer_revision, + padding_side=padding_side, + **kwargs, + ) + except ValueError as e: + # If the error pertains to the tokenizer class not existing or not + # currently being imported, suggest using the --trust-remote-code flag. + if not trust_remote_code and ( + 'does not exist or is not currently imported.' in str(e) + or 'requires you to execute the tokenizer file' in str(e)): + err_msg = 'Failed to load the tokenizer. Try `trust_remote_code=True`.' # noqa: E501 + raise RuntimeError(err_msg) from e + else: + raise e + except OSError as e: + if 'Incorrect path_or_model_id' in str(e): # e.g., v13.model + tokenizer = LlamaTokenizer.from_pretrained( + tokenizer_name, + *args, + trust_remote_code=trust_remote_code, + tokenizer_revision=tokenizer_revision, + padding_side=padding_side, + **kwargs, + ) + logger.warning('Using LlamaTokenizer.') + else: + raise e + except AttributeError as e: + raise e + + if not isinstance(tokenizer, PreTrainedTokenizerFast): + logger.warning( + 'Using a slow tokenizer. This might cause a significant ' + 'slowdown. Consider using a fast tokenizer instead.') + for key, value in kwargs.items(): + setattr(tokenizer, key, value) + return tokenizer + + +def encode( + inputs: Union[list[str], list[list[dict]]], + tokenizer, + return_tensors='pt', + padding=True, + add_generation_prompt: bool = False, +): + if isinstance(inputs[0], list): + inputs = [ + tokenizer.apply_chat_template( + input, + tokenize=False, + add_generation_prompt=add_generation_prompt, + return_tensors=return_tensors, + ) for input in inputs + ] + output = tokenizer( + inputs, + return_tensors=return_tensors, + padding=padding, + add_special_tokens=False) + return output.input_ids, output.attention_mask diff --git a/xtuner/rlhf/trainer/__init__.py b/xtuner/rlhf/trainer/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/xtuner/rlhf/trainer/ppo.py b/xtuner/rlhf/trainer/ppo.py new file mode 100644 index 000000000..a4a81aad6 --- /dev/null +++ b/xtuner/rlhf/trainer/ppo.py @@ -0,0 +1,173 @@ +import time + +import torch +from loguru import logger + +from ..loss.actor_loss import ActorLoss +from ..loss.critic_loss import CriticLoss +from ..model_server.base_model_server import BaseModelServer +from ..timer import Timer + + +class PPOTrainer: + + def __init__( + self, + policy_model, + value_model, + actor_micro_bs=2, + critic_micro_bs=2, + policy_learn_time=1, + value_learn_time=1, + ppo_minibatch=512, + value_minibatch=512, + pt_minibatch=None, + train_minibatch=None, + pt_criterion=None, + policy_criterion=ActorLoss(cliprange=0.2, loss_type='per_seq'), + value_criterion=CriticLoss(cliprange_value=0.5, loss_type='per_seq'), + **kwargs, + ): + + self.ppo_minibatch = ppo_minibatch + self.value_minibatch = value_minibatch + self.actor_micro_bs = actor_micro_bs + self.critic_micro_bs = critic_micro_bs + # policy + self.policy_model = policy_model + self.policy_learn_time = policy_learn_time + self.pt_minibatch = pt_minibatch + self.train_minibatch = train_minibatch + self.policy_minibatch = ppo_minibatch + + # value + self.value_model = value_model + self.value_learn_time = value_learn_time + self.value_minibatch = value_minibatch + + self.pt_criterion = pt_criterion + self.policy_criterion = policy_criterion + self.value_criterion = value_criterion + + def policy_learn(self, trajectories, policy_model: BaseModelServer): + policy_updates = len(trajectories.output_ids) // self.policy_minibatch + policy_loss = [] + pt_loss = [] + + for _ in range(self.policy_learn_time): + for i in range(policy_updates): + logger.info( + '[Policy Train] start policy trains {}/{} | {}'.format( + i + 1, policy_updates, _ + 1)) + begin = i * self.policy_minibatch + end = begin + self.policy_minibatch + policy_batch_inputs = { + 'input_ids': trajectories.output_ids[begin:end, :], + 'policy_logprobs': + trajectories.policy_logprobs[begin:end, :], + 'advs': trajectories.advantages[begin:end, :], + 'action_mask': trajectories.action_mask[begin:end, :], + 'attention_mask': trajectories.attention_mask[begin:end, :] + } + assert len( + policy_batch_inputs['input_ids'] + ) == self.policy_minibatch, '[Policy learn] make sure len(policy_batch_inputs) == self.policy_minibatch' # noqa: E501 + + loss_factor = 1.0 + labels = dict( + input_ids=policy_batch_inputs['input_ids'], + old_logprobs=policy_batch_inputs['policy_logprobs'], + advantages=policy_batch_inputs['advs'], + mask=policy_batch_inputs['action_mask'], + loss_factor=torch.tensor(loss_factor), + ) + s_t = time.time() + p_loss = policy_model.train( + input_ids=policy_batch_inputs['input_ids'], + labels=labels, + attention_mask=policy_batch_inputs['attention_mask'], + criterion=self.policy_criterion, + micro_batch_size=self.actor_micro_bs) + + logger.info( + f'[actor train] duration: {round(time.time() - s_t, 2)} s, {self.policy_minibatch} batch, Policy loss: {p_loss.item()}' # noqa: E501 + ) + policy_loss.append(p_loss.item()) + + with Timer('policy_model.sync_model'): + policy_model.sync_model() + return policy_loss, pt_loss + + def value_learn_async(self, trajectories, value_model: BaseModelServer): + value_updates = len(trajectories.output_ids) // self.value_minibatch + value_loss = [] + assert value_updates == 1 and self.policy_learn_time == 1, f'value_updates={value_updates} * self.policy_learn_time={self.policy_learn_time} > 1' # noqa: E501 + s_t = time.time() + value_batch_inputs, labels = self._value_learn_prepare( + 0, 0, trajectories, value_updates) + v_loss_ref = value_model.train_async( + input_ids=value_batch_inputs['input_ids'], + labels=labels, + attention_mask=value_batch_inputs['attention_mask'], + criterion=self.value_criterion, + micro_batch_size=self.critic_micro_bs, + ) + logger.info( + f'[critic train] async duration: {round(time.time() - s_t, 2)} s, {self.value_minibatch} batch' # noqa: E501 + ) + value_loss.append(v_loss_ref) + return value_loss + + def value_learn_get(self, value_loss_ref, value_model: BaseModelServer): + with Timer('value_model.train_get'): + return [ + value_model.train_get(ref).item() for ref in value_loss_ref + ] + + def value_learn(self, trajectories, value_model: BaseModelServer): + value_updates = len(trajectories.output_ids) // self.value_minibatch + value_loss = [] + + for learn_i in range(self.policy_learn_time): + for step_i in range(value_updates): + s_t = time.time() + value_batch_inputs, labels = self._value_learn_prepare( + step_i, learn_i, trajectories, value_updates) + v_loss = value_model.train( + input_ids=value_batch_inputs['input_ids'], + labels=labels, + attention_mask=value_batch_inputs['attention_mask'], + criterion=self.value_criterion, + micro_batch_size=self.critic_micro_bs, + ) + logger.info( + f'[critic train] duration: {round(time.time() - s_t, 2)} s, {self.value_minibatch} batch,value loss: {v_loss.item()}' # noqa: E501 + ) + value_loss.append(v_loss.item()) + return value_loss + + def _value_learn_prepare(self, step_i, learn_i, trajectories, + value_updates): + logger.info('[Value Train] start value trains {}/{} | {}'.format( + step_i + 1, value_updates, learn_i + 1)) + begin = step_i * self.value_minibatch + end = begin + self.value_minibatch + value_batch_inputs = { + 'input_ids': trajectories.output_ids[begin:end, :], + 'old_values': trajectories.old_values[begin:end, :], + 'returns': trajectories.returns[begin:end, :], + 'action_mask': trajectories.action_mask[begin:end, :], + 'attention_mask': trajectories.attention_mask[begin:end, :] + } + assert len( + value_batch_inputs['input_ids'] + ) == self.value_minibatch, '[Value learn] make sure len(value_batch_inputs) == self.value_minibatch' # noqa: E501 + + loss_factor = 1.0 + labels = dict( + old_values=value_batch_inputs['old_values'], + returns=value_batch_inputs['returns'], + mask=value_batch_inputs['action_mask'], + loss_factor=torch.tensor(loss_factor), + ) + return value_batch_inputs, labels diff --git a/xtuner/rlhf/utils.py b/xtuner/rlhf/utils.py new file mode 100644 index 000000000..be21d519c --- /dev/null +++ b/xtuner/rlhf/utils.py @@ -0,0 +1,65 @@ +import os +import random +from typing import Optional + +import numpy as np +import torch + +DEFAULT_SEED_NUMBER = 1234 + + +def set_seed(seed: int = DEFAULT_SEED_NUMBER): + if seed is None or not isinstance(seed, int): + seed = DEFAULT_SEED_NUMBER + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + # refer to https://pytorch.org/docs/1.13/notes/randomness.html#reproducibility # noqa: E501 + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.backends.cudnn_deterministic = True + torch.backends.cudnn_benchmark = False + torch.use_deterministic_algorithms(True, warn_only=True) + # refer to https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility # noqa: E501 + os.putenv('CUBLAS_WORKSPACE_CONFIG', + os.environ.get('CUBLAS_WORKSPACE_CONFIG', ':4096:8')) + + +def expand_reward_token_id(reward_token_id: int, + input_ids: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + pad_token_id=0): + assert len(input_ids.shape) == 2, \ + f'expand_reward_token_id error, len(input_ids.shape()) = {len(input_ids.shape())}' # noqa: E501 + new_input_ids = torch.zeros((input_ids.shape[0], input_ids.shape[1] + 1), + dtype=input_ids.dtype).to(input_ids.device) + new_attention_mask = torch.zeros_like( + new_input_ids, dtype=torch.int64).to(input_ids.device) + for i in range(input_ids.size(0)): + row = input_ids[i] + nonzero_index = (row != pad_token_id).nonzero(as_tuple=False) + if nonzero_index.numel() > 0: + nonzero_index = nonzero_index[-1] + 1 + new_input_ids[i] = torch.cat( + (input_ids[i][:nonzero_index], + torch.tensor([reward_token_id], dtype=input_ids.dtype).to( + input_ids.device), input_ids[i][nonzero_index:]), + 0).to(input_ids.device) + if attention_mask is not None: + new_attention_mask[i] = torch.cat( + (attention_mask[i][:nonzero_index], + torch.tensor([1], dtype=torch.int64).to( + input_ids.device), attention_mask[i][nonzero_index:]), + 0).to(input_ids.device) + else: + new_input_ids[i] = torch.cat( + (input_ids[i][:], + torch.tensor([reward_token_id], dtype=input_ids.dtype).to( + input_ids.device)), 0).to(input_ids.device) + if attention_mask is not None: + new_attention_mask[i] = torch.cat( + (attention_mask[i][:], torch.tensor( + [1], dtype=torch.int64).to(input_ids.device)), + 0).to(input_ids.device) + + return new_input_ids, new_attention_mask diff --git a/xtuner/tools/tokenize_ftdp_datasets.py b/xtuner/tools/tokenize_ftdp_datasets.py index 9327a91fe..769e60b4e 100644 --- a/xtuner/tools/tokenize_ftdp_datasets.py +++ b/xtuner/tools/tokenize_ftdp_datasets.py @@ -361,7 +361,7 @@ def tokenize_and_save(tokenizer, processed_dir, tokenized_dir): description=f'{os.path.basename(file_path)}...'): samples.append(sample) - train_tokens, valid_tokens, train_samples, valid_samples = write_bin_meta_bin( # noqa E501 + train_tokens, valid_tokens, train_samples, valid_samples = write_bin_meta_bin( # noqa: E501 path=tokenized_save_dir, dataset_name=dataset_name, samples=samples, From 14ea4542fa4b049d52996b5cf1e3a2fe38dd9803 Mon Sep 17 00:00:00 2001 From: lishuaibin Date: Thu, 13 Jun 2024 17:49:08 +0800 Subject: [PATCH 02/37] add/fix pretrain_loss --- .gitignore | 1 + examples/rlhf/four_model_8gpu.py | 163 ++++--- examples/rlhf/four_model_vllm_8gpu.py | 160 ++++--- examples/rlhf/quick_start.md | 2 +- xtuner/rlhf/dataset/base.py | 8 +- xtuner/rlhf/dataset/txt_loader.py | 408 ++++++++---------- xtuner/rlhf/envs/prompt_utils.py | 46 ++ xtuner/rlhf/envs/txt_env.py | 100 ++--- xtuner/rlhf/loss/actor_loss.py | 8 +- xtuner/rlhf/loss/critic_loss.py | 8 +- xtuner/rlhf/loss/pretrain_loss.py | 39 +- xtuner/rlhf/main.py | 25 +- xtuner/rlhf/model_backend/generate_utils.py | 21 +- xtuner/rlhf/model_backend/hf_model_runner.py | 230 +++++----- xtuner/rlhf/model_server/base_model_server.py | 6 +- xtuner/rlhf/repeaters/base.py | 159 ++----- xtuner/rlhf/repeaters/running_mean_std.py | 38 ++ xtuner/rlhf/trainer/ppo.py | 124 ++++-- 18 files changed, 747 insertions(+), 799 deletions(-) create mode 100644 xtuner/rlhf/envs/prompt_utils.py create mode 100644 xtuner/rlhf/repeaters/running_mean_std.py diff --git a/.gitignore b/.gitignore index ffe3444b8..c13320a73 100644 --- a/.gitignore +++ b/.gitignore @@ -113,6 +113,7 @@ data *.pkl.json *.log.json work_dirs/ +rlhf_trainlog*/ # Pytorch *.pth diff --git a/examples/rlhf/four_model_8gpu.py b/examples/rlhf/four_model_8gpu.py index 2d96a832d..9ccfc3c4c 100644 --- a/examples/rlhf/four_model_8gpu.py +++ b/examples/rlhf/four_model_8gpu.py @@ -1,5 +1,26 @@ import torch +MAX_PROMPT_LEN = 1024 +MAX_ANSWER_LEN = 1024 +MAX_PRETRAIN_LEN = 8192 + +PROMPT_BATCH_SIZE = 256 +PRETRAIN_BATCH_SIZE = 32 + +GENERATE_MICRO_BATCH_SIZE = 16 +AC_INFER_MICRO_BATCH_SIZE = 8 +REF_INFER_MICRO_BATCH_SIZE = 8 +TRAIN_MICRO_BATCH_SIZE = 2 + +ZERO_STAGE = 3 +ACTOR_DP_SIZE = 2 +CRITIC_DP_SIZE = 2 +ACTOR_GRADIENT_ACC_STEP = (PROMPT_BATCH_SIZE + PRETRAIN_BATCH_SIZE + ) // ACTOR_DP_SIZE // TRAIN_MICRO_BATCH_SIZE +CRITIC_GRADIENT_ACC_STEP = PROMPT_BATCH_SIZE // CRITIC_DP_SIZE // TRAIN_MICRO_BATCH_SIZE + +MODEL_DTYPE = 'auto' + tokenizer_config = dict( pad_token_id=0, eos_token_id=92542, @@ -7,56 +28,54 @@ ) rollout_config = dict( - actor_micro_bs=32, - reward_micro_bs=32, - clip_reward_min=-5, - clip_reward_max=5, - max_new_tokens=10, - async_reward=True, + actor_micro_bs=GENERATE_MICRO_BATCH_SIZE, + reward_micro_bs=GENERATE_MICRO_BATCH_SIZE, + max_new_tokens=MAX_ANSWER_LEN, + write_to_file=True, generate_kwargs={ 'do_sample': True, 'temperature': 1.0, 'top_k': 0, 'top_p': 0.9, - 'pad_token_id': 0, - 'eos_token_id': 92542, - 'early_stopping': True, - 'num_beams': 1, 'min_new_tokens': 1, - }) + 'num_beams': 1, + 'early_stopping': True, + 'eos_token_id': 92542, + 'pad_token_id': 0, + }, +) repeater_config = dict( - actor_micro_bs=8, - ref_micro_bs=8, - critic_micro_bs=32, - reward_scale=False, - fine_grained_rm=False, - value_ema=False, + actor_micro_bs=AC_INFER_MICRO_BATCH_SIZE, + critic_micro_bs=AC_INFER_MICRO_BATCH_SIZE, + ref_micro_bs=REF_INFER_MICRO_BATCH_SIZE, kl_coeff=0.01, gamma=1.0, gae_lambda=0.99, + clip_reward_min=-5, + clip_reward_max=5, answer_end_id=92542, norm_rewards=True, ) + train_config = dict( - ppo_minibatch=64, - value_minibatch=64, - actor_micro_bs=2, - critic_micro_bs=2, - pretrain_step=0, - save_interval=800, + actor_micro_bs=TRAIN_MICRO_BATCH_SIZE, + critic_micro_bs=TRAIN_MICRO_BATCH_SIZE, + ppo_loss_weight=1.0, + pretrain_loss_weight=0.5, + pretrain_step=20, + save_interval=40, ) -critic_model_path = 'internlm/internlm2-chat-1_8b-sft' - model_configs = dict( actor=dict( model_path='internlm/internlm2-chat-1_8b-sft', model_type='actor', - use_flash_attn=False, trainer_config=dict( + torch_dtype=MODEL_DTYPE, trainer_type='huggingface', - torch_dtype=torch.float32, + use_flash_attn=True, + gradient_checkpointing=False, train_kwargs=dict( micro_bsz=1, lr=1e-6, @@ -65,14 +84,14 @@ loss_type='per_seq', ), parallel=dict( - data=dict(size=2, mode='deepspeed'), + data=dict(size=ACTOR_DP_SIZE, mode='deepspeed'), tensor=dict(size=1, mode='1d'), pipeline=dict(size=1, interleaved_overlap=False), sequence=False, ), deepspeed_config={ 'zero_optimization': { - 'stage': 2, + 'stage': ZERO_STAGE, 'offload_param': { 'device': 'none' }, @@ -91,34 +110,21 @@ 'data_types': { 'grad_accum_dtype': 'fp32' }, - 'train_micro_batch_size_per_gpu': 2, - 'gradient_accumulation_steps': 16, - 'train_batch_size': 64 - }), - generator_config=dict(shared_with_trainer=True, ), - ), - reference=dict( - model_path='internlm/internlm2-chat-1_8b-sft', - model_type='reference', - use_flash_attn=False, - trainer_config=dict( - torch_dtype=torch.float32, - trainer_type='huggingface', - parallel=dict( - data=dict(size=2, mode='ddp'), - tensor=dict(size=1, mode='1d'), - pipeline=dict(size=1, interleaved_overlap=False), - sequence=False, - ), + 'train_micro_batch_size_per_gpu': TRAIN_MICRO_BATCH_SIZE, + 'gradient_accumulation_steps': ACTOR_GRADIENT_ACC_STEP, + 'train_batch_size': PROMPT_BATCH_SIZE + PRETRAIN_BATCH_SIZE, + }, ), + generator_config=dict(shared_with_trainer=True, ), ), critic=dict( - model_path=critic_model_path, + model_path=None, model_type='critic', - use_flash_attn=False, trainer_config=dict( - torch_dtype='auto', + torch_dtype=MODEL_DTYPE, trainer_type='huggingface', + use_flash_attn=True, + gradient_checkpointing=False, train_kwargs=dict( micro_bsz=1, lr=5e-6, @@ -127,14 +133,14 @@ loss_type='per_seq', ), parallel=dict( - data=dict(size=2, mode='deepspeed'), + data=dict(size=CRITIC_DP_SIZE, mode='deepspeed'), tensor=dict(size=1, mode='1d'), pipeline=dict(size=1, interleaved_overlap=False), sequence=False, ), deepspeed_config={ 'zero_optimization': { - 'stage': 2, + 'stage': ZERO_STAGE, 'offload_param': { 'device': 'none' }, @@ -152,20 +158,36 @@ 'data_types': { 'grad_accum_dtype': 'fp32' }, - 'train_micro_batch_size_per_gpu': 2, - 'gradient_accumulation_steps': 16, - 'train_batch_size': 64 - }), + 'train_micro_batch_size_per_gpu': TRAIN_MICRO_BATCH_SIZE, + 'gradient_accumulation_steps': CRITIC_GRADIENT_ACC_STEP, + 'train_batch_size': PROMPT_BATCH_SIZE, + }, + ), + ), + reference=dict( + model_path='internlm/internlm2-chat-1_8b-sft', + model_type='reference', + trainer_config=dict( + torch_dtype=MODEL_DTYPE, + trainer_type='huggingface', + use_flash_attn=True, + parallel=dict( + data=dict(size=1, mode='ddp'), + tensor=dict(size=1, mode='1d'), + pipeline=dict(size=1, interleaved_overlap=False), + sequence=False, + ), + ), ), reward=dict( - model_path=critic_model_path, + model_path=None, model_type='reward', - use_flash_attn=False, trainer_config=dict( + torch_dtype=MODEL_DTYPE, trainer_type='huggingface', - torch_dtype='auto', + use_flash_attn=True, parallel=dict( - data=dict(size=2, mode='ddp'), + data=dict(size=1, mode='ddp'), tensor=dict(size=1, mode='1d'), pipeline=dict(size=1, interleaved_overlap=False), sequence=False, @@ -175,14 +197,23 @@ ) dataset_config = { - 'num_samples_each_epoch': - 64, - 'max_seq_len': - 1024, + 'prompt_samples_each_epoch': + PROMPT_BATCH_SIZE, + 'max_prompt_len': + MAX_PROMPT_LEN, + 'pretrain_samples_each_epoch': + PRETRAIN_BATCH_SIZE, + 'max_pretrain_len': + MAX_PRETRAIN_LEN, 'random_seed': 1024, - 'ppo_datas': [ + "sample_strategy": "in_data", + "ratio_within_datasets": False, + 'prompt_datasets': [ 'Anthropic/hh-rlhf/helpful-base::1.0', 'Anthropic/hh-rlhf/harmless-base::0.5', ], + 'pretrain_datasets': [ + 'Anthropic/hh-rlhf/helpful-base::1.0', + ], } diff --git a/examples/rlhf/four_model_vllm_8gpu.py b/examples/rlhf/four_model_vllm_8gpu.py index 654f57691..9d8ea67fe 100644 --- a/examples/rlhf/four_model_vllm_8gpu.py +++ b/examples/rlhf/four_model_vllm_8gpu.py @@ -1,5 +1,26 @@ import torch +MAX_PROMPT_LEN = 1024 +MAX_ANSWER_LEN = 1024 +MAX_PRETRAIN_LEN = 8192 + +PROMPT_BATCH_SIZE = 256 +PRETRAIN_BATCH_SIZE = 32 + +GENERATE_MICRO_BATCH_SIZE = 16 +AC_INFER_MICRO_BATCH_SIZE = 8 +REF_INFER_MICRO_BATCH_SIZE = 8 +TRAIN_MICRO_BATCH_SIZE = 2 + +ZERO_STAGE = 3 +ACTOR_DP_SIZE = 2 +CRITIC_DP_SIZE = 2 +ACTOR_GRADIENT_ACC_STEP = (PROMPT_BATCH_SIZE + PRETRAIN_BATCH_SIZE + ) // ACTOR_DP_SIZE // TRAIN_MICRO_BATCH_SIZE +CRITIC_GRADIENT_ACC_STEP = PROMPT_BATCH_SIZE // CRITIC_DP_SIZE // TRAIN_MICRO_BATCH_SIZE + +MODEL_DTYPE = 'auto' + tokenizer_config = dict( pad_token_id=0, eos_token_id=92542, @@ -7,55 +28,54 @@ ) rollout_config = dict( - actor_micro_bs=32, - reward_micro_bs=32, - clip_reward_min=-5, - clip_reward_max=5, - max_new_tokens=10, - async_reward=True, + actor_micro_bs=GENERATE_MICRO_BATCH_SIZE, + reward_micro_bs=GENERATE_MICRO_BATCH_SIZE, + max_new_tokens=MAX_ANSWER_LEN, + write_to_file=True, generate_kwargs={ 'do_sample': True, 'temperature': 1.0, 'top_k': 0, 'top_p': 0.9, - 'pad_token_id': 0, - 'eos_token_id': 92542, - 'early_stopping': True, - 'num_beams': 1, 'min_new_tokens': 1, - }) + 'num_beams': 1, + 'early_stopping': True, + 'eos_token_id': 92542, + 'pad_token_id': 0, + }, +) repeater_config = dict( - actor_micro_bs=8, - ref_micro_bs=8, - critic_micro_bs=32, - reward_scale=False, - fine_grained_rm=False, - value_ema=False, + actor_micro_bs=AC_INFER_MICRO_BATCH_SIZE, + critic_micro_bs=AC_INFER_MICRO_BATCH_SIZE, + ref_micro_bs=REF_INFER_MICRO_BATCH_SIZE, kl_coeff=0.01, gamma=1.0, gae_lambda=0.99, + clip_reward_min=-5, + clip_reward_max=5, answer_end_id=92542, norm_rewards=True, ) + train_config = dict( - ppo_minibatch=64, - value_minibatch=64, - actor_micro_bs=2, - critic_micro_bs=2, - pretrain_step=0, - save_interval=800, + actor_micro_bs=TRAIN_MICRO_BATCH_SIZE, + critic_micro_bs=TRAIN_MICRO_BATCH_SIZE, + ppo_loss_weight=1.0, + pretrain_loss_weight=0.5, + pretrain_step=20, + save_interval=40, ) -critic_model_path = 'internlm/internlm2-chat-1_8b-sft' model_configs = dict( actor=dict( model_path='internlm/internlm2-chat-1_8b-sft', model_type='actor', - use_flash_attn=False, trainer_config=dict( + torch_dtype=MODEL_DTYPE, trainer_type='huggingface', - torch_dtype=torch.float32, + use_flash_attn=True, + gradient_checkpointing=False, train_kwargs=dict( micro_bsz=1, lr=1e-6, @@ -64,14 +84,14 @@ loss_type='per_seq', ), parallel=dict( - data=dict(size=2, mode='deepspeed'), + data=dict(size=ACTOR_DP_SIZE, mode='deepspeed'), tensor=dict(size=1, mode='1d'), pipeline=dict(size=1, interleaved_overlap=False), sequence=False, ), deepspeed_config={ 'zero_optimization': { - 'stage': 2, + 'stage': ZERO_STAGE, 'offload_param': { 'device': 'none' }, @@ -90,10 +110,11 @@ 'data_types': { 'grad_accum_dtype': 'fp32' }, - 'train_micro_batch_size_per_gpu': 2, - 'gradient_accumulation_steps': 16, - 'train_batch_size': 64 - }), + 'train_micro_batch_size_per_gpu': TRAIN_MICRO_BATCH_SIZE, + 'gradient_accumulation_steps': ACTOR_GRADIENT_ACC_STEP, + 'train_batch_size': PROMPT_BATCH_SIZE + PRETRAIN_BATCH_SIZE, + }, + ), generator_config=dict( shared_with_trainer=False, generator_type='vllm', @@ -105,28 +126,14 @@ ), ), ), - reference=dict( - model_path='internlm/internlm2-chat-1_8b-sft', - model_type='reference', - use_flash_attn=False, - trainer_config=dict( - torch_dtype=torch.float32, - trainer_type='huggingface', - parallel=dict( - data=dict(size=1, mode='ddp'), - tensor=dict(size=1, mode='1d'), - pipeline=dict(size=1, interleaved_overlap=False), - sequence=False, - ), - ), - ), critic=dict( - model_path=critic_model_path, + model_path=None, model_type='critic', - use_flash_attn=False, trainer_config=dict( - torch_dtype='auto', + torch_dtype=MODEL_DTYPE, trainer_type='huggingface', + use_flash_attn=True, + gradient_checkpointing=False, train_kwargs=dict( micro_bsz=1, lr=5e-6, @@ -135,14 +142,14 @@ loss_type='per_seq', ), parallel=dict( - data=dict(size=2, mode='deepspeed'), + data=dict(size=CRITIC_DP_SIZE, mode='deepspeed'), tensor=dict(size=1, mode='1d'), pipeline=dict(size=1, interleaved_overlap=False), sequence=False, ), deepspeed_config={ 'zero_optimization': { - 'stage': 2, + 'stage': ZERO_STAGE, 'offload_param': { 'device': 'none' }, @@ -160,18 +167,34 @@ 'data_types': { 'grad_accum_dtype': 'fp32' }, - 'train_micro_batch_size_per_gpu': 2, - 'gradient_accumulation_steps': 16, - 'train_batch_size': 64 - }), + 'train_micro_batch_size_per_gpu': TRAIN_MICRO_BATCH_SIZE, + 'gradient_accumulation_steps': CRITIC_GRADIENT_ACC_STEP, + 'train_batch_size': PROMPT_BATCH_SIZE, + }, + ), + ), + reference=dict( + model_path='internlm/internlm2-chat-1_8b-sft', + model_type='reference', + trainer_config=dict( + torch_dtype=MODEL_DTYPE, + trainer_type='huggingface', + use_flash_attn=True, + parallel=dict( + data=dict(size=1, mode='ddp'), + tensor=dict(size=1, mode='1d'), + pipeline=dict(size=1, interleaved_overlap=False), + sequence=False, + ), + ), ), reward=dict( - model_path=critic_model_path, + model_path=None, model_type='reward', - use_flash_attn=False, trainer_config=dict( + torch_dtype=MODEL_DTYPE, trainer_type='huggingface', - torch_dtype='auto', + use_flash_attn=True, parallel=dict( data=dict(size=1, mode='ddp'), tensor=dict(size=1, mode='1d'), @@ -183,14 +206,23 @@ ) dataset_config = { - 'num_samples_each_epoch': - 64, - 'max_seq_len': - 1024, + 'prompt_samples_each_epoch': + PROMPT_BATCH_SIZE, + 'max_prompt_len': + MAX_PROMPT_LEN, + 'pretrain_samples_each_epoch': + PRETRAIN_BATCH_SIZE, + 'max_pretrain_len': + MAX_PRETRAIN_LEN, 'random_seed': 1024, - 'ppo_datas': [ + # "sample_strategy": "in_data", + # "ratio_within_datasets": False, + 'prompt_datasets': [ 'Anthropic/hh-rlhf/helpful-base::1.0', 'Anthropic/hh-rlhf/harmless-base::0.5', ], + 'pretrain_datasets': [ + 'Anthropic/hh-rlhf/helpful-base::1.0', + ], } diff --git a/examples/rlhf/quick_start.md b/examples/rlhf/quick_start.md index 823b08ad2..8cc5cb494 100644 --- a/examples/rlhf/quick_start.md +++ b/examples/rlhf/quick_start.md @@ -10,7 +10,7 @@ pip install torch==2.1.2+cu118 torchvision --index-url https://download.pytorch. git clone https://github.com/2581543189/xtuner.git cd xtuner git checkout rlhf -pip install .[rlhf] +pip install '.[rlhf]' ``` ### step2: 使用单引擎(huggingface)启动 rlhf 任务 diff --git a/xtuner/rlhf/dataset/base.py b/xtuner/rlhf/dataset/base.py index b64107d8f..9f9a5cb69 100644 --- a/xtuner/rlhf/dataset/base.py +++ b/xtuner/rlhf/dataset/base.py @@ -153,7 +153,7 @@ def __init__(self, sub_dataset_type='file', tokenizer=None, random_seed=1024, - ratio_within_datas=True): + ratio_within_datasets=True): self._task_group = [] for _task in task_groups: file_path, extra_info = _task.split('::')[0], _task.split('::')[1] @@ -194,9 +194,9 @@ def __init__(self, else: raise NotImplementedError('Cannot support filelist now.') self.random_seed = random_seed - self.ratio_within_datas = ratio_within_datas + self.ratio_within_datasets = ratio_within_datasets - if self.ratio_within_datas: + if self.ratio_within_datasets: sum_prob = sum([task['prob'] for task in self._task_group]) for task in self._task_group: task['prob'] = task['prob'] / sum_prob @@ -220,7 +220,7 @@ def _get_subset_by_ratio(self, dataset: Dataset, ratio: float, seed: int): def __iter__(self): """sample data one task by probs.""" - if self.ratio_within_datas: + if self.ratio_within_datasets: rng = random.Random(self.random_seed) probs = [task['prob'] for task in self._task_group] # Initialize task iterator diff --git a/xtuner/rlhf/dataset/txt_loader.py b/xtuner/rlhf/dataset/txt_loader.py index 5f524206f..cc20a54da 100644 --- a/xtuner/rlhf/dataset/txt_loader.py +++ b/xtuner/rlhf/dataset/txt_loader.py @@ -1,325 +1,267 @@ -"""Finetuning dataset.""" +""" Finetuning dataset. """ import random -from dataclasses import dataclass from typing import List - import numpy as np -from torch.utils.data import DataLoader, IterableDataset, RandomSampler - -from .base import InfiniteDataset, MultiSourceDatset +from dataclasses import dataclass +from torch.utils.data import IterableDataset, DataLoader, RandomSampler +from .base import MultiSourceDatset, InfiniteDataset @dataclass class Message: message: List[dict] - sys_meta: str = 'default' - rm_meta: str = 'default' + sys_meta: str = "default" + rm_meta: str = "default" token_ids: List[int] = None - mes_type: str = 'ppo' + mes_type: str = "prompt" class TxtMessageDataset(IterableDataset): - """Create sequences from dataset. - + """ Create sequences from dataset. Args: - sample_strategy (str) ["in_batch", "in_data"]: - "in_batch": - sample data by ratio for every single training batch - "in_data": - merge all data by ratio first and then sample training batch + sample_strategy (str) ["in_batch", "in_data"]: "in_batch": sample data by ratio for every single training batch + "in_data": merge all data by ratio first and then sample training batch """ - def __init__(self, - ppo_datas: list[str] = None, - pt_datas: list[str] = None, + prompt_datasets: list[str] = None, + pretrain_datasets: list[str] = None, tokenizer=None, - max_seq_len: int = 4096, - num_samples_each_epoch: int = 64, - pt_data_samples: int = 0, + max_prompt_len: int = 4096, + max_pretrain_len: int = 4096, + prompt_samples_each_epoch: int = 64, + pretrain_samples_each_epoch: int = 0, random_seed: int = 110, - sample_strategy: str = 'in_batch', - ratio_within_datas: bool = True, - **kwargs): - - assert sample_strategy in [ - 'in_batch', 'in_data' - ], f"sample_strategy should in ['in_batch', 'in_data'], but got {sample_strategy}" # noqa: E501 + sample_strategy: str = "in_batch", + ratio_within_datasets: bool = True, + **kwargs + ): + assert sample_strategy in ["in_batch", "in_data"], f"sample_strategy should in ['in_batch', 'in_data'], but got {sample_strategy}" self.sample_strategy = sample_strategy - assert ppo_datas is not None, '[Data error] Specify your data task config' # noqa: E501 + assert prompt_datasets is not None, "[Data error] Specify your data task config" self.tokenizer = tokenizer - assert self.tokenizer.chat_template is not None, 'Make sure tokenizer has chat_template.' # noqa: E501 + assert self.tokenizer.chat_template is not None, "Make sure tokenizer has chat_template." - self.ppo_message_dataset = MultiSourceDatset( - task_groups=ppo_datas, - sub_dataset_type='file', - tokenizer=self.tokenizer, - ratio_within_datas=ratio_within_datas) - if pt_data_samples is not None and pt_data_samples != 0: - assert pt_datas is not None, f'[PT DATA error] samples num {pt_data_samples}, while pt_datas is None' # noqa: E501 - self.pt_message_dataset = MultiSourceDatset( - task_groups=pt_datas, - sub_dataset_type='file', - tokenizer=self.tokenizer, - ratio_within_datas=ratio_within_datas) - self.pt_data_per_epoch = pt_data_samples - self.ppo_data_per_epoch = num_samples_each_epoch - self.pt_data_per_epoch # noqa: E501 + self.prompt_message_dataset = MultiSourceDatset(task_groups=prompt_datasets, + sub_dataset_type="file", + tokenizer=self.tokenizer, + ratio_within_datasets=ratio_within_datasets + ) + if pretrain_samples_each_epoch is not None and pretrain_samples_each_epoch > 0: + assert pretrain_datasets is not None, f"[PT DATA error] samples num {pretrain_samples_each_epoch}, while pretrain_datasets is None" + self.pt_message_dataset = MultiSourceDatset(task_groups=pretrain_datasets, + sub_dataset_type="file", + tokenizer=self.tokenizer, + ratio_within_datasets=ratio_within_datasets + ) + self.pretrain_samples_each_epoch = pretrain_samples_each_epoch else: self.pt_message_dataset = None - self.pt_data_per_epoch = 0 - self.ppo_data_per_epoch = num_samples_each_epoch - - self.max_seq_len = max_seq_len - self.num_samples_each_epoch = num_samples_each_epoch + self.pretrain_samples_each_epoch = 0 + self.prompt_samples_each_epoch = prompt_samples_each_epoch + self.max_prompt_len = max_prompt_len + self.max_pretrain_len = max_pretrain_len + self.num_samples_each_epoch = self.pretrain_samples_each_epoch + self.prompt_samples_each_epoch + self.random_seed = random_seed self.rng = random.Random(self.random_seed) np.random.seed(self.random_seed) random.seed(self.random_seed) - if self.sample_strategy == 'in_batch': + if self.sample_strategy == "in_batch": self._init_in_batch() - elif self.sample_strategy == 'in_data': + elif self.sample_strategy == "in_data": self._init_in_data() else: - raise NotImplementedError( - f"sample_strategy should in ['in_batch', 'in_data'], but got {sample_strategy}" # noqa: E501 - ) + raise NotImplementedError(f"sample_strategy should in ['in_batch', 'in_data'], but got {sample_strategy}") self.epoch_index = 0 def _init_in_data(self): - print( - '========================= Init in data sampler =========================' # noqa: E501 - ) - if self.pt_data_per_epoch != 0: - assert hasattr(self.pt_message_dataset, 'all_dataset') + print(f"========================= Init in data sampler =========================") + if self.pretrain_samples_each_epoch != 0: + assert hasattr(self.pt_message_dataset, "all_dataset") pt_sampler = RandomSampler(self.pt_message_dataset.all_dataset) - self.pt_dataloader = iter( - DataLoader( - self.pt_message_dataset.all_dataset, - collate_fn=lambda x: x, - sampler=pt_sampler, - batch_size=self.pt_data_per_epoch)) - print( - f'[PT data] pretrain data per epoch: {self.pt_data_per_epoch}') - - assert hasattr(self.ppo_message_dataset, 'all_dataset') - prompt_sampler = RandomSampler(self.ppo_message_dataset.all_dataset) - self.prompt_dataloader = iter( - DataLoader( - self.ppo_message_dataset.all_dataset, - collate_fn=lambda x: x, - sampler=prompt_sampler, - batch_size=self.ppo_data_per_epoch)) + self.pt_dataloader = iter(DataLoader( + self.pt_message_dataset.all_dataset, collate_fn=lambda x: x, sampler=pt_sampler, batch_size=self.pretrain_samples_each_epoch + )) + print(f"[PT data] pretrain data per epoch: {self.pretrain_samples_each_epoch}") - print(f'[PPO data] ppo data per epoch: {self.ppo_data_per_epoch}') - print( - f'[Txt] Training dataset initialized, random seed {self.random_seed}.\n' # noqa: E501 - ) + assert hasattr(self.prompt_message_dataset, "all_dataset") + prompt_sampler = RandomSampler(self.prompt_message_dataset.all_dataset) + self.prompt_dataloader = iter(DataLoader( + self.prompt_message_dataset.all_dataset, collate_fn=lambda x: x, sampler=prompt_sampler, batch_size=self.prompt_samples_each_epoch + )) + print(f"[Prompt data] prompt data per epoch: {self.prompt_samples_each_epoch}") + print(f"[Txt] Training dataset initialized, random seed {self.random_seed}.\n") + def yield_in_data(self): - print( - '========================= yield data from data sampler =========================' # noqa: E501 - ) + print(f"========================= yield data from data sampler =========================") batch_sequence = [] - ppo_sequence, pt_sequence = [], [] - if self.pt_data_per_epoch != 0: - pt_batch_messages = next(self.pt_dataloader) - for index, message in enumerate(pt_batch_messages): - sequence = self._postprocess_sequence(message, mes_type='pt') + prompt_sequence, pretrain_sequence = [], [] + if self.pretrain_samples_each_epoch != 0: + pretrain_batch_messages = next(self.pt_dataloader) + for index, message in enumerate(pretrain_batch_messages): + sequence = self._postprocess_sequence(message, mes_type="pretrain") if sequence is not None: - assert sequence.mes_type == 'pt', f'Data type should be pt, but get {sequence.mes_type}' # noqa: E501 - pt_sequence.append(sequence) - if len(pt_sequence) == self.pt_data_per_epoch: + assert sequence.mes_type == 'pretrain', f"Data type should be pretrain, but get {sequence.mes_type}" + pretrain_sequence.append(sequence) + if len(pretrain_sequence) == self.pretrain_samples_each_epoch: break - assert len( - pt_sequence - ) == self.pt_data_per_epoch, f'{len(pt_sequence)} != {self.pt_data_per_epoch}' # noqa: E501 + assert len(pretrain_sequence) == self.pretrain_samples_each_epoch, f"{len(pretrain_sequence)} != {self.pretrain_samples_each_epoch}" - ppo_batch_messages = next(self.prompt_dataloader) - for index, message in enumerate(ppo_batch_messages): - sequence = self._postprocess_sequence(message, mes_type='ppo') + prompt_batch_messages = next(self.prompt_dataloader) + for index, message in enumerate(prompt_batch_messages): + if message is None: + continue + sequence = self._postprocess_sequence(message, mes_type="prompt") if sequence is not None: - assert sequence.mes_type == 'ppo', f'Data type should be ppo. but get {sequence.mes_type}' # noqa: E501 - ppo_sequence.append(sequence) - if len(ppo_sequence) == self.ppo_data_per_epoch: + assert sequence.mes_type == 'prompt', f"Data type should be prompt. but get {sequence.mes_type}" + prompt_sequence.append(sequence) + if len(prompt_sequence) == self.prompt_samples_each_epoch: break - if len(ppo_sequence) < self.ppo_data_per_epoch: - missed = self.ppo_data_per_epoch - len(ppo_sequence) - print( - f'[Warning] {missed} dirty data, use {missed} data from sampled data...' # noqa: E501 - ) + # TODO, len(prompt_sequence) < self.prompt_samples_each_epoch, random sample from chosen data + if len(prompt_sequence) < self.prompt_samples_each_epoch: + missed = self.prompt_samples_each_epoch - len(prompt_sequence) + print(f"[Warning] {missed} dirty data, use {missed} data from sampled data...") for i in range(missed): - ppo_sequence.append(ppo_sequence[i]) + prompt_sequence.append(prompt_sequence[i]) - assert len( - ppo_sequence - ) == self.ppo_data_per_epoch, f'{len(ppo_sequence)} == {self.ppo_data_per_epoch}' # noqa: E501 + assert len(prompt_sequence) == self.prompt_samples_each_epoch, f"{len(prompt_sequence)} == {self.prompt_samples_each_epoch}" - print( - f'prepare TxtMessageDataset done: {len(ppo_sequence)} ppo & {len(pt_sequence)} pretrain, for epoch {self.epoch_index}.' # noqa: E501 - ) - batch_sequence = ppo_sequence + pt_sequence - assert len( - batch_sequence - ) == self.num_samples_each_epoch, '[Epoch {self.epoch_index}] Wrong data len' # noqa: E501 + print(f"prepare TxtMessageDataset done: {len(prompt_sequence)} prompt & {len(pretrain_sequence)} pretrain, for epoch {self.epoch_index}.") + batch_sequence = prompt_sequence + pretrain_sequence + assert len(batch_sequence) == self.num_samples_each_epoch, "[Epoch {self.epoch_index}] Wrong data len" return batch_sequence def _init_in_batch(self): - print( - '========================= Init in batch sampler =========================' # noqa: E501 - ) + print(f"========================= Init in batch sampler =========================") samples_cnts = [] pt_data_len = 0 - if self.pt_data_per_epoch != 0: + if self.pretrain_samples_each_epoch != 0: for task in self.pt_message_dataset._task_group: - task['target_num_each_epoch'] = int( - task['prob'] * self.pt_data_per_epoch + 0.5) + 1 - inner_dataset = InfiniteDataset(task['dataset'], self.rng) - task['iterator'] = iter(inner_dataset) - samples_cnts.append(task['target_num_each_epoch']) - print( - f"[PT data] {task['filepath']}: task prob: {task['prob']}, " # noqa: E501 - f'ori number of messages: {len(inner_dataset.data)}, ' - f"target_num_each_epoch: {task['target_num_each_epoch']}" - ) # noqa: E501 + task["target_num_each_epoch"] = int(task["prob"] * self.pretrain_samples_each_epoch + 0.5) + 1 + inner_dataset = InfiniteDataset(task["dataset"], self.rng) + task["iterator"] = iter(inner_dataset) + samples_cnts.append(task["target_num_each_epoch"]) + print(f"[Pretrain data] {task['filepath']}: task prob: {task['prob']}, " + f"ori number of messages: {len(inner_dataset.data)}, " + f"target_num_each_epoch: {task['target_num_each_epoch']}") pt_data_len = sum(samples_cnts) - assert pt_data_len >= self.pt_data_per_epoch, f'Make sure there are enough pretrain data, {pt_data_len} >= {self.pt_data_per_epoch}' # noqa: E501 - print( - f'[PT data] pretrain data per epoch: {self.pt_data_per_epoch}, sampled {pt_data_len}' # noqa: E501 - ) - for task in self.ppo_message_dataset._task_group: - task['target_num_each_epoch'] = int( - task['prob'] * self.ppo_data_per_epoch + 0.5) + 1 - inner_dataset = InfiniteDataset(task['dataset'], self.rng) - task['iterator'] = iter(inner_dataset) - samples_cnts.append(task['target_num_each_epoch']) - print(f"{task['filepath']}: task prob: {task['prob']}, " - f'ori number of messages: {len(inner_dataset.data)}, ' - f"target_num_each_epoch: {task['target_num_each_epoch']}") - assert ( - sum(samples_cnts) - pt_data_len - ) >= self.ppo_data_per_epoch, 'Make sure there are enough ppo datas' - print( - f'[PPO data] ppo data per epoch: {self.ppo_data_per_epoch}, sampled: {sum(samples_cnts) - pt_data_len}' # noqa: E501 - ) + # TODO + assert pt_data_len >= self.pretrain_samples_each_epoch, f"Make sure there are enough pretrain datas, {pt_data_len} >= {self.pretrain_samples_each_epoch}" + print(f"[PT data] pretrain data per epoch: {self.pretrain_samples_each_epoch}, sampled {pt_data_len}") - if sum(samples_cnts) <= self.num_samples_each_epoch: - print( - f'[Txt loader] Warning!!! sample nums {sum(samples_cnts)} <= samples {self.num_samples_each_epoch}' # noqa: E501 - ) - print( - f'[Txt] Training dataset initialized, random seed {self.random_seed}.\n' # noqa: E501 - ) + for task in self.prompt_message_dataset._task_group: + task["target_num_each_epoch"] = int(task["prob"] * self.prompt_samples_each_epoch + 0.5) + 1 + inner_dataset = InfiniteDataset(task["dataset"], self.rng) + task["iterator"] = iter(inner_dataset) + samples_cnts.append(task["target_num_each_epoch"]) + print(f"{task['filepath']}: task prob: {task['prob']}, " + f"ori number of messages: {len(inner_dataset.data)}, " + f"target_num_each_epoch: {task['target_num_each_epoch']}") + assert (sum(samples_cnts) - pt_data_len) >= self.prompt_samples_each_epoch, "Make sure there are enough prompt datas" + print(f"[Prompt data] prompt data per epoch: {self.prompt_samples_each_epoch}, sampled: {sum(samples_cnts) - pt_data_len}") + assert sum(samples_cnts) >= self.num_samples_each_epoch, "[Dataset init] sample num error" + # if sum(samples_cnts) <= self.num_samples_each_epoch: + # print(f"[Txt loader] Warning!!! sample nums {sum(samples_cnts)} <= samples {self.num_samples_each_epoch}") + print(f"[Txt] Training dataset initialized, random seed {self.random_seed}.\n") + def yield_in_batch(self): - print( - '========================= yield data from batch sampler =========================' # noqa: E501 - ) + print(f"========================= yield data from batch sampler =========================") batch_sequence = [] - ppo_sequence, pt_sequence = [], [] + prompt_sequence, pretrain_sequence = [], [] # epoch_rng only use in this epoch. epoch_rng = np.random.RandomState(self.epoch_index) # prepare epoch data - if self.pt_data_per_epoch != 0: - pt_batch_messages = [] + # print(f"prepare TxtMessageDataset for epoch {self.epoch_index}...") + if self.pretrain_samples_each_epoch != 0 : + pretrain_batch_messages = [] for task in self.pt_message_dataset._task_group: messages = [] - for _ in range(task['target_num_each_epoch']): - messages.append(next(task['iterator'])) - print( - f"[PT] prepare {len(messages)} data from {task['filepath']}" # noqa: E501 - ) + for _ in range(task["target_num_each_epoch"]): + messages.append(next(task["iterator"])) + print(f"[Pretrain] prepare {len(messages)} data from {task['filepath']}") epoch_rng.shuffle(messages) - pt_batch_messages.extend(messages) - epoch_rng.shuffle(pt_batch_messages) - for index, message in enumerate(pt_batch_messages): - sequence = self._postprocess_sequence(message, mes_type='pt') + pretrain_batch_messages.extend(messages) + # if len(pretrain_batch_messages) == self.pretrain_samples_each_epoch: + # break + epoch_rng.shuffle(pretrain_batch_messages) + for index, message in enumerate(pretrain_batch_messages): + sequence = self._postprocess_sequence(message, mes_type="pretrain") if sequence is not None: - assert sequence.mes_type == 'pt', f'Data type should be pt, but get {sequence.mes_type}' # noqa: E501 - pt_sequence.append(sequence) - if len(pt_sequence) == self.pt_data_per_epoch: + assert sequence.mes_type == 'pretrain', f"Data type should be pretrain, but get {sequence.mes_type}" + pretrain_sequence.append(sequence) + if len(pretrain_sequence) == self.pretrain_samples_each_epoch: break - assert len( - pt_sequence - ) == self.pt_data_per_epoch, f'{len(pt_sequence)} != {self.pt_data_per_epoch}' # noqa: E501 + assert len(pretrain_sequence) == self.pretrain_samples_each_epoch, f"{len(pretrain_sequence)} != {self.pretrain_samples_each_epoch}" - ppo_batch_messages = [] - for task in self.ppo_message_dataset._task_group: + prompt_batch_messages = [] + for task in self.prompt_message_dataset._task_group: messages = [] - for _ in range(task['target_num_each_epoch']): - messages.append(next(task['iterator'])) - print( - f"[PPO] prepare {len(messages)} data from {task['filepath']}") + for _ in range(task["target_num_each_epoch"]): + messages.append(next(task["iterator"])) + print(f"[Prompt] prepare {len(messages)} data from {task['filepath']}") epoch_rng.shuffle(messages) - ppo_batch_messages.extend(messages) - epoch_rng.shuffle(ppo_batch_messages) - for index, message in enumerate(ppo_batch_messages): - sequence = self._postprocess_sequence(message, mes_type='ppo') + prompt_batch_messages.extend(messages) + epoch_rng.shuffle(prompt_batch_messages) + for index, message in enumerate(prompt_batch_messages): + sequence = self._postprocess_sequence(message, mes_type="prompt") if sequence is not None: - assert sequence.mes_type == 'ppo', f'Data type should be ppo. but get {sequence.mes_type}' # noqa: E501 - ppo_sequence.append(sequence) - if len(ppo_sequence) == self.ppo_data_per_epoch: + assert sequence.mes_type == 'prompt', f"Data type should be prompt. but get {sequence.mes_type}" + prompt_sequence.append(sequence) + if len(prompt_sequence) == self.prompt_samples_each_epoch: break - assert len( - ppo_sequence - ) == self.ppo_data_per_epoch, f'{len(ppo_sequence)} == {self.ppo_data_per_epoch}' # noqa: E501 + assert len(prompt_sequence) == self.prompt_samples_each_epoch, f"{len(prompt_sequence)} == {self.prompt_samples_each_epoch}" - print( - f'prepare TxtMessageDataset done: {len(ppo_sequence)} ppo & {len(pt_sequence)} pretrain, for epoch {self.epoch_index}.' # noqa: E501 - ) - batch_sequence = ppo_sequence + pt_sequence - assert len( - batch_sequence - ) == self.num_samples_each_epoch, '[Epoch {self.epoch_index}] Wrong data len' # noqa: E501 + print(f"prepare TxtMessageDataset done: {len(prompt_sequence)} prompt & {len(pretrain_sequence)} pretrain, for epoch {self.epoch_index}.") + batch_sequence = prompt_sequence + pretrain_sequence + assert len(batch_sequence) == self.num_samples_each_epoch, "[Epoch {self.epoch_index}] Wrong data len" return batch_sequence def __iter__(self): while True: - if self.sample_strategy == 'in_batch': + if self.sample_strategy == "in_batch": yield self.yield_in_batch() - elif self.sample_strategy == 'in_data': + elif self.sample_strategy == "in_data": yield self.yield_in_data() self.epoch_index += 1 - def _postprocess_sequence(self, message, mes_type='ppo'): + def _postprocess_sequence(self, message, mes_type=None): """Post process sequence: tokenization & truncation.""" message_data = message['data'] new_meaasage_data = [] - if mes_type == 'ppo': + if mes_type == "prompt": for _ in reversed(range(len(message_data))): - if message_data[_]['role'] == 'user': + if message_data[_]["role"] == "user": new_meaasage_data = message_data[:_ + 1] break - assert new_meaasage_data[-1][ - 'role'] == 'user', f'ppo data last role must user, {new_meaasage_data}' # noqa: E501 - token_ids = self.tokenizer.apply_chat_template( - new_meaasage_data, - tokenize=True, - add_generation_prompt=True, - return_tensors='pt') - elif mes_type == 'pt': + assert new_meaasage_data[-1]["role"] == "user", f"prompt data last role must user, {new_meaasage_data}" + token_ids = self.tokenizer.apply_chat_template(new_meaasage_data, tokenize=True, add_generation_prompt=True, return_tensors="pt") + if token_ids.shape[-1] <= 4 or token_ids.shape[-1] > self.max_prompt_len: + # TODO truncation?? + # raise RuntimeError(f"token_ids is too long: {token_ids.shape[-1]}") + print(f"[TXT Loader] Warning, {mes_type} message {message} is too short or long, skipped...") + return None + elif mes_type == "pretrain": for _ in reversed(range(len(message_data))): - if message_data[_]['role'] == 'assistant': + if message_data[_]["role"] == "assistant": new_meaasage_data = message_data[:_ + 1] break - assert new_meaasage_data[-1][ - 'role'] == 'assistant', f'pretrain data last role must assistant, {new_meaasage_data}' # noqa: E501 - token_ids = self.tokenizer.apply_chat_template( - new_meaasage_data, - tokenize=True, - add_generation_prompt=False, - return_tensors='pt') - if token_ids.shape[-1] <= 4 or token_ids.shape[-1] > self.max_seq_len: - print( - f'[TXT Loader] Warning, {mes_type} message {message} is too short or long, skipped...' # noqa: E501 - ) - return None - return Message( - message=new_meaasage_data, - token_ids=token_ids, - sys_meta=message['sys_meta'], - rm_meta=message['rm_meta'], - mes_type=mes_type) + assert new_meaasage_data[-1]["role"] == "assistant", f"pretrain data last role must assistant, {new_meaasage_data}" + token_ids = self.tokenizer.apply_chat_template(new_meaasage_data, tokenize=True, add_generation_prompt=False, return_tensors="pt") + + if token_ids.shape[-1] <= 4 or token_ids.shape[-1] > self.max_pretrain_len: + # TODO truncation?? + # raise RuntimeError(f"token_ids is too long: {token_ids.shape[-1]}") + print(f"[TXT Loader] Warning, {mes_type} message {message} is too short or long, skipped...") + return None + return Message(message=new_meaasage_data, + token_ids=token_ids, + sys_meta=message['sys_meta'], + rm_meta=message['rm_meta'], + mes_type=mes_type) diff --git a/xtuner/rlhf/envs/prompt_utils.py b/xtuner/rlhf/envs/prompt_utils.py new file mode 100644 index 000000000..54dfab3cf --- /dev/null +++ b/xtuner/rlhf/envs/prompt_utils.py @@ -0,0 +1,46 @@ + + +META_PROMPT = { + 'default': + '', + 'helpful': + 'Consider the helpfulness of the given sentence. ' + 'A helpful sentence is one that provides value, offers insight, or assists the user in some way. ' # noqa: E501 + 'It may solve a problem, answer a question, provide relevant information, or offer guidance. ' # noqa: E501 + "It should directly address the user's needs or inquiries and is ideally actionable, " # noqa: E501 + 'meaning the user can use the information to make decisions or take steps towards their goal. ' # noqa: E501 + 'The score should be higher for sentences that are more beneficial and informative to the user. ' # noqa: E501 + 'Here is the conversation:\n\n', + 'harmless': + 'Evaluate the harmlessness of the provided sentence. ' + 'A harmless sentence is one that does not offend, insult, or cause distress to the user, ' # noqa: E501 + "and does not mislead or give false information. It respects the user's feelings, beliefs, " # noqa: E501 + 'and values, and does not provoke conflict or negative emotions. ' + 'Additionally, a harmless sentence should not incite harmful actions or behaviors. ' # noqa: E501 + 'It should uphold the principles of respect, honesty, and fairness. ' + 'Sentences that are more considerate, respectful, and accurate should receive higher scores on this scale. ' # noqa: E501 + 'Here is the conversation:\n\n', + 'hhh': + 'Evaluate the given sentence based on its helpfulness, harmlessness, and honesty. ' # noqa: E501 + 'A helpful sentence provides value, solves problems, or aids the user in an actionable manner. ' # noqa: E501 + "A harmless sentence respects the user's feelings and beliefs, avoids causing distress, " # noqa: E501 + 'and does not mislead or incite harmful behaviors. An honest sentence delivers reliable and true information, ' # noqa: E501 + 'presents facts objectively, and demonstrates integrity and authenticity. Higher scores should be assigned ' # noqa: E501 + 'to sentences that embody these characteristics more strongly. ' + 'Here is the conversation:\n\n', + 'summarization': + 'As a language model performing a summarization task, your goal is to generate a summary that ' # noqa: E501 + 'accurately, succinctly, and coherently encapsulates the key details of the source text. Ensure relevance to ' # noqa: E501 + 'the original material, completeness of main points, and logical structure. Maintain conciseness and high ' # noqa: E501 + 'linguistic standards. Ensure only the summary is outputted, refraining from adding extraneous comments or ' # noqa: E501 + 'remarks. Here is the original material:\n\n', + 'reddit': + 'Imagine you are a knowledgeable and friendly Reddit user. ' + 'A fellow Redditor has just shared a post seeking feedback, advice, or input. ' # noqa: E501 + 'Please read the post and provide a thoughtful, informative, and respectful response, ' # noqa: E501 + 'just as if you were replying on the platform. Here is the post:\n\n', + 'latex': + 'When mathematical content appears in the conversation, please use latex format to express the mathematical content. Here is the conversation:\n\n', # noqa: E501 + 'math_ci': + "Integrate step-by-step reasoning and Python code to solve math problems using the following guidelines:\n- Just write jupyter code to solve the problem without giving your thought;\n- Present the final result in LaTeX using a '\\boxed\\{{}}' without any units. \n", # noqa: E501 +} diff --git a/xtuner/rlhf/envs/txt_env.py b/xtuner/rlhf/envs/txt_env.py index 1a01d0741..7490151b8 100644 --- a/xtuner/rlhf/envs/txt_env.py +++ b/xtuner/rlhf/envs/txt_env.py @@ -6,51 +6,7 @@ from torch.utils.data import IterableDataset from ..model_server.base_model_server import BaseModelServer - -META_PROMPT = { - 'default': - '', - 'helpful': - 'Consider the helpfulness of the given sentence. ' - 'A helpful sentence is one that provides value, offers insight, or assists the user in some way. ' # noqa: E501 - 'It may solve a problem, answer a question, provide relevant information, or offer guidance. ' # noqa: E501 - "It should directly address the user's needs or inquiries and is ideally actionable, " # noqa: E501 - 'meaning the user can use the information to make decisions or take steps towards their goal. ' # noqa: E501 - 'The score should be higher for sentences that are more beneficial and informative to the user. ' # noqa: E501 - 'Here is the conversation:\n\n', - 'harmless': - 'Evaluate the harmlessness of the provided sentence. ' - 'A harmless sentence is one that does not offend, insult, or cause distress to the user, ' # noqa: E501 - "and does not mislead or give false information. It respects the user's feelings, beliefs, " # noqa: E501 - 'and values, and does not provoke conflict or negative emotions. ' - 'Additionally, a harmless sentence should not incite harmful actions or behaviors. ' # noqa: E501 - 'It should uphold the principles of respect, honesty, and fairness. ' - 'Sentences that are more considerate, respectful, and accurate should receive higher scores on this scale. ' # noqa: E501 - 'Here is the conversation:\n\n', - 'hhh': - 'Evaluate the given sentence based on its helpfulness, harmlessness, and honesty. ' # noqa: E501 - 'A helpful sentence provides value, solves problems, or aids the user in an actionable manner. ' # noqa: E501 - "A harmless sentence respects the user's feelings and beliefs, avoids causing distress, " # noqa: E501 - 'and does not mislead or incite harmful behaviors. An honest sentence delivers reliable and true information, ' # noqa: E501 - 'presents facts objectively, and demonstrates integrity and authenticity. Higher scores should be assigned ' # noqa: E501 - 'to sentences that embody these characteristics more strongly. ' - 'Here is the conversation:\n\n', - 'summarization': - 'As a language model performing a summarization task, your goal is to generate a summary that ' # noqa: E501 - 'accurately, succinctly, and coherently encapsulates the key details of the source text. Ensure relevance to ' # noqa: E501 - 'the original material, completeness of main points, and logical structure. Maintain conciseness and high ' # noqa: E501 - 'linguistic standards. Ensure only the summary is outputted, refraining from adding extraneous comments or ' # noqa: E501 - 'remarks. Here is the original material:\n\n', - 'reddit': - 'Imagine you are a knowledgeable and friendly Reddit user. ' - 'A fellow Redditor has just shared a post seeking feedback, advice, or input. ' # noqa: E501 - 'Please read the post and provide a thoughtful, informative, and respectful response, ' # noqa: E501 - 'just as if you were replying on the platform. Here is the post:\n\n', - 'latex': - 'When mathematical content appears in the conversation, please use latex format to express the mathematical content. Here is the conversation:\n\n', # noqa: E501 - 'math_ci': - "Integrate step-by-step reasoning and Python code to solve math problems using the following guidelines:\n- Just write jupyter code to solve the problem without giving your thought;\n- Present the final result in LaTeX using a '\\boxed\\{{}}' without any units. \n", # noqa: E501 -} +from .prompt_utils import META_PROMPT class TxtEnv: @@ -62,8 +18,6 @@ def __init__( max_new_tokens: int = 1024, actor_micro_bs: int = 32, reward_micro_bs: int = 32, - clip_reward_min: int = -5, - clip_reward_max: int = 5, reward_function: BaseModelServer = None, async_reward: bool = True, generate_kwargs: dict = None, @@ -80,15 +34,13 @@ def __init__( self.max_new_tokens = max_new_tokens self.actor_micro_bs = actor_micro_bs self.reward_micro_bs = reward_micro_bs - self.clip_reward_min = clip_reward_min - self.clip_reward_max = clip_reward_max self.async_reward = async_reward self.generate_kwargs: dict = generate_kwargs def rollout(self, policy_model: BaseModelServer, display=False): sample_data = deepcopy(next(self.dataloader)) - ppo_input_messages = [] - pt_input_messages = [] + prompt_input_messages = [] + pretrain_input_messages = [] for data in sample_data: if data.sys_meta != 'default': message = deepcopy([{ @@ -97,23 +49,23 @@ def rollout(self, policy_model: BaseModelServer, display=False): }] + data.message) else: message = deepcopy(data.message) - if data.mes_type == 'ppo': - ppo_input_messages.append(message) - elif data.mes_type == 'pt': - pt_input_messages.append(message) + if data.mes_type == 'prompt': + prompt_input_messages.append(message) + elif data.mes_type == 'pretrain': + pretrain_input_messages.append(message) else: raise TypeError(f'Wrong message type {data.mes_type}') - # ppo data + # prompt data s_t = time.time() - print(f'[For Generate]: {ppo_input_messages[0]}') + print(f'[For Generate]: {prompt_input_messages[0]}') trajectories = policy_model.generate( - inputs=ppo_input_messages, + inputs=prompt_input_messages, micro_batch_size=self.actor_micro_bs, step=self.max_new_tokens, output_str=True, generate_kwargs=self.generate_kwargs) logger.info( - f'[actor generate] duration: {round(time.time() - s_t, 2)} s, len(inputs): {len(ppo_input_messages)} ' # noqa: E501 + f'[actor generate] duration: {round(time.time() - s_t, 2)} s, len(inputs): {len(prompt_input_messages)} ' # noqa: E501 ) if self.async_reward: @@ -122,25 +74,23 @@ def rollout(self, policy_model: BaseModelServer, display=False): trajectories['reward_output_ref'] = reward_output_ref else: rewards = self.get_reward(sample_data, trajectories) - clipped_rewards = torch.clamp( - rewards, min=self.clip_reward_min, max=self.clip_reward_max) trajectories['rewards'] = rewards - trajectories['clipped_rewards'] = clipped_rewards # pretrain data - if len(pt_input_messages) > 0: - pt_inputs = [ - policy_model.tokenizer.apply_chat_template( - mes, - tokenize=False, - add_generation_prompt=False, - return_tensors='pt') for mes in pt_input_messages - ] - trajectories.pt_data = policy_model.tokenizer( - pt_inputs, return_tensors='pt', padding=True) + if len(pretrain_input_messages) > 0: + from ..tokenizer import tokenizer_utils + pretrain_input_ids, pretrain_attention_mask = tokenizer_utils.encode( + pretrain_input_messages, policy_model.tokenizer) + pretrain_labels = torch.nn.functional.pad(pretrain_input_ids[:, 1:], (0, 1), mode="constant", value=-100) + + trajectories.pretrain_data = {"input_ids": pretrain_input_ids, + "labels": pretrain_labels, + "attention_mask": pretrain_attention_mask} print( - f'[TxtEnv & {policy_model.__class__.__name__}] gets {len(pt_input_messages)} pretrain episodes.' # noqa: E501 + f'[TxtEnv & {policy_model.__class__.__name__}] gets {len(pretrain_input_messages)} pretrain episodes.' # noqa: E501 ) + else: + trajectories.pretrain_data = None return trajectories @@ -149,6 +99,8 @@ def get_reward_async(self, sample_data, policyout): s_t = time.time() rm_input_messages = [] for i in range(len(sample_data)): + if sample_data[i].mes_type != "prompt": + continue if sample_data[i].rm_meta != 'default': cur_rm_data = [{ 'role': 'system', @@ -190,6 +142,8 @@ def get_reward(self, sample_data, policyout): s_t = time.time() rm_input_messages = [] for i in range(len(sample_data)): + if sample_data[i].mes_type != "prompt": + continue if sample_data[i].rm_meta != 'default': cur_rm_data = [{ 'role': 'system', diff --git a/xtuner/rlhf/loss/actor_loss.py b/xtuner/rlhf/loss/actor_loss.py index cd81c97db..e5c05cc01 100644 --- a/xtuner/rlhf/loss/actor_loss.py +++ b/xtuner/rlhf/loss/actor_loss.py @@ -53,12 +53,12 @@ def forward(self, logits: torch.Tensor, labels: dict[str, Any]): Tensor: Return the final loss """ assert logits.ndim == 3 - mask = labels['mask'] # (micro_bsz, seqlen) + mask = labels['mask'] assert logits.shape[0] == labels['input_ids'].shape[0] - input_ids = labels['input_ids'] # (micro_bsz, seqlen) - old_logprobs = labels['old_logprobs'] # (micro_bsz, seqlen) - advantages = labels['advantages'] # (micro_bsz, seqlen) + input_ids = labels['input_ids'] + old_logprobs = labels['old_logprobs'] + advantages = labels['advantages'] loss_factor = labels['loss_factor'] logpy = logprobs_from_logits( diff --git a/xtuner/rlhf/loss/critic_loss.py b/xtuner/rlhf/loss/critic_loss.py index 877c21c28..3ad4e2db6 100644 --- a/xtuner/rlhf/loss/critic_loss.py +++ b/xtuner/rlhf/loss/critic_loss.py @@ -7,7 +7,7 @@ class CriticLoss(torch.nn.Module): """Loss function for critic model.""" def __init__(self, - cliprange_value: float = 100, + cliprange_value: float = 0.5, loss_type: str = 'per_seq'): super().__init__() self.cliprange_value = cliprange_value @@ -53,12 +53,12 @@ def forward(self, values: torch.Tensor, labels: dict[str, Any]): Tensor: Return the final loss """ assert values.ndim == 2 - mask = labels['mask'] # (micro_bsz, seqlen) + mask = labels['mask'] num_actions = mask.size(1) values = values[:, -num_actions:] - old_values = labels['old_values'] # (micro_bsz, seqlen) - returns = labels['returns'] # (micro_bsz, seqlen) + old_values = labels['old_values'] + returns = labels['returns'] loss_factor = labels['loss_factor'] loss = self.critic_loss_fn( values=values, diff --git a/xtuner/rlhf/loss/pretrain_loss.py b/xtuner/rlhf/loss/pretrain_loss.py index fe08d2a0b..6356291d0 100644 --- a/xtuner/rlhf/loss/pretrain_loss.py +++ b/xtuner/rlhf/loss/pretrain_loss.py @@ -1,36 +1,20 @@ import torch from loguru import logger -try: - from flash_attn.losses.cross_entropy import \ - CrossEntropyLoss as FlashCrossEntropyLoss -except ImportError: - pass - -# Adapted from: https://gitlab.pjlab.org.cn/openmmlab/bigmodel/rl3m/-/blob/main/rl3m/layers/loss.py#L37 # noqa: E501 -class FlashGPTLMLoss(torch.nn.Module): +class PretrainLoss(torch.nn.Module): """Loss function for flash GPT Language Model.""" - def __init__(self, parallel_output=True, label_smoothing=0): + def __init__(self, label_smoothing=0): super().__init__() if label_smoothing is not None and label_smoothing != 0: logger.warning(f'Use label_smoothing: {label_smoothing}') self.label_smoothing = label_smoothing - if parallel_output: - # The loss in this place is bound to the gather_output initialized by VocabParallelClassifier1D # noqa: E501 - self.loss_fn = FlashCrossEntropyLoss( - reduction='mean', - inplace_backward=True, - process_group=None, - label_smoothing=label_smoothing, - ) - else: - # Here, the output will gather output is set in the model, so use ordinary loss # noqa: E501 - self.loss_fn = torch.nn.CrossEntropyLoss( - reduction='mean', label_smoothing=label_smoothing) + # Here, the output will gather output is set in the model, so use ordinary loss # noqa: E501 + self.loss_fn = torch.nn.CrossEntropyLoss( + reduction='mean', label_smoothing=label_smoothing) def forward(self, *args): if len(args) == 3: @@ -50,16 +34,3 @@ def forward(self, *args): return loss - -# Adapted from: https://gitlab.pjlab.org.cn/openmmlab/bigmodel/rl3m/-/blob/main/rl3m/layers/loss.py#L37 # noqa: E501 -class PretrainLoss(FlashGPTLMLoss): - """Modified from pretrain/sft loss, but with a loss factor term to balance - with ppo policy loss.""" - - def __init__(self, *args, loss_factor=1.0, **kwargs): - super().__init__(*args, **kwargs) - self.loss_factor = loss_factor - - def forward(self, *args, **kwargs): - loss = super().forward(*args, **kwargs) - return loss * self.loss_factor diff --git a/xtuner/rlhf/main.py b/xtuner/rlhf/main.py index cf3812fcc..ea64cc36a 100644 --- a/xtuner/rlhf/main.py +++ b/xtuner/rlhf/main.py @@ -22,7 +22,7 @@ def parse_args(): '--config', help='config file name or path.', type=str, - default='examples/rlhf/four_model_8gpu.py') + default='examples/rlhf/four_model_vllm_8gpu.py') parser.add_argument( '-w', '--work_dir', @@ -50,17 +50,15 @@ def validate_config(config: Config): assert args.config is not None, 'config should not be None' work_dir = args.work_dir if work_dir is None: - work_dir = os.getcwd() + work_dir = os.getcwd() + '/rlhf_trainlog_' + time.strftime( + '%Y-%m-%d-%H:%M:%S') work_dir = os.path.abspath(work_dir) logger.info(f'using work_dir: {work_dir}') os.makedirs(work_dir, exist_ok=True) logger.add( - f'{work_dir}/train.log', + f'{work_dir}/train_rlhf.log', filter=lambda record: record['extra'].get('name') == 'train') - logger.add( - f'{work_dir}/rollout.log', - filter=lambda record: record['extra'].get('name') == 'rollout') logger_train = logger.bind(name='train') configs_path = args.config @@ -131,7 +129,7 @@ def validate_config(config: Config): # # for value & policy learn value_loss_ref = ppo.value_learn_async(trajectories, critic_model) - ppo_loss = 0.0 + ppo_loss, pt_loss = None, None if pretrain_step <= 0: ppo_loss, pt_loss = ppo.policy_learn(trajectories, actor_model) logger_train.info( @@ -145,8 +143,14 @@ def validate_config(config: Config): pretrain_step -= 1 if config['rollout_config'].get('write_to_file', True): - with open(f'{work_dir}/rollout.log', 'a') as file: - file.write(f'generates: {trajectories.output_str}') + if not os.path.exists(f'{work_dir}/rollouts'): + os.makedirs(f'{work_dir}/rollouts') + with open(f'{work_dir}/rollouts/step{step}_rollout.log', + 'a') as file: + for output_s, r in zip(trajectories.output_str, + trajectories.rewards): + file.write(output_s + '\n' + 'Reward: ' + str(r.item()) + + '\n' + '=' * 30 + '\n') summaries = dict( reward_mean=trajectories.rewards.mean().item(), reward_std=trajectories.rewards.std().item(), @@ -158,9 +162,10 @@ def validate_config(config: Config): entropy=trajectories.entropy.mean().item(), step=step, policy_loss=ppo_loss, + pretrain_loss=pt_loss, critic_loss=value_loss, ) - with open(f'{work_dir}/train.log.jsonl', 'a') as f: + with open(f'{work_dir}/train_rlhf.log.jsonl', 'a') as f: f.write(json.dumps(summaries) + '\n') step += 1 diff --git a/xtuner/rlhf/model_backend/generate_utils.py b/xtuner/rlhf/model_backend/generate_utils.py index e88995d28..15adb5640 100644 --- a/xtuner/rlhf/model_backend/generate_utils.py +++ b/xtuner/rlhf/model_backend/generate_utils.py @@ -36,6 +36,7 @@ def partition_by_micro_batch_size( input_ids: Union[list[str], torch.Tensor, list[int]], micro_batch_size: int, attention_mask: torch.Tensor = None, + position_ids: torch.Tensor = None, labels: Optional[Union[list[torch.Tensor], torch.Tensor, dict[str, torch.Tensor]]] = None, ) -> list[dict[str, torch.Tensor]]: @@ -46,6 +47,7 @@ def partition_by_micro_batch_size( micro_batch = {} micro_batch['input_ids'] = input_ids micro_batch['attention_mask'] = attention_mask + micro_batch['position_ids'] = position_ids micro_batch['labels'] = labels micro_batches.append(micro_batch) return micro_batches @@ -64,6 +66,9 @@ def partition_by_micro_batch_size( attention_mask_split = ( torch.split(attention_mask, micro_batch_size, dim=0) if attention_mask is not None else [None for _ in range(num_splits)]) + position_ids_split = ( + torch.split(position_ids, micro_batch_size, dim=0) + if position_ids is not None else [None for _ in range(num_splits)]) labels_split = ( partition_label_by_micro_batch_size(labels, micro_batch_size, num_splits) @@ -72,6 +77,7 @@ def partition_by_micro_batch_size( micro_batch = {} micro_batch['input_ids'] = input_ids_split[i] micro_batch['attention_mask'] = attention_mask_split[i] + micro_batch['position_ids'] = position_ids_split[i] micro_batch['labels'] = labels_split[i] micro_batches.append(micro_batch) return micro_batches @@ -108,33 +114,34 @@ def partition_list_by_micro_batch_size( micro_batch_size: list[int], labels: list[torch.Tensor], attention_mask: Optional[list[torch.Tensor]] = None, - loss_weights: Optional[list[float]] = None, + position_ids: Optional[list[torch.Tensor]] = None, ) -> list[dict]: length = len(input_ids) batch_size = input_ids[0].shape[0] num_splits = int(batch_size // micro_batch_size[0]) + ( batch_size % micro_batch_size[0] > 0) micro_batches = [[{} for i in range(length)] for _ in range(num_splits)] - if loss_weights is None: - loss_weights = [None for _ in range(length)] if attention_mask is None: attention_mask = [None for _ in range(length)] + if position_ids == None: + position_ids = [None for _ in range(length)] for i in range(length): sub_input_ids = input_ids[i] sub_attention_mask = attention_mask[i] + sub_position_ids = position_ids[i] sub_labels = labels[i] - sub_loss_weights = loss_weights[i] sub_micro_batches = partition_by_micro_batch_size( - sub_input_ids, micro_batch_size[i], sub_attention_mask, sub_labels) + sub_input_ids, micro_batch_size[i], sub_attention_mask, + sub_position_ids, sub_labels) for micro_batch_index, sub_micro_batch in enumerate(sub_micro_batches): micro_batches[micro_batch_index][i]['input_ids'] = sub_micro_batch[ 'input_ids'] micro_batches[micro_batch_index][i][ 'attention_mask'] = sub_micro_batch['attention_mask'] + micro_batches[micro_batch_index][i][ + 'position_ids'] = sub_micro_batch['position_ids'] micro_batches[micro_batch_index][i]['labels'] = sub_micro_batch[ 'labels'] - micro_batches[micro_batch_index][i][ - 'loss_weights'] = sub_loss_weights return micro_batches diff --git a/xtuner/rlhf/model_backend/hf_model_runner.py b/xtuner/rlhf/model_backend/hf_model_runner.py index ca6a826f5..b873a3497 100644 --- a/xtuner/rlhf/model_backend/hf_model_runner.py +++ b/xtuner/rlhf/model_backend/hf_model_runner.py @@ -146,71 +146,30 @@ def initialize(self): f'[{self.model_type}] __init__() done with optimizer {self.optimizer.optimizer}.' # noqa: E501 ) - # Training - def compute_loss_and_backward( - self, - input_ids: Union[list[torch.Tensor], torch.Tensor], - labels: Optional[Union[list[torch.Tensor], torch.Tensor, - dict[str, torch.Tensor]]] = None, - attention_mask: Optional[Union[list[torch.Tensor], - torch.Tensor]] = None, - criterion: Optional[Union[list[_Loss], _Loss]] = None, - loss_weights: Optional[list[float]] = None, - gradient_accumulation_steps=1, - **_ignored, - ) -> tuple[torch.Tensor, list[torch.Tensor]]: - """ - criterion: _Loss class, e.g., torch.nn.CrossEntropyLoss() - """ - if isinstance(input_ids, torch.Tensor): # returns torch.Tensor - # rarely, since self.train() changes all input_ids to [input_ids] - loss = self.compute_loss(input_ids, labels, attention_mask, - criterion) - self.accelerator.backward(loss) - return loss - - elif type(input_ids) == list: # returns list[torch.Tensor] - # multiple inputs grouped to compute loss, see: - # https://stackoverflow.com/questions/53994625/how-can-i-process-multi-loss-in-pytorch - assert ( - len(input_ids) == len(labels) == len(criterion) == - len(attention_mask) == len(loss_weights) - ), f'{len(input_ids)} {len(labels)} {len(criterion)} {len(attention_mask)} {len(loss_weights)} must equal' # noqa: E501 - loss_list = [0 for _ in range(len(input_ids))] - loss_weights = [ - x / float(len(loss_weights)) for x in loss_weights - ] # to 1 - - loss_sum = 0 - for i in range(len(input_ids)): - with self.accelerator.autocast(): - loss = self.compute_loss(input_ids[i], labels[i], - attention_mask[i], criterion[i]) - loss_sum += loss * loss_weights[i] - loss_list[i] = loss - self.accelerator.backward(loss_sum) - return loss_list - - else: - raise NotImplementedError(f'unknown input {input_ids}') - def compute_loss( self, input_ids: torch.Tensor, labels: Optional[Union[torch.Tensor, dict[str, torch.Tensor]]] = None, attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, criterion: Optional[_Loss] = None, loss_weight: Optional[float] = None, **_ignored, ) -> torch.Tensor: input_ids = input_ids.to(self.device) labels = input_ids.clone() if labels is None else labels - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) + if attention_mask is not None: + if position_ids is None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) batch = { - 'input_ids': input_ids, - 'attention_mask': attention_mask, - 'position_ids': position_ids.to(self.device) + 'input_ids': + input_ids, + 'attention_mask': + attention_mask.to(self.device) + if attention_mask is not None else None, + 'position_ids': + position_ids.to(self.device) if position_ids is not None else None } self.model.train() @@ -226,19 +185,12 @@ def compute_loss( # OPT. B) Use preset loss functions, e.g., torch.nn.CrossEntropyLoss() # noqa: E501 # Adopted from: https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L1199 # noqa: E501 logits: torch.Tensor = self.model(**batch, use_cache=False).logits - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - shift_logits = shift_logits.view(-1, self.vocab_size) - shift_labels = shift_labels.view(-1) - shift_labels = shift_labels.to( - shift_logits.device) # enable model para - # loss_fct = criterion() - loss = criterion(shift_logits, shift_labels) + labels = labels.to(self.device) + loss = criterion(logits, labels) elif isinstance(labels, dict): # OPT. C) Use customized loss function, see loss/actor_loss.py logits: torch.Tensor = self.model( **batch, use_cache=False, return_dict=True).logits - # loss_fct = criterion() for k, v in labels.items(): labels[k] = v.to(self.device) loss = criterion(logits, labels) @@ -266,6 +218,7 @@ def train( dict[str, torch.Tensor]]] = None, attention_mask: Optional[Union[list[torch.Tensor], torch.Tensor]] = None, + position_ids: Optional[Union[list[torch.Tensor], torch.Tensor]] = None, criterion: Optional[Union[list[_Loss], _Loss]] = None, loss_weights: Optional[Union[list[float], float]] = None, step_interval: int = 1, @@ -280,58 +233,66 @@ def train( input_ids = [input_ids] labels = [labels] attention_mask = [attention_mask] + position_ids = [position_ids] criterion = [criterion] - loss_weights = [1] if loss_weights is None else [loss_weights] - micro_batch_size = None if micro_batch_size is None else [ - micro_batch_size - ] - return_list = False - - if micro_batch_size is None: - for i in range(len(input_ids)): - self.info_rank0( - f'[{self.model_type}] train input_ids[{i}] shape[{input_ids[i].shape}]' # noqa: E501 - ) - origin_loss = self.compute_loss_and_backward( - input_ids, labels, attention_mask, criterion, loss_weights) + loss_weights = [loss_weights] + micro_batch_size = [micro_batch_size] else: - assert isinstance(input_ids, list) - micro_batches = partition_list_by_micro_batch_size( - input_ids, micro_batch_size, labels, attention_mask, - loss_weights) - origin_loss_list_mb = [] - for index, micro_batch in enumerate(micro_batches): - input_ids_mb = [] - attention_mask_mb = [] - labels_mb = [] - loss_weights_mb = [] - for i in range(len(micro_batch)): - input_ids_mb.append(micro_batch[i]['input_ids'].to( - self.device)) - attention_mask_mb.append( - micro_batch[i]['attention_mask'].to(self.device)) - labels_mb.append(micro_batch[i]['labels']) - loss_weights_mb.append(micro_batch[i]['loss_weights']) - if index == 0: - for i in range(len(input_ids_mb)): - self.info_rank0( - f'[{self.model_type}] will train input_ids_mb[{i}] shape[{input_ids_mb[i].shape}] * {len(micro_batches)} times' # noqa: E501 - ) - origin_loss_mb = self.compute_loss_and_backward( - input_ids_mb, - labels_mb, - attention_mask_mb, - criterion, - loss_weights_mb, - gradient_accumulation_steps=len(micro_batches), + if attention_mask is None: + attention_mask = [None for _ in range(len(input_ids))] + if position_ids is None: + position_ids = [None for _ in range(len(input_ids))] + if criterion is None: + criterion = [None for _ in range(len(input_ids))] + if loss_weights is None: + loss_weights = [None for _ in range(len(input_ids))] + if micro_batch_size is None: + micro_batch_size = [None for _ in range(len(input_ids))] + + assert isinstance(input_ids, list) + + loss_list = [[] for _ in range(len(input_ids))] + for index in range(len(input_ids)): + mb_size_entry = micro_batch_size[index] + if mb_size_entry is None: + micro_batches: list[dict[str, torch.Tensor]] = [] + micro_batches.append({ + 'input_ids': input_ids[index], + 'attention_mask': attention_mask[index], + 'position_ids': position_ids[index], + 'labels': labels[index] + }) + else: + micro_batches = partition_by_micro_batch_size( + input_ids=input_ids[index], + micro_batch_size=micro_batch_size[index], + attention_mask=attention_mask[index], + position_ids=position_ids[index], + labels=labels[index], + ) + loss_entry = [] + for mb_index, micro_batch in enumerate(micro_batches): + if mb_index == 0: + self.info_rank0( + f"[{self.model_type}] will train input_ids[{mb_index}] shape[{micro_batch['input_ids'].shape}] * {len(micro_batches)} times" # noqa: E501 + ) + # compute loss and backward + loss = self.compute_loss( + input_ids=micro_batch['input_ids'], + labels=micro_batch['labels'], + attention_mask=micro_batch['attention_mask'], + position_ids=micro_batch['position_ids'], + criterion=criterion[index], + loss_weight=loss_weights[index], ) - origin_loss_list_mb.append(origin_loss_mb) + self.accelerator.backward(loss) + loss_entry.append(loss) if debug: set_seed(1234) - origin_loss = merge_loss_list(origin_loss_list_mb) + loss_list[index] = sum(loss_entry) / len(loss_entry) self.parameter_update(step_interval) - return origin_loss if return_list else origin_loss[0] + return loss_list if len(loss_list) > 1 else loss_list[0] # Inference @torch.no_grad() @@ -740,18 +701,21 @@ def initialize_get(self): self.initialize_ref = None # Training - def train_async(self, input_ids, labels, attention_mask, *args, **kwargs): + def train_async(self, input_ids, labels, attention_mask, position_ids, + *args, **kwargs): if isinstance(input_ids, torch.Tensor): micro_batch_size = input_ids.shape[0] // self.dp_size + ( input_ids.shape[0] % self.dp_size > 0 ) # round up division, i.e., math.ceil(a / b) micro_batches = partition_by_micro_batch_size( - input_ids, micro_batch_size, attention_mask, labels) + input_ids, micro_batch_size, attention_mask, position_ids, + labels) assert len(micro_batches) == self.dp_size return [ self.ray_actors[index].train.remote( input_ids=micro_batch['input_ids'], attention_mask=micro_batch['attention_mask'], + position_ids=micro_batch['position_ids'], labels=micro_batch['labels'], *args, **kwargs, @@ -762,39 +726,47 @@ def train_async(self, input_ids, labels, attention_mask, *args, **kwargs): assert isinstance(input_ids[0], torch.Tensor) micro_batch_size = [i for i in range(len(input_ids))] for index, input_id in enumerate(input_ids): - micro_batch_size[ - index] = input_id[index].shape[0] // self.dp_size + ( - input_id[index].shape[0] % self.dp_size > 0 - ) # round up division, i.e., math.ceil(a / b) + micro_batch_size[index] = input_id.shape[0] // self.dp_size + ( + input_id.shape[0] % self.dp_size > 0 + ) # round up division, i.e., math.ceil(a / b) micro_batches = partition_list_by_micro_batch_size( - input_ids, self.dp_size, attention_mask, labels) + input_ids=input_ids, + micro_batch_size=micro_batch_size, + labels=labels, + attention_mask=attention_mask, + position_ids=position_ids, + ) + assert len(micro_batches) == self.dp_size object_refs = [] for index, micro_batch in enumerate(micro_batches): input_ids_mb = [] attention_mask_mb = [] + position_ids_mb = [] labels_mb = [] - loss_weights_mb = [] - assert len(micro_batch) == self.dp_size for i in range(len(micro_batch)): input_ids_mb.append(micro_batch[i]['input_ids']) attention_mask_mb.append(micro_batch[i]['attention_mask']) + position_ids_mb.append(micro_batch[i]['position_ids']) labels_mb.append(micro_batch[i]['labels']) - loss_weights_mb.append(micro_batch[i]['loss_weights']) - - object_ref = self.ray_actors[index].train.remote( - inputs=input_ids_mb, - attention_mask=attention_mask_mb, - labels=labels_mb, - loss_weights=loss_weights_mb, - *args, - **kwargs, - ) - object_refs.append(object_ref) - return object_ref + object_ref = self.ray_actors[index].train.remote( + input_ids=input_ids_mb, + attention_mask=attention_mask_mb, + position_ids=position_ids_mb, + labels=labels_mb, + *args, + **kwargs, + ) + object_refs.append(object_ref) + return object_refs def train_get(self, object_refs, timeout=None): losses = ray.get(object_refs, timeout=timeout) - return sum(losses) / len(losses) + if isinstance(losses[0], list): + p_loss = [sub_loss[0] for sub_loss in losses] + pt_loss = [sub_loss[1] for sub_loss in losses] + return [sum(p_loss) / len(p_loss), sum(pt_loss) / len(pt_loss)] + else: + return sum(losses) / len(losses) def train(self, *args, **kwargs): object_refs = self.train_async(*args, **kwargs) diff --git a/xtuner/rlhf/model_server/base_model_server.py b/xtuner/rlhf/model_server/base_model_server.py index ffb2426bd..884482f4e 100644 --- a/xtuner/rlhf/model_server/base_model_server.py +++ b/xtuner/rlhf/model_server/base_model_server.py @@ -107,9 +107,10 @@ def train_async(self, input_ids, labels=None, attention_mask=None, + position_ids=None, *args, **train_kwargs): - return self.trainer.train_async(input_ids, labels, attention_mask, + return self.trainer.train_async(input_ids, labels, attention_mask, position_ids, *args, **train_kwargs) def train_get(self, object_refs, timeout: Optional[float] = None): @@ -119,9 +120,10 @@ def train(self, input_ids, labels=None, attention_mask=None, + position_ids=None, *args, **train_kwargs): - object_refs = self.train_async(input_ids, labels, attention_mask, + object_refs = self.train_async(input_ids, labels, attention_mask, position_ids, *args, **train_kwargs) loss = self.train_get(object_refs) self.log_cuda_mem_stats(remark='[train] ') diff --git a/xtuner/rlhf/repeaters/base.py b/xtuner/rlhf/repeaters/base.py index 0e68600b4..156fe3e02 100644 --- a/xtuner/rlhf/repeaters/base.py +++ b/xtuner/rlhf/repeaters/base.py @@ -1,66 +1,11 @@ import time -import numpy as np import torch from loguru import logger from ..model_server.base_model_server import BaseModelServer from ..policy_output import PolicyOutput - - -def find_mask_begin(padded_datas, mask_id=0): - """finding the mask id begin index and it's length.""" - begin_indexs = [] - lengths = [] - - for padded_data in padded_datas: - is_flag = 0 - for index, data in enumerate(padded_data): - if data != mask_id: - is_flag = 1 - begin_indexs.append(index) - length = (np.array(padded_data) != mask_id).sum() - lengths.append(length) - break - assert is_flag - return begin_indexs, lengths - - -class RunningStates: - # adopt from https://github.com/DLR-RM/stable-baselines3/blob/master/stable_baselines3/common/running_mean_std.py # noqa: E501 - def __init__(self, epsilon: float = 1e-4): - self.mean = torch.tensor(0, dtype=torch.float32) - self.var = torch.tensor(0, dtype=torch.float32) - self.count = epsilon - - def update(self, x: torch.Tensor): - x_var, x_mean = torch.var_mean(x.cpu(), unbiased=False) - x_count = x.shape[0] - self.update_from_moments(x_mean, x_var, x_count) - - def update_from_other(self, other: 'RunningStates'): - self.update_from_moments(other.mean, other.var, other.count) - - def update_from_moments(self, mean: torch.Tensor, var: torch.Tensor, - count: int): - delta = mean - self.mean - tot_count = self.count + count - m_a = self.var * self.count - m_b = var * count - m_2 = m_a + m_b + delta**2 * self.count * count / (self.count + count) - new_var = m_2 / (self.count + count) - - self.mean += delta * count / tot_count - self.var = new_var - self.count = tot_count - - def state_dict(self): - return dict(mean=self.mean, var=self.var, count=self.count) - - def load_state_dict(self, states): - self.mean = states['mean'] - self.var = states['var'] - self.count = states['count'] +from .running_mean_std import RunningStates class BaseRepeater: @@ -68,31 +13,30 @@ class BaseRepeater: def __init__( self, sft_model, - reward_scale: bool = False, - fine_grained_rm: bool = False, - value_ema: bool = False, actor_micro_bs: int = 8, ref_micro_bs: int = 8, critic_micro_bs: int = 32, kl_coeff=0.02, gamma=1.0, gae_lambda=0.95, - answer_end_id=92542, norm_adv=False, + clip_reward_min: int = -5, + clip_reward_max: int = 5, norm_rewards=True, + reward_scale: bool = False, + fine_grained_rm: bool = False, **_ignored, ): self.sft_model = sft_model self.actor_micro_bs = actor_micro_bs self.ref_micro_bs = ref_micro_bs self.critic_micro_bs = critic_micro_bs - self.reward_scale = reward_scale - self.fine_grained_rm = fine_grained_rm - self.value_ema = value_ema self.kl_coeff = kl_coeff self.gamma = gamma self.gae_lambda = gae_lambda - self.answer_end_id = answer_end_id + # rewards + self.clip_reward_min = clip_reward_min + self.clip_reward_max = clip_reward_max self.norm_rewards = norm_rewards if self.norm_rewards: self.running_states = RunningStates(epsilon=0) @@ -158,15 +102,16 @@ def _get_kl_rewards(self, if env.async_reward: rewards = env.get_reward_collect(trajectories['reward_output_ref']) trajectories['reward_output_ref'] = None - clipped_rewards = torch.clamp( - rewards, min=env.clip_reward_min, max=env.clip_reward_max) trajectories['rewards'] = rewards - trajectories['clipped_rewards'] = clipped_rewards # Experimental - rewards = trajectories.clipped_rewards + + clipped_rewards = torch.clamp( + rewards, min=self.clip_reward_min, max=self.clip_reward_max) + trajectories['clipped_rewards'] = clipped_rewards + if self.norm_rewards: - self.running_states.update(rewards) - norm_reward_score = (rewards - self.running_states.mean) / ( + self.running_states.update(clipped_rewards) + norm_reward_score = (clipped_rewards - self.running_states.mean) / ( self.running_states.var.sqrt() + 1e-8) action_mask = trajectories.action_mask num_actions = action_mask.size(1) @@ -232,60 +177,6 @@ def _get_values_collect(self, value_output_ref, ) return raw_values - def _get_advantages_and_returns(self, trajectories): - output_ids = trajectories.output_ids - answer_mask = trajectories.answer_mask - values_with_last_value = trajectories.values_with_last_value - kl_rewards = trajectories.kl_rewards - - begins_index, answers_length = find_mask_begin(answer_mask, 0) - count = 0 - advantages_padded, returns_padded = torch.zeros_like( - kl_rewards, dtype=values_with_last_value.dtype), torch.zeros_like( - kl_rewards, dtype=values_with_last_value.dtype) - for begin_index, ans_len, value_with_last_value, reward, output_id in zip( # noqa: E501 - begins_index, answers_length, values_with_last_value, - kl_rewards, output_ids): - # shape :ans_len + 1 - value_with_last_value = value_with_last_value[begin_index - - 1:begin_index + - ans_len] - # shape :ans_len - reward = reward[begin_index:begin_index + ans_len] - last_gae_lam = torch.zeros((1), dtype=values_with_last_value.dtype) - # shape :ans_len - advantages = torch.zeros_like( - reward, dtype=values_with_last_value.dtype) - step_nums = advantages.shape[-1] - # shape:ans_len + 1 - dones = self._build_dones(output_id[begin_index:begin_index + - ans_len]) - for step in reversed(range(step_nums)): - next_non_terminal = 1 - dones[step + 1] - next_values = value_with_last_value[step + 1] - # delta and last_gae_lam using value and reward - delta = reward[ - step] + self.gamma * next_values * next_non_terminal - value_with_last_value[ # noqa: E501 - step] - last_gae_lam = delta + self.gamma * self.gae_lambda * next_non_terminal * last_gae_lam # noqa: E501 - advantages[step] = last_gae_lam[0] - returns = advantages + value_with_last_value[:-1] - advantages_padded[count, - begin_index:begin_index + ans_len] = advantages - returns_padded[count, begin_index:begin_index + ans_len] = returns - count += 1 - return advantages_padded, returns_padded - - # ans_len + 1: dones - def _build_dones(self, answer_ids): - dones = torch.tensor( - (answer_ids == self.answer_end_id).numpy().astype(np.float32)) - # (1, )the first one is not done, so obs_0_dones=0 - obs_0_dones = torch.zeros((1), dtype=torch.float32) - # (ans_len + 1), - dones = torch.concat((obs_0_dones, dones), axis=0) - return dones - def get_advantages_and_returns( self, values: torch.Tensor, @@ -293,6 +184,24 @@ def get_advantages_and_returns( action_mask: torch.Tensor, ): # Adopted from https://github.com/CarperAI/trlx/blob/main/trlx/models/modeling_ppo.py#L134 # noqa: E501 + """Function that computes advantages and returns from rewards and values. + Calculated as in the original PPO paper: https://arxiv.org/abs/1707.06347 + Note that rewards may include a KL divergence loss term. + + Advantages looks like this: + Adv1 = R1 + γ * λ * R2 + γ^2 * λ^2 * R3 + ... + - V1 + γ * (1 - λ) V2 + γ^2 * λ * (1 - λ) V3 + ... + + Returns looks like this: + Ret1 = R1 + γ * λ * R2 + γ^2 * λ^2 * R3 + ... + + γ * (1 - λ) V2 + γ^2 * λ * (1 - λ) V3 + ... + + Args: + values: Tensor of shape (batch_size, response_size) + rewards: Tensor of shape (batch_size, response_size) + response_length: Length of the response sequence + use_whitening: Whether to use whitening (ie. normalize advantages) or not + """ lastgaelam = 0 advantages_reversed = [] response_length = rewards.size(1) @@ -303,6 +212,8 @@ def get_advantages_and_returns( for t in reversed(range(response_length)): nextvalues = values[:, t + 1] if t < response_length - 1 else 0.0 + # Since old_rewards and old_values are masked with action_mask, i.e. they have + # 0's at pad tokens, delta will be 0 if current t is at a pad token, so will lastgaelam delta = rewards[:, t] + self.gamma * nextvalues - values[:, t] lastgaelam = delta + self.gamma * self.gae_lambda * lastgaelam advantages_reversed.append(lastgaelam) diff --git a/xtuner/rlhf/repeaters/running_mean_std.py b/xtuner/rlhf/repeaters/running_mean_std.py new file mode 100644 index 000000000..e8b3e2763 --- /dev/null +++ b/xtuner/rlhf/repeaters/running_mean_std.py @@ -0,0 +1,38 @@ +import torch + + +class RunningStates: + # adopt from https://github.com/DLR-RM/stable-baselines3/blob/master/stable_baselines3/common/running_mean_std.py # noqa: E501 + def __init__(self, epsilon: float = 1e-4): + self.mean = torch.tensor(0, dtype=torch.float32) + self.var = torch.tensor(0, dtype=torch.float32) + self.count = epsilon + + def update(self, x: torch.Tensor): + x_var, x_mean = torch.var_mean(x.cpu(), unbiased=False) + x_count = x.shape[0] + self.update_from_moments(x_mean, x_var, x_count) + + def update_from_other(self, other: 'RunningStates'): + self.update_from_moments(other.mean, other.var, other.count) + + def update_from_moments(self, mean: torch.Tensor, var: torch.Tensor, + count: int): + delta = mean - self.mean + tot_count = self.count + count + m_a = self.var * self.count + m_b = var * count + m_2 = m_a + m_b + delta**2 * self.count * count / (self.count + count) + new_var = m_2 / (self.count + count) + + self.mean += delta * count / tot_count + self.var = new_var + self.count = tot_count + + def state_dict(self): + return dict(mean=self.mean, var=self.var, count=self.count) + + def load_state_dict(self, states): + self.mean = states['mean'] + self.var = states['var'] + self.count = states['count'] diff --git a/xtuner/rlhf/trainer/ppo.py b/xtuner/rlhf/trainer/ppo.py index a4a81aad6..ccfbeb61f 100644 --- a/xtuner/rlhf/trainer/ppo.py +++ b/xtuner/rlhf/trainer/ppo.py @@ -5,6 +5,7 @@ from ..loss.actor_loss import ActorLoss from ..loss.critic_loss import CriticLoss +from ..loss.pretrain_loss import PretrainLoss from ..model_server.base_model_server import BaseModelServer from ..timer import Timer @@ -13,92 +14,125 @@ class PPOTrainer: def __init__( self, - policy_model, - value_model, actor_micro_bs=2, critic_micro_bs=2, policy_learn_time=1, value_learn_time=1, - ppo_minibatch=512, - value_minibatch=512, - pt_minibatch=None, - train_minibatch=None, - pt_criterion=None, + policy_minibatch=None, + value_minibatch=None, + ppo_loss_weight=1.0, + pretrain_loss_weight=0.5, + pretrain_criterion=PretrainLoss(label_smoothing=0), policy_criterion=ActorLoss(cliprange=0.2, loss_type='per_seq'), value_criterion=CriticLoss(cliprange_value=0.5, loss_type='per_seq'), **kwargs, ): - self.ppo_minibatch = ppo_minibatch - self.value_minibatch = value_minibatch self.actor_micro_bs = actor_micro_bs self.critic_micro_bs = critic_micro_bs # policy - self.policy_model = policy_model self.policy_learn_time = policy_learn_time - self.pt_minibatch = pt_minibatch - self.train_minibatch = train_minibatch - self.policy_minibatch = ppo_minibatch + self.policy_minibatch = policy_minibatch # value - self.value_model = value_model self.value_learn_time = value_learn_time self.value_minibatch = value_minibatch - self.pt_criterion = pt_criterion + self.ppo_loss_weight = ppo_loss_weight + self.pretrain_loss_weight = pretrain_loss_weight + self.pretrain_criterion = pretrain_criterion self.policy_criterion = policy_criterion self.value_criterion = value_criterion def policy_learn(self, trajectories, policy_model: BaseModelServer): + if self.policy_minibatch is None: + self.policy_minibatch = len(trajectories.output_ids) policy_updates = len(trajectories.output_ids) // self.policy_minibatch - policy_loss = [] - pt_loss = [] + ppo_loss = [] + pretrain_loss = [] for _ in range(self.policy_learn_time): for i in range(policy_updates): logger.info( '[Policy Train] start policy trains {}/{} | {}'.format( i + 1, policy_updates, _ + 1)) + # prompt train data begin = i * self.policy_minibatch end = begin + self.policy_minibatch - policy_batch_inputs = { - 'input_ids': trajectories.output_ids[begin:end, :], - 'policy_logprobs': - trajectories.policy_logprobs[begin:end, :], - 'advs': trajectories.advantages[begin:end, :], - 'action_mask': trajectories.action_mask[begin:end, :], - 'attention_mask': trajectories.attention_mask[begin:end, :] - } + + train_input_ids = [ + trajectories.output_ids[begin:end, :], + ] + train_attention_mask = [ + trajectories.attention_mask[begin:end, :], + ] + train_criterion = [ + self.policy_criterion, + ] + loss_weights = [ + self.ppo_loss_weight, + ] + micro_batch_size = [ + self.actor_micro_bs, + ] assert len( - policy_batch_inputs['input_ids'] + trajectories.output_ids[begin:end, :] ) == self.policy_minibatch, '[Policy learn] make sure len(policy_batch_inputs) == self.policy_minibatch' # noqa: E501 loss_factor = 1.0 - labels = dict( - input_ids=policy_batch_inputs['input_ids'], - old_logprobs=policy_batch_inputs['policy_logprobs'], - advantages=policy_batch_inputs['advs'], - mask=policy_batch_inputs['action_mask'], - loss_factor=torch.tensor(loss_factor), - ) + train_lables = [ + dict( + input_ids=trajectories.output_ids[begin:end, :], + old_logprobs=trajectories.policy_logprobs[ + begin:end, :], + advantages=trajectories.advantages[begin:end, :], + mask=trajectories.action_mask[begin:end, :], + loss_factor=torch.tensor(loss_factor), + ), + ] + # pretrain data + if trajectories.pretrain_data is not None: + logger.info( + f'[Policy Train] policy train with pretrain data {trajectories.pretrain_data["input_ids"].shape}' + ) + train_input_ids.append( + trajectories.pretrain_data['input_ids']) + train_lables.append(trajectories.pretrain_data['labels']) + # train_position_ids.append(trajectories.pretrain_data["position_ids"]) + train_attention_mask.append( + trajectories.pretrain_data['attention_mask']) + train_criterion.append(self.pretrain_criterion) + loss_weights.append(self.pretrain_loss_weight) + micro_batch_size.append(self.actor_micro_bs) + s_t = time.time() p_loss = policy_model.train( - input_ids=policy_batch_inputs['input_ids'], - labels=labels, - attention_mask=policy_batch_inputs['attention_mask'], - criterion=self.policy_criterion, - micro_batch_size=self.actor_micro_bs) - - logger.info( - f'[actor train] duration: {round(time.time() - s_t, 2)} s, {self.policy_minibatch} batch, Policy loss: {p_loss.item()}' # noqa: E501 - ) - policy_loss.append(p_loss.item()) + input_ids=train_input_ids, + labels=train_lables, + attention_mask=train_attention_mask, + # position_ids=train_position_ids, + criterion=train_criterion, + loss_weights=loss_weights, + micro_batch_size=micro_batch_size) + if isinstance(p_loss, list): + ppo_loss.append(p_loss[0].item()) + pretrain_loss.append(p_loss[1].item()) + logger.info( + f'[Policy Train] duration: {round(time.time() - s_t, 2)} s, prompt data: {train_input_ids[0].shape}, ppo loss: {p_loss[0].item()}; pretrain data: {train_input_ids[1].shape}, pretrain loss: {p_loss[1].item()}' + ) + else: + ppo_loss.append(p_loss.item()) + logger.info( + f'[Policy Train] duration: {round(time.time() - s_t, 2)} s, prompt data: {train_input_ids[0].shape}, ppo loss: {p_loss.item()}' + ) with Timer('policy_model.sync_model'): policy_model.sync_model() - return policy_loss, pt_loss + return ppo_loss, pretrain_loss def value_learn_async(self, trajectories, value_model: BaseModelServer): + if self.value_minibatch is None: + self.value_minibatch = len(trajectories.output_ids) value_updates = len(trajectories.output_ids) // self.value_minibatch value_loss = [] assert value_updates == 1 and self.policy_learn_time == 1, f'value_updates={value_updates} * self.policy_learn_time={self.policy_learn_time} > 1' # noqa: E501 @@ -125,6 +159,8 @@ def value_learn_get(self, value_loss_ref, value_model: BaseModelServer): ] def value_learn(self, trajectories, value_model: BaseModelServer): + if self.value_minibatch is None: + self.value_minibatch = len(trajectories.output_ids) value_updates = len(trajectories.output_ids) // self.value_minibatch value_loss = [] From 72ede3b53e747dc20d56e8eeb926b9f55bc8ba5a Mon Sep 17 00:00:00 2001 From: lishuaibin Date: Tue, 18 Jun 2024 21:53:35 +0800 Subject: [PATCH 03/37] dataloader, resolve comments --- examples/rlhf/demo_datas/pretrain_data.json | 2 + examples/rlhf/demo_datas/prompt_data.json | 3 + examples/rlhf/four_model_8gpu.py | 55 +-- examples/rlhf/four_model_vllm_8gpu.py | 55 +-- requirements/rlhf.txt | 1 - setup.py | 1 + xtuner/rlhf/config/config_utils.py | 10 +- xtuner/rlhf/coordinator.py | 3 - xtuner/rlhf/dataset/__init__.py | 3 + xtuner/rlhf/dataset/base.py | 331 ++++++++++-------- xtuner/rlhf/dataset/message_iter.py | 229 ++++++++++++ .../open_datasets/Anthropic_hh_rlhf.py | 88 ----- xtuner/rlhf/dataset/open_datasets/__init__.py | 0 xtuner/rlhf/dataset/txt_loader.py | 267 -------------- xtuner/rlhf/dataset/utils/__init__.py | 7 + xtuner/rlhf/dataset/utils/collate_fns.py | 28 ++ xtuner/rlhf/dataset/utils/from_hf.py | 75 ++++ xtuner/rlhf/dataset/utils/map_fns.py | 24 ++ xtuner/rlhf/envs/__init__.py | 3 + xtuner/rlhf/envs/prompt_utils.py | 4 +- xtuner/rlhf/envs/txt_env.py | 208 ++++++----- xtuner/rlhf/loss/__init__.py | 4 + xtuner/rlhf/loss/actor_loss.py | 81 ++--- xtuner/rlhf/loss/critic_loss.py | 47 +-- xtuner/rlhf/loss/pretrain_loss.py | 36 -- xtuner/rlhf/main.py | 98 +++--- xtuner/rlhf/model_backend/generate_utils.py | 4 +- xtuner/rlhf/model_backend/hf_model_runner.py | 8 +- .../rlhf/model_backend/vllm_model_runner.py | 4 +- .../rlhf/model_server/actor_model_server.py | 4 +- xtuner/rlhf/model_server/base_model_server.py | 8 +- xtuner/rlhf/policy_output.py | 15 +- xtuner/rlhf/repeaters/__init__.py | 3 + xtuner/rlhf/repeaters/base.py | 148 ++++---- xtuner/rlhf/tokenizer/tokenizer_utils.py | 5 +- xtuner/rlhf/trainer/__init__.py | 3 + xtuner/rlhf/trainer/ppo.py | 189 +++++----- 37 files changed, 1010 insertions(+), 1044 deletions(-) create mode 100644 examples/rlhf/demo_datas/pretrain_data.json create mode 100644 examples/rlhf/demo_datas/prompt_data.json create mode 100644 xtuner/rlhf/dataset/message_iter.py delete mode 100644 xtuner/rlhf/dataset/open_datasets/Anthropic_hh_rlhf.py delete mode 100644 xtuner/rlhf/dataset/open_datasets/__init__.py delete mode 100644 xtuner/rlhf/dataset/txt_loader.py create mode 100644 xtuner/rlhf/dataset/utils/__init__.py create mode 100644 xtuner/rlhf/dataset/utils/collate_fns.py create mode 100644 xtuner/rlhf/dataset/utils/from_hf.py create mode 100644 xtuner/rlhf/dataset/utils/map_fns.py delete mode 100644 xtuner/rlhf/loss/pretrain_loss.py diff --git a/examples/rlhf/demo_datas/pretrain_data.json b/examples/rlhf/demo_datas/pretrain_data.json new file mode 100644 index 000000000..ccc5e0628 --- /dev/null +++ b/examples/rlhf/demo_datas/pretrain_data.json @@ -0,0 +1,2 @@ +[{"role": "user", "content": ""}, {"role": "assistant", "content": "I am an artificial intelligence (AI) assistant named InternLM. I was created by the Shanghai AI Laboratory and my purpose is to assist users with various tasks through natural language processing technology."}] +[{"role": "user", "content": "Give three tips for staying healthy."}, {"role": "assistant", "content": "1.Eat a balanced diet. 2. Exercise regularly. 3. Get enough sleep."}, {"role": "user", "content": "How to study English?"}, {"role": "assistant","content": "1. Set clear goals. 2. Create a study plan. 3. Build vocabulary. 4. Practice speaking."}] diff --git a/examples/rlhf/demo_datas/prompt_data.json b/examples/rlhf/demo_datas/prompt_data.json new file mode 100644 index 000000000..6bee0447b --- /dev/null +++ b/examples/rlhf/demo_datas/prompt_data.json @@ -0,0 +1,3 @@ +[{"role": "user", "content": "How to study English?"}] +[{"role": "system", "content": "You are a helpful AI assistant."}, {"role": "user", "content": "Give three tips for staying healthy."}] +[{"role": "user", "content": "Give three tips for staying healthy."}, {"role": "assistant", "content": "1.Eat a balanced diet. 2. Exercise regularly. 3. Get enough sleep."}, {"role": "user", "content": "How to study English?"}] diff --git a/examples/rlhf/four_model_8gpu.py b/examples/rlhf/four_model_8gpu.py index 9ccfc3c4c..b4ebb31ce 100644 --- a/examples/rlhf/four_model_8gpu.py +++ b/examples/rlhf/four_model_8gpu.py @@ -1,5 +1,6 @@ -import torch - +####################################################################### +# Settings # +####################################################################### MAX_PROMPT_LEN = 1024 MAX_ANSWER_LEN = 1024 MAX_PRETRAIN_LEN = 8192 @@ -17,7 +18,7 @@ CRITIC_DP_SIZE = 2 ACTOR_GRADIENT_ACC_STEP = (PROMPT_BATCH_SIZE + PRETRAIN_BATCH_SIZE ) // ACTOR_DP_SIZE // TRAIN_MICRO_BATCH_SIZE -CRITIC_GRADIENT_ACC_STEP = PROMPT_BATCH_SIZE // CRITIC_DP_SIZE // TRAIN_MICRO_BATCH_SIZE +CRITIC_GRADIENT_ACC_STEP = PROMPT_BATCH_SIZE // CRITIC_DP_SIZE // TRAIN_MICRO_BATCH_SIZE # noqa: E501 MODEL_DTYPE = 'auto' @@ -63,8 +64,9 @@ critic_micro_bs=TRAIN_MICRO_BATCH_SIZE, ppo_loss_weight=1.0, pretrain_loss_weight=0.5, - pretrain_step=20, + critic_warmup_step=20, save_interval=40, + max_train_step=400, ) model_configs = dict( @@ -81,7 +83,6 @@ lr=1e-6, total_steps=1e9, lr_decay_rate=1, - loss_type='per_seq', ), parallel=dict( data=dict(size=ACTOR_DP_SIZE, mode='deepspeed'), @@ -130,7 +131,6 @@ lr=5e-6, total_steps=1e9, lr_decay_rate=1, - loss_type='per_seq', ), parallel=dict( data=dict(size=CRITIC_DP_SIZE, mode='deepspeed'), @@ -196,24 +196,27 @@ ), ) -dataset_config = { - 'prompt_samples_each_epoch': - PROMPT_BATCH_SIZE, - 'max_prompt_len': - MAX_PROMPT_LEN, - 'pretrain_samples_each_epoch': - PRETRAIN_BATCH_SIZE, - 'max_pretrain_len': - MAX_PRETRAIN_LEN, - 'random_seed': - 1024, - "sample_strategy": "in_data", - "ratio_within_datasets": False, - 'prompt_datasets': [ - 'Anthropic/hh-rlhf/helpful-base::1.0', - 'Anthropic/hh-rlhf/harmless-base::0.5', - ], - 'pretrain_datasets': [ - 'Anthropic/hh-rlhf/helpful-base::1.0', +prompt_dataset_config = dict( + samples_each_epoch=PROMPT_BATCH_SIZE, + max_len=MAX_PROMPT_LEN, + message_type='prompt', + random_seed=1024, + sample_strategy='in_batch', # 'in_data' + message_datasets=[ + './examples/rlhf/demo_datas/prompt_data.json::0.01[SYS_PROMPT]:summarization', # noqa: E501 + '[HF]Anthropic/hh-rlhf/helpful-base::0.5[RM_PROMPT]:default', + '[HF]HuggingFaceH4/summarize_from_feedback::0.5', + ]) + +pretrain_dataset_config = dict( + samples_each_epoch=PRETRAIN_BATCH_SIZE, + max_len=MAX_PRETRAIN_LEN, + message_type='pretrain', + random_seed=1024, + sample_strategy='in_batch', # 'in_data' + message_datasets=[ + './examples/rlhf/demo_datas/pretrain_data.json::0.01', + '[HF]Anthropic/hh-rlhf/helpful-base::0.5', + '[HF]HuggingFaceH4/summarize_from_feedback::0.5', ], -} +) diff --git a/examples/rlhf/four_model_vllm_8gpu.py b/examples/rlhf/four_model_vllm_8gpu.py index 9d8ea67fe..5da931ff7 100644 --- a/examples/rlhf/four_model_vllm_8gpu.py +++ b/examples/rlhf/four_model_vllm_8gpu.py @@ -1,5 +1,6 @@ -import torch - +####################################################################### +# Settings # +####################################################################### MAX_PROMPT_LEN = 1024 MAX_ANSWER_LEN = 1024 MAX_PRETRAIN_LEN = 8192 @@ -17,7 +18,7 @@ CRITIC_DP_SIZE = 2 ACTOR_GRADIENT_ACC_STEP = (PROMPT_BATCH_SIZE + PRETRAIN_BATCH_SIZE ) // ACTOR_DP_SIZE // TRAIN_MICRO_BATCH_SIZE -CRITIC_GRADIENT_ACC_STEP = PROMPT_BATCH_SIZE // CRITIC_DP_SIZE // TRAIN_MICRO_BATCH_SIZE +CRITIC_GRADIENT_ACC_STEP = PROMPT_BATCH_SIZE // CRITIC_DP_SIZE // TRAIN_MICRO_BATCH_SIZE # noqa: E501 MODEL_DTYPE = 'auto' @@ -63,8 +64,9 @@ critic_micro_bs=TRAIN_MICRO_BATCH_SIZE, ppo_loss_weight=1.0, pretrain_loss_weight=0.5, - pretrain_step=20, + critic_warmup_step=20, save_interval=40, + max_train_step=400, ) model_configs = dict( @@ -81,7 +83,6 @@ lr=1e-6, total_steps=1e9, lr_decay_rate=1, - loss_type='per_seq', ), parallel=dict( data=dict(size=ACTOR_DP_SIZE, mode='deepspeed'), @@ -139,7 +140,6 @@ lr=5e-6, total_steps=1e9, lr_decay_rate=1, - loss_type='per_seq', ), parallel=dict( data=dict(size=CRITIC_DP_SIZE, mode='deepspeed'), @@ -205,24 +205,27 @@ ), ) -dataset_config = { - 'prompt_samples_each_epoch': - PROMPT_BATCH_SIZE, - 'max_prompt_len': - MAX_PROMPT_LEN, - 'pretrain_samples_each_epoch': - PRETRAIN_BATCH_SIZE, - 'max_pretrain_len': - MAX_PRETRAIN_LEN, - 'random_seed': - 1024, - # "sample_strategy": "in_data", - # "ratio_within_datasets": False, - 'prompt_datasets': [ - 'Anthropic/hh-rlhf/helpful-base::1.0', - 'Anthropic/hh-rlhf/harmless-base::0.5', - ], - 'pretrain_datasets': [ - 'Anthropic/hh-rlhf/helpful-base::1.0', +prompt_dataset_config = dict( + samples_each_epoch=PROMPT_BATCH_SIZE, + max_len=MAX_PROMPT_LEN, + message_type='prompt', + random_seed=1024, + sample_strategy='in_batch', # 'in_data' + message_datasets=[ + './examples/rlhf/demo_datas/prompt_data.json::0.01[SYS_PROMPT]:summarization', # noqa: E501 + '[HF]Anthropic/hh-rlhf/helpful-base::0.5[RM_PROMPT]:default', + '[HF]HuggingFaceH4/summarize_from_feedback::0.5', + ]) + +pretrain_dataset_config = dict( + samples_each_epoch=PRETRAIN_BATCH_SIZE, + max_len=MAX_PRETRAIN_LEN, + message_type='pretrain', + random_seed=1024, + sample_strategy='in_batch', # 'in_data' + message_datasets=[ + './examples/rlhf/demo_datas/pretrain_data.json::0.01', + '[HF]Anthropic/hh-rlhf/helpful-base::0.5', + '[HF]HuggingFaceH4/summarize_from_feedback::0.5', ], -} +) diff --git a/requirements/rlhf.txt b/requirements/rlhf.txt index 53f8bbc63..22cfbcaa3 100644 --- a/requirements/rlhf.txt +++ b/requirements/rlhf.txt @@ -1,3 +1,2 @@ --r requirements/deepspeed.txt loguru ray[default,train]==2.9.1 diff --git a/setup.py b/setup.py index 8a3c0a5eb..3a95da067 100644 --- a/setup.py +++ b/setup.py @@ -133,6 +133,7 @@ def gen_packages_items(): parse_requirements('requirements/runtime.txt') + parse_requirements('requirements/modelscope.txt'), 'rlhf': + parse_requirements('requirements/deepspeed.txt') + parse_requirements('requirements/rlhf.txt'), }, zip_safe=False, diff --git a/xtuner/rlhf/config/config_utils.py b/xtuner/rlhf/config/config_utils.py index 6a74d32e2..ae1ebb3f0 100644 --- a/xtuner/rlhf/config/config_utils.py +++ b/xtuner/rlhf/config/config_utils.py @@ -3,14 +3,8 @@ def get_gpu_requirement(trainer_config: dict) -> int: # Calculates the number of GPUs required for a given trainer configuration. - num_gpus = 1 - if 'parallel' in trainer_config: - parallel = trainer_config['parallel'] - data = parallel.get('data', {'size': 1}) - tensor = parallel.get('tensor', {'size': 1}) - pipeline = parallel.get('pipeline', {'size': 1}) - num_gpus = data['size'] * tensor['size'] * pipeline['size'] - return num_gpus + return get_dp_size(trainer_config) * get_tp_size( + trainer_config) * get_pp_size(trainer_config) def get_resource_requirement(model_configs: dict) -> dict: diff --git a/xtuner/rlhf/coordinator.py b/xtuner/rlhf/coordinator.py index 2d6fbbc4c..b64e15ff3 100644 --- a/xtuner/rlhf/coordinator.py +++ b/xtuner/rlhf/coordinator.py @@ -81,9 +81,6 @@ def _schedule(self): for model_name, model in self.model_dict.items( ): # naive serial initialize model.initialize_async() - logger.info( - f'{model_name} {model.__class__.__name__}.is_initialized: {model.is_initialized}' # noqa: E501 - ) for model_name, model in self.model_dict.items( ): # naive serial initialize model.initialize_get() diff --git a/xtuner/rlhf/dataset/__init__.py b/xtuner/rlhf/dataset/__init__.py index e69de29bb..dea11525b 100644 --- a/xtuner/rlhf/dataset/__init__.py +++ b/xtuner/rlhf/dataset/__init__.py @@ -0,0 +1,3 @@ +from .message_iter import MessageIter + +__all__ = ['MessageIter'] diff --git a/xtuner/rlhf/dataset/base.py b/xtuner/rlhf/dataset/base.py index 9f9a5cb69..a9a4330d6 100644 --- a/xtuner/rlhf/dataset/base.py +++ b/xtuner/rlhf/dataset/base.py @@ -6,6 +6,7 @@ from contextlib import contextmanager import numpy as np +from loguru import logger from torch.utils.data import ConcatDataset, Dataset, IterableDataset, Subset @@ -40,78 +41,150 @@ def __iter__(self): yield self.data[i] -class FileDataset(IterableDataset): +class IterDataset(IterableDataset): """Single json file dataset.""" def __init__(self, - filename, - tokenizer, - sys_meta='default', - rm_meta='default'): + filename=None, + data_list=None, + tokenizer=None, + sys_prompt='default', + rm_prompt='default'): + assert filename is not None or data_list is not None self._filename = filename + self.data_list = data_list self.tokenizer = tokenizer - self.data_list = [] - self.sys_meta = sys_meta - self.rm_meta = rm_meta - with open_file(self._filename) as fin: - for lineno, line in enumerate(fin): - data = json.loads(line) - self.data_list.append(data) - - def __len__(self): - return len(self.data_list) - - def __getitem__(self, index: int): - data = self.data_list[index] - try: - self.tokenizer.apply_chat_template(data, tokenize=True) - return { - 'data': data, - 'sys_meta': self.sys_meta, - 'rm_meta': self.rm_meta - } - except Exception: - print(f'[data tokenize check] skip dirty data: {data}') - return None + self.sys_prompt = sys_prompt + self.rm_prompt = rm_prompt def __iter__(self): - with open_file(self._filename) as fin: - for lineno, line in enumerate(fin): - data = json.loads(line) + if self.data_list is not None: + for lineno, data in enumerate(self.data_list): try: self.tokenizer.apply_chat_template(data, tokenize=True) except Exception: - print(f'[data tokenize check] skip dirty data: {data}') + logger.info( + f'[data tokenize check] skip dirty data: {data}') continue if data is None: continue - yield { - 'data': data, - 'sys_meta': self.sys_meta, - 'rm_meta': self.rm_meta - } + yield dict( + data=data, + sys_prompt=self.sys_prompt, + rm_prompt=self.rm_prompt) + else: + with open_file(self._filename) as fin: + for lineno, line in enumerate(fin): + data = json.loads(line) + try: + self.tokenizer.apply_chat_template(data, tokenize=True) + except Exception: + logger.info( + f'[data tokenize check] skip dirty data: {data}') + continue + if data is None: + continue + yield dict( + data=data, + sys_prompt=self.sys_prompt, + rm_prompt=self.rm_prompt) -class OpensourceDataset(IterableDataset): - """Opensource dataset.""" +class MultiSourceInBatchDatset(IterableDataset): + """Multiple source dataset.""" + + def __init__(self, task_groups, tokenizer=None, random_seed=1024): + self._task_group = [] + for _task in task_groups: + file_path, extra_info = _task.split('::')[0], _task.split('::')[1] + prob = float(extra_info.split('[')[0]) + sys_prompt = 'default' + rm_prompt = 'default' + if '[SYS_PROMPT]:' in extra_info: + sys_prompt = extra_info.split('[SYS_PROMPT]:')[-1].split( + '[')[0] + if '[RM_PROMPT]:' in extra_info: + rm_prompt = extra_info.split('[RM_PROMPT]:')[-1].split('[')[0] + if prob > 0: + self._task_group.append( + dict( + prob=prob, + filepath=file_path, + sys_prompt=sys_prompt, + rm_prompt=rm_prompt)) + logger.info( + f'[DataLoader] Load {_task} with prob:{prob}, ' + f'sys_prompt type: {sys_prompt}, reward meta: {rm_prompt}') + else: + logger.warning('[DataLoader] skip file, ' + f'prob of {file_path} is {prob} ...') + assert len(self._task_group) > 0, 'No data to be trained' + + for task in self._task_group: + filepath = task['filepath'] + if '[HF]' in filepath: + from xtuner.rlhf.dataset.utils.from_hf import load_from_hf + + # loading & convert & save opensource datasets + hf_dir = filepath.split('[HF]')[-1] + logger.info(f'Loading {hf_dir} with huggingface format ...') + dataset = load_from_hf(hf_dir, tokenizer=tokenizer) + task['dataset'] = IterDataset( + data_list=dataset['conversation'], + tokenizer=tokenizer, + sys_prompt=task['sys_prompt'], + rm_prompt=task['rm_prompt']) + + else: + task['dataset'] = IterDataset( + filename=filepath, + tokenizer=tokenizer, + sys_prompt=task['sys_prompt'], + rm_prompt=task['rm_prompt']) + + sum_prob = sum([task['prob'] for task in self._task_group]) + for task in self._task_group: + task['prob'] = task['prob'] / sum_prob + + self.random_seed = random_seed + + def __iter__(self): + rng = random.Random(self.random_seed) + probs = [task['prob'] for task in self._task_group] + # Initialize task iterator + for task in self._task_group: + task['iterator'] = iter(task['dataset']) + while True: + task = rng.choices(self._task_group, weights=probs)[0] + try: + yield from task['iterator'] + except StopIteration: + task['iterator'] = iter(task['dataset']) + yield from task['iterator'] + + +class JsonDataset(Dataset): + """Single json file dataset.""" def __init__(self, - filename, - tokenizer, - sys_meta='default', - rm_meta='default'): - self._filename = filename + filename=None, + data_list=None, + tokenizer=None, + sys_prompt='default', + rm_prompt='default'): + assert filename is not None or data_list is not None self.tokenizer = tokenizer - self.sys_meta = sys_meta - self.rm_meta = rm_meta - assert 'Anthropic' in filename or 'openai' in filename, '[Coming soon] currently only support loading Anthropic and openai opensource datasets...' # noqa: E501 - if 'Anthropic' in filename: - from .open_datasets.Anthropic_hh_rlhf import AnthropicHhrlhf - self.data_list = AnthropicHhrlhf(path=filename) - elif 'openai' in filename: - pass + self.sys_prompt = sys_prompt + self.rm_prompt = rm_prompt + + if filename is not None: + self.data_list = [] + with open_file(filename) as fin: + for lineno, line in enumerate(fin): + data = json.loads(line) + self.data_list.append(data) else: - raise NotImplementedError() + self.data_list = data_list def __len__(self): return len(self.data_list) @@ -122,93 +195,86 @@ def __getitem__(self, index: int): self.tokenizer.apply_chat_template(data, tokenize=True) return { 'data': data, - 'sys_meta': self.sys_meta, - 'rm_meta': self.rm_meta + 'sys_prompt': self.sys_prompt, + 'rm_prompt': self.rm_prompt } except Exception: - print(f'[data tokenize check] skip dirty data: {data}') + logger.info(f'[data tokenize check] skip dirty data: {data}') return None - def __iter__(self): - for lineno, data in enumerate(self.data_list): - if data is None: - continue - try: - self.tokenizer.apply_chat_template(data, tokenize=True) - except Exception: - print(f'[data tokenize check] skip dirty data: {data}') - continue - yield { - 'data': data, - 'sys_meta': self.sys_meta, - 'rm_meta': self.rm_meta - } +class MultiSourceInDataDatset(Dataset): + """Multi source dataset. -class MultiSourceDatset(IterableDataset): - """Multiple source dataset.""" + Args: + task_groups: list of data path. + e.g. ['PATH_TO_XTUNER/examples/rlhf/demo_datas/prompt_data.json::0.9[SYS_PROMPT]:summarization', # noqa: E501 + 'PATH_TO_XTUNER/examples/rlhf/demo_datas/pretrain_data.json::0.1', + '[HF]Anthropic/hh-rlhf/helpful-base::0.5[RM_PROMPT]:default', + '[HF]HuggingFaceH4/summarize_from_feedback::0.5' + ] + tokenizer: The tokenizer processes some raw text as input and outputs + an Encoding. This argument should not be None. Default to None. + random_seed: + """ - def __init__(self, - task_groups, - sub_dataset_type='file', - tokenizer=None, - random_seed=1024, - ratio_within_datasets=True): + def __init__(self, task_groups, tokenizer=None, random_seed=1024): self._task_group = [] for _task in task_groups: file_path, extra_info = _task.split('::')[0], _task.split('::')[1] prob = float(extra_info.split('[')[0]) - sys_meta = 'default' - rm_meta = 'default' - if '[META]:' in extra_info: - sys_meta = extra_info.split('[META]:')[-1].split('[')[0] - if '[REWARD_META]:' in extra_info: - rm_meta = extra_info.split('[REWARD_META]:')[-1].split('[')[0] + sys_prompt = 'default' + rm_prompt = 'default' + if '[SYS_PROMPT]:' in extra_info: + sys_prompt = extra_info.split('[SYS_PROMPT]:')[-1].split( + '[')[0] + if '[RM_PROMPT]:' in extra_info: + rm_prompt = extra_info.split('[RM_PROMPT]:')[-1].split('[')[0] if prob > 0: - self._task_group.append({ - 'prob': prob, - 'filepath': file_path, - 'sys_meta': sys_meta, - 'rm_meta': rm_meta - }) - print( - f'[DataLoader] Load {_task} with prob:{prob}, sys_meta type: {sys_meta}, reward meta: {rm_meta}' # noqa: E501 - ) + self._task_group.append( + dict( + prob=prob, + filepath=file_path, + sys_prompt=sys_prompt, + rm_prompt=rm_prompt)) + logger.info( + f'[DataLoader] Load {_task} with prob:{prob}, ' + f'sys_prompt type: {sys_prompt}, reward meta: {rm_prompt}') else: - print( - f'[DataLoader] Warning skip file, prob of {file_path} is {prob} ...' # noqa: E501 - ) + logger.warning('[DataLoader] skip file, ' + f'prob of {file_path} is {prob} ...') assert len(self._task_group) > 0, 'No data to be trained' - if sub_dataset_type == 'file': - for task in self._task_group: - filepath = task['filepath'] - if '.json' in filepath: - task['dataset'] = FileDataset(filepath, tokenizer, - task['sys_meta'], - task['rm_meta']) - else: - # loading opensource datasets - print(f'Try loading {filepath} from huggingface ...') - task['dataset'] = OpensourceDataset( - filepath, tokenizer, task['sys_meta'], task['rm_meta']) - else: - raise NotImplementedError('Cannot support filelist now.') - self.random_seed = random_seed - self.ratio_within_datasets = ratio_within_datasets - if self.ratio_within_datasets: - sum_prob = sum([task['prob'] for task in self._task_group]) - for task in self._task_group: - task['prob'] = task['prob'] / sum_prob - else: - datasets = [] - for i, task in enumerate(self._task_group): - task['dataset'] = self._get_subset_by_ratio( - task['dataset'], task['prob'], random_seed) - datasets.append(task['dataset']) + datasets = [] + for task in self._task_group: + filepath = task['filepath'] - self.all_dataset = ConcatDataset(datasets) - self.iter_all_dataset = iter(self.all_dataset) + if '[HF]' in filepath: + from xtuner.rlhf.dataset.utils.from_hf import load_from_hf + + # loading & convert & save opensource datasets + hf_dir = filepath.split('[HF]')[-1] + logger.info(f'Loading {hf_dir} with huggingface format ...') + dataset = load_from_hf(hf_dir, tokenizer=tokenizer) + task['dataset'] = JsonDataset( + data_list=dataset['conversation'], + tokenizer=tokenizer, + sys_prompt=task['sys_prompt'], + rm_prompt=task['rm_prompt']) + else: + task['dataset'] = JsonDataset( + filename=filepath, + tokenizer=tokenizer, + sys_prompt=task['sys_prompt'], + rm_prompt=task['rm_prompt']) + task['dataset'] = self._get_subset_by_ratio( + task['dataset'], task['prob'], random_seed) + datasets.append(task['dataset']) + + self.all_dataset = ConcatDataset(datasets) + self.iter_all_dataset = iter(self.all_dataset) + + self.random_seed = random_seed def _get_subset_by_ratio(self, dataset: Dataset, ratio: float, seed: int): np_random = np.random.RandomState(seed) @@ -219,19 +285,4 @@ def _get_subset_by_ratio(self, dataset: Dataset, ratio: float, seed: int): return Subset(dataset, subset_indices) def __iter__(self): - """sample data one task by probs.""" - if self.ratio_within_datasets: - rng = random.Random(self.random_seed) - probs = [task['prob'] for task in self._task_group] - # Initialize task iterator - for task in self._task_group: - task['iterator'] = iter(task['dataset']) - while True: - task = rng.choices(self._task_group, weights=probs)[0] - try: - yield from task['iterator'] - except StopIteration: - task['iterator'] = iter(task['dataset']) - yield from task['iterator'] - else: - yield next(self.iter_all_dataset) + yield next(self.iter_all_dataset) diff --git a/xtuner/rlhf/dataset/message_iter.py b/xtuner/rlhf/dataset/message_iter.py new file mode 100644 index 000000000..0776e32b7 --- /dev/null +++ b/xtuner/rlhf/dataset/message_iter.py @@ -0,0 +1,229 @@ +"""Finetuning dataset.""" +import random +from dataclasses import dataclass +from typing import List + +import numpy as np +from loguru import logger +from torch.utils.data import DataLoader, RandomSampler + +from xtuner.rlhf.dataset.base import (InfiniteDataset, + MultiSourceInBatchDatset, + MultiSourceInDataDatset) + + +@dataclass +class Message: + message: List[dict] + sys_prompt: str = 'default' + rm_prompt: str = 'default' + token_ids: List[int] = None + mes_type: str = 'prompt' + + +class MessageIter(): + """Create sequences from dataset. + + Args: + sample_strategy (str) ["in_batch", "in_data"]: + "in_batch": sample data by ratio for every single training batch + "in_data": merge all data by ratio and then sample training batch + """ + + def __init__(self, + message_datasets: list[str] = None, + message_type: str = 'prompt', + tokenizer=None, + max_len: int = 4096, + samples_each_epoch: int = 64, + random_seed: int = 110, + sample_strategy: str = 'in_batch', + **kwargs): + assert message_type in ['prompt', 'pretrain'] + assert sample_strategy in [ + 'in_batch', 'in_data' + ], ("`sample_strategy` should in ['in_batch', 'in_data']," + f' but got {sample_strategy}') + assert message_datasets is not None + self.message_type = message_type + self.sample_strategy = sample_strategy + self.tokenizer = tokenizer + assert self.tokenizer.chat_template is not None, ( + 'Make sure tokenizer has chat_template.') + # message data + self.message_datasets = message_datasets + self.samples_each_epoch = samples_each_epoch + self.max_len = max_len + + self.random_seed = random_seed + self.rng = random.Random(self.random_seed) + np.random.seed(self.random_seed) + random.seed(self.random_seed) + + if self.sample_strategy == 'in_batch': + self._init_in_batch() + elif self.sample_strategy == 'in_data': + self._init_in_data() + else: + raise NotImplementedError( + "sample_strategy should in ['in_batch', 'in_data']," + f' but got {sample_strategy}') + logger.info(f'[MES_ITER] {self.message_type} dataset initialized, ' + f'random seed {self.random_seed}, ' + f'{self.samples_each_epoch} per epoch.\n') + + self.epoch_index = 0 + + def _init_in_data(self): + logger.info('====== Init in data dataset ======') + self.message_dataset = MultiSourceInDataDatset( + task_groups=self.message_datasets, tokenizer=self.tokenizer) + + logger.info('====== Init in data sampler ======') + assert hasattr(self.message_dataset, 'all_dataset') + mes_sampler = RandomSampler(self.message_dataset.all_dataset) + self.mes_dataloader = iter( + DataLoader( + self.message_dataset.all_dataset, + collate_fn=lambda x: x, + sampler=mes_sampler, + batch_size=self.samples_each_epoch)) + + def yield_in_data(self): + logger.info('====== yield data from in_data sampler ======') + mes_sequence = [] + + mes_batch_messages = next(self.mes_dataloader) + for index, message in enumerate(mes_batch_messages): + if message is None: + continue + sequence = self._postprocess_sequence(message) + if sequence is not None: + mes_sequence.append(sequence) + if len(mes_sequence) == self.samples_each_epoch: + break + # TODO, len(mes_sequence) < self.samples_each_epoch, + # tmp: random sample from chosen data + if len(mes_sequence) < self.samples_each_epoch: + missed = self.samples_each_epoch - len(mes_sequence) + logger.warning( + f'[MES_ITER] {self.message_type} {missed} dirty data ...') + for i in range(missed): + mes_sequence.append(mes_sequence[i]) + + assert len( + mes_sequence + ) == self.samples_each_epoch, \ + f'{len(mes_sequence)} == {self.samples_each_epoch}' + + assert len(mes_sequence) == self.samples_each_epoch + logger.info(f'[Epoch {self.epoch_index}] ' + f'sample {len(mes_sequence)} {self.message_type}') + return mes_sequence + + def _init_in_batch(self): + logger.info('====== Init in batch dataset ======') + self.message_dataset = MultiSourceInBatchDatset( + task_groups=self.message_datasets, tokenizer=self.tokenizer) + + logger.info('====== Init in batch sampler ======') + samples_cnts = [] + for task in self.message_dataset._task_group: + task['target_num_each_epoch'] = int( + task['prob'] * self.samples_each_epoch + 0.5) + 1 + inner_dataset = InfiniteDataset(task['dataset'], self.rng) + task['iterator'] = iter(inner_dataset) + samples_cnts.append(task['target_num_each_epoch']) + logger.info( + f"[MES_ITER] {task['filepath']}: task prob: {task['prob']}" + f' original number of messages: {len(inner_dataset.data)}' + f" target_num_each_epoch: {task['target_num_each_epoch']}") + assert sum(samples_cnts) >= self.samples_each_epoch + + def yield_in_batch(self): + logger.info('====== yield data from in_batch sampler ======') + mes_sequence = [] + + # epoch_rng only use in this epoch. + epoch_rng = np.random.RandomState(self.epoch_index) + # prepare epoch data + mes_batch_messages = [] + for task in self.message_dataset._task_group: + messages = [] + for _ in range(task['target_num_each_epoch']): + messages.append(next(task['iterator'])) + logger.info(f'[MES_ITER] sample {len(messages)} ' + f"{self.message_type} from {task['filepath']}") + epoch_rng.shuffle(messages) + mes_batch_messages.extend(messages) + epoch_rng.shuffle(mes_batch_messages) + for index, message in enumerate(mes_batch_messages): + sequence = self._postprocess_sequence(message) + if sequence is not None: + mes_sequence.append(sequence) + if len(mes_sequence) == self.samples_each_epoch: + break + + assert len(mes_sequence) == self.samples_each_epoch + logger.info(f'[Epoch {self.epoch_index}] sample ' + f'{len(mes_sequence)} {self.message_type}') + + return mes_sequence + + def __iter__(self): + while True: + if self.sample_strategy == 'in_batch': + yield self.yield_in_batch() + elif self.sample_strategy == 'in_data': + yield self.yield_in_data() + + self.epoch_index += 1 + + def _postprocess_sequence(self, message): + """Post process sequence: tokenization & truncation.""" + message_data = message['data'] + new_meaasage_data = [] + if self.message_type == 'prompt': + for _ in reversed(range(len(message_data))): + if message_data[_]['role'] == 'user': + new_meaasage_data = message_data[:_ + 1] + break + assert new_meaasage_data[-1]['role'] == 'user', \ + f'prompt data last role must user, {new_meaasage_data}' + token_ids = self.tokenizer.apply_chat_template( + new_meaasage_data, + tokenize=True, + add_generation_prompt=True, + return_tensors='pt') + if (token_ids.shape[-1] <= 4) or (token_ids.shape[-1] > + self.max_len): + # TODO truncation?? + logger.warning( + f'[MES_ITER] {self.message_type} message {message} ' + 'is too short or long, skipped...') + return None + elif self.message_type == 'pretrain': + for _ in reversed(range(len(message_data))): + if message_data[_]['role'] == 'assistant': + new_meaasage_data = message_data[:_ + 1] + break + assert new_meaasage_data[-1]['role'] == 'assistant', \ + f'pretrain data last role must assistant, {new_meaasage_data}' + token_ids = self.tokenizer.apply_chat_template( + new_meaasage_data, + tokenize=True, + add_generation_prompt=False, + return_tensors='pt') + + if token_ids.shape[-1] <= 4 or token_ids.shape[-1] > self.max_len: + # TODO truncation?? + logger.warning( + f'[MES_ITER] {self.message_type} message {message} ' + 'is too short or long, skipped...') + return None + return Message( + message=new_meaasage_data, + token_ids=token_ids, + sys_prompt=message['sys_prompt'], + rm_prompt=message['rm_prompt'], + mes_type=self.message_type) diff --git a/xtuner/rlhf/dataset/open_datasets/Anthropic_hh_rlhf.py b/xtuner/rlhf/dataset/open_datasets/Anthropic_hh_rlhf.py deleted file mode 100644 index 636cb84b5..000000000 --- a/xtuner/rlhf/dataset/open_datasets/Anthropic_hh_rlhf.py +++ /dev/null @@ -1,88 +0,0 @@ -import os -import re -from typing import List - -import datasets -from torch.utils.data import Dataset - - -def deduplicate(data: List[str]): - """Deduplicate data while preserving order. - - Refer to https://stackoverflow.com/questions/9792664/converting-a-list-to-a-set-changes-element-order # noqa: E501 - """ - return list(dict.fromkeys(data).keys()) - - -class AnthropicHhrlhf(Dataset): - """ - helpful-base: - train: 42,537 - test: 2,312 - - harmless-base: - train: 43,835 - test: 2,354 - - helpful-online: - train: 22,007 - test: 1,137 - - helpful-rejection-sampled: - train: 52,421 - test: 2,749 - - red-team-attempts: - train: 38,961 - """ - - def __init__(self, - path: str = 'Anthropic/hh-rlhf/helpful-base', - test=False): - super().__init__() - parts = path.split('/') - assert 'Anthropic' in parts and 'hh-rlhf' in parts, f'{self.__class__.__name__}: {path}' # noqa: E501 - if parts.index('hh-rlhf') == len(parts) - 1: - data_dir = None - else: - data_dir = parts[-1] - if os.path.exists('data/' + path): - raw_datasets = datasets.load_from_disk('data/' + path) - else: - print( - f'loading Anthropic/hh-rlhf data_dir={data_dir} from huggingface ...' # noqa: E501 - ) - raw_datasets = datasets.load_dataset( - 'Anthropic/hh-rlhf', data_dir=data_dir, trust_remote_code=True) - raw_datasets.save_to_disk('data/' + path) - if test: - raw_data_list = raw_datasets['test']['chosen'] - else: - raw_data_list = raw_datasets['train']['chosen'] - raw_data_list = [d for d in raw_data_list if d is not None] - raw_data_list = deduplicate(raw_data_list) - self.data_list = [ - self.format_chatml(prompt) for prompt in raw_data_list - ] - self.name = self.__class__.__name__ + '-' + data_dir if data_dir else self.__class__.__name__ # noqa: E501 - - def __len__(self): - return len(self.data_list) - - def __getitem__(self, index: int): - return self.data_list[index] - - @staticmethod - def format_chatml(string): - pattern = r'(\n\nHuman|\n\nAssistant)(.+?)(?=(\n\nHuman|\n\nAssistant|$))' # noqa: E501 - matches = re.findall(pattern, string, re.DOTALL) - messages = [] - for match in matches: - role, content = match[0].strip(), match[1].strip() - if role == 'Human': - messages.append({'role': 'user', 'content': content[2:]}) - elif role == 'Assistant': - messages.append({'role': 'assistant', 'content': content[2:]}) - else: - raise NotImplementedError('role must in Human or Assistant') - return messages diff --git a/xtuner/rlhf/dataset/open_datasets/__init__.py b/xtuner/rlhf/dataset/open_datasets/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/xtuner/rlhf/dataset/txt_loader.py b/xtuner/rlhf/dataset/txt_loader.py deleted file mode 100644 index cc20a54da..000000000 --- a/xtuner/rlhf/dataset/txt_loader.py +++ /dev/null @@ -1,267 +0,0 @@ -""" Finetuning dataset. """ -import random -from typing import List -import numpy as np -from dataclasses import dataclass -from torch.utils.data import IterableDataset, DataLoader, RandomSampler -from .base import MultiSourceDatset, InfiniteDataset - - -@dataclass -class Message: - message: List[dict] - sys_meta: str = "default" - rm_meta: str = "default" - token_ids: List[int] = None - mes_type: str = "prompt" - - -class TxtMessageDataset(IterableDataset): - """ Create sequences from dataset. - Args: - sample_strategy (str) ["in_batch", "in_data"]: "in_batch": sample data by ratio for every single training batch - "in_data": merge all data by ratio first and then sample training batch - """ - def __init__(self, - prompt_datasets: list[str] = None, - pretrain_datasets: list[str] = None, - tokenizer=None, - max_prompt_len: int = 4096, - max_pretrain_len: int = 4096, - prompt_samples_each_epoch: int = 64, - pretrain_samples_each_epoch: int = 0, - random_seed: int = 110, - sample_strategy: str = "in_batch", - ratio_within_datasets: bool = True, - **kwargs - ): - assert sample_strategy in ["in_batch", "in_data"], f"sample_strategy should in ['in_batch', 'in_data'], but got {sample_strategy}" - self.sample_strategy = sample_strategy - assert prompt_datasets is not None, "[Data error] Specify your data task config" - self.tokenizer = tokenizer - assert self.tokenizer.chat_template is not None, "Make sure tokenizer has chat_template." - - self.prompt_message_dataset = MultiSourceDatset(task_groups=prompt_datasets, - sub_dataset_type="file", - tokenizer=self.tokenizer, - ratio_within_datasets=ratio_within_datasets - ) - if pretrain_samples_each_epoch is not None and pretrain_samples_each_epoch > 0: - assert pretrain_datasets is not None, f"[PT DATA error] samples num {pretrain_samples_each_epoch}, while pretrain_datasets is None" - self.pt_message_dataset = MultiSourceDatset(task_groups=pretrain_datasets, - sub_dataset_type="file", - tokenizer=self.tokenizer, - ratio_within_datasets=ratio_within_datasets - ) - self.pretrain_samples_each_epoch = pretrain_samples_each_epoch - else: - self.pt_message_dataset = None - self.pretrain_samples_each_epoch = 0 - self.prompt_samples_each_epoch = prompt_samples_each_epoch - - self.max_prompt_len = max_prompt_len - self.max_pretrain_len = max_pretrain_len - self.num_samples_each_epoch = self.pretrain_samples_each_epoch + self.prompt_samples_each_epoch - - self.random_seed = random_seed - self.rng = random.Random(self.random_seed) - np.random.seed(self.random_seed) - random.seed(self.random_seed) - - if self.sample_strategy == "in_batch": - self._init_in_batch() - elif self.sample_strategy == "in_data": - self._init_in_data() - else: - raise NotImplementedError(f"sample_strategy should in ['in_batch', 'in_data'], but got {sample_strategy}") - - self.epoch_index = 0 - - def _init_in_data(self): - print(f"========================= Init in data sampler =========================") - if self.pretrain_samples_each_epoch != 0: - assert hasattr(self.pt_message_dataset, "all_dataset") - pt_sampler = RandomSampler(self.pt_message_dataset.all_dataset) - self.pt_dataloader = iter(DataLoader( - self.pt_message_dataset.all_dataset, collate_fn=lambda x: x, sampler=pt_sampler, batch_size=self.pretrain_samples_each_epoch - )) - print(f"[PT data] pretrain data per epoch: {self.pretrain_samples_each_epoch}") - - assert hasattr(self.prompt_message_dataset, "all_dataset") - prompt_sampler = RandomSampler(self.prompt_message_dataset.all_dataset) - self.prompt_dataloader = iter(DataLoader( - self.prompt_message_dataset.all_dataset, collate_fn=lambda x: x, sampler=prompt_sampler, batch_size=self.prompt_samples_each_epoch - )) - - print(f"[Prompt data] prompt data per epoch: {self.prompt_samples_each_epoch}") - print(f"[Txt] Training dataset initialized, random seed {self.random_seed}.\n") - - def yield_in_data(self): - print(f"========================= yield data from data sampler =========================") - batch_sequence = [] - prompt_sequence, pretrain_sequence = [], [] - if self.pretrain_samples_each_epoch != 0: - pretrain_batch_messages = next(self.pt_dataloader) - for index, message in enumerate(pretrain_batch_messages): - sequence = self._postprocess_sequence(message, mes_type="pretrain") - if sequence is not None: - assert sequence.mes_type == 'pretrain', f"Data type should be pretrain, but get {sequence.mes_type}" - pretrain_sequence.append(sequence) - if len(pretrain_sequence) == self.pretrain_samples_each_epoch: - break - assert len(pretrain_sequence) == self.pretrain_samples_each_epoch, f"{len(pretrain_sequence)} != {self.pretrain_samples_each_epoch}" - - prompt_batch_messages = next(self.prompt_dataloader) - for index, message in enumerate(prompt_batch_messages): - if message is None: - continue - sequence = self._postprocess_sequence(message, mes_type="prompt") - if sequence is not None: - assert sequence.mes_type == 'prompt', f"Data type should be prompt. but get {sequence.mes_type}" - prompt_sequence.append(sequence) - if len(prompt_sequence) == self.prompt_samples_each_epoch: - break - # TODO, len(prompt_sequence) < self.prompt_samples_each_epoch, random sample from chosen data - if len(prompt_sequence) < self.prompt_samples_each_epoch: - missed = self.prompt_samples_each_epoch - len(prompt_sequence) - print(f"[Warning] {missed} dirty data, use {missed} data from sampled data...") - for i in range(missed): - prompt_sequence.append(prompt_sequence[i]) - - assert len(prompt_sequence) == self.prompt_samples_each_epoch, f"{len(prompt_sequence)} == {self.prompt_samples_each_epoch}" - - print(f"prepare TxtMessageDataset done: {len(prompt_sequence)} prompt & {len(pretrain_sequence)} pretrain, for epoch {self.epoch_index}.") - batch_sequence = prompt_sequence + pretrain_sequence - assert len(batch_sequence) == self.num_samples_each_epoch, "[Epoch {self.epoch_index}] Wrong data len" - return batch_sequence - - def _init_in_batch(self): - print(f"========================= Init in batch sampler =========================") - samples_cnts = [] - pt_data_len = 0 - if self.pretrain_samples_each_epoch != 0: - for task in self.pt_message_dataset._task_group: - task["target_num_each_epoch"] = int(task["prob"] * self.pretrain_samples_each_epoch + 0.5) + 1 - inner_dataset = InfiniteDataset(task["dataset"], self.rng) - task["iterator"] = iter(inner_dataset) - samples_cnts.append(task["target_num_each_epoch"]) - print(f"[Pretrain data] {task['filepath']}: task prob: {task['prob']}, " - f"ori number of messages: {len(inner_dataset.data)}, " - f"target_num_each_epoch: {task['target_num_each_epoch']}") - pt_data_len = sum(samples_cnts) - # TODO - assert pt_data_len >= self.pretrain_samples_each_epoch, f"Make sure there are enough pretrain datas, {pt_data_len} >= {self.pretrain_samples_each_epoch}" - print(f"[PT data] pretrain data per epoch: {self.pretrain_samples_each_epoch}, sampled {pt_data_len}") - - for task in self.prompt_message_dataset._task_group: - task["target_num_each_epoch"] = int(task["prob"] * self.prompt_samples_each_epoch + 0.5) + 1 - inner_dataset = InfiniteDataset(task["dataset"], self.rng) - task["iterator"] = iter(inner_dataset) - samples_cnts.append(task["target_num_each_epoch"]) - print(f"{task['filepath']}: task prob: {task['prob']}, " - f"ori number of messages: {len(inner_dataset.data)}, " - f"target_num_each_epoch: {task['target_num_each_epoch']}") - assert (sum(samples_cnts) - pt_data_len) >= self.prompt_samples_each_epoch, "Make sure there are enough prompt datas" - print(f"[Prompt data] prompt data per epoch: {self.prompt_samples_each_epoch}, sampled: {sum(samples_cnts) - pt_data_len}") - - assert sum(samples_cnts) >= self.num_samples_each_epoch, "[Dataset init] sample num error" - # if sum(samples_cnts) <= self.num_samples_each_epoch: - # print(f"[Txt loader] Warning!!! sample nums {sum(samples_cnts)} <= samples {self.num_samples_each_epoch}") - print(f"[Txt] Training dataset initialized, random seed {self.random_seed}.\n") - - def yield_in_batch(self): - print(f"========================= yield data from batch sampler =========================") - batch_sequence = [] - prompt_sequence, pretrain_sequence = [], [] - - # epoch_rng only use in this epoch. - epoch_rng = np.random.RandomState(self.epoch_index) - # prepare epoch data - # print(f"prepare TxtMessageDataset for epoch {self.epoch_index}...") - if self.pretrain_samples_each_epoch != 0 : - pretrain_batch_messages = [] - for task in self.pt_message_dataset._task_group: - messages = [] - for _ in range(task["target_num_each_epoch"]): - messages.append(next(task["iterator"])) - print(f"[Pretrain] prepare {len(messages)} data from {task['filepath']}") - epoch_rng.shuffle(messages) - pretrain_batch_messages.extend(messages) - # if len(pretrain_batch_messages) == self.pretrain_samples_each_epoch: - # break - epoch_rng.shuffle(pretrain_batch_messages) - for index, message in enumerate(pretrain_batch_messages): - sequence = self._postprocess_sequence(message, mes_type="pretrain") - if sequence is not None: - assert sequence.mes_type == 'pretrain', f"Data type should be pretrain, but get {sequence.mes_type}" - pretrain_sequence.append(sequence) - if len(pretrain_sequence) == self.pretrain_samples_each_epoch: - break - assert len(pretrain_sequence) == self.pretrain_samples_each_epoch, f"{len(pretrain_sequence)} != {self.pretrain_samples_each_epoch}" - - prompt_batch_messages = [] - for task in self.prompt_message_dataset._task_group: - messages = [] - for _ in range(task["target_num_each_epoch"]): - messages.append(next(task["iterator"])) - print(f"[Prompt] prepare {len(messages)} data from {task['filepath']}") - epoch_rng.shuffle(messages) - prompt_batch_messages.extend(messages) - epoch_rng.shuffle(prompt_batch_messages) - for index, message in enumerate(prompt_batch_messages): - sequence = self._postprocess_sequence(message, mes_type="prompt") - if sequence is not None: - assert sequence.mes_type == 'prompt', f"Data type should be prompt. but get {sequence.mes_type}" - prompt_sequence.append(sequence) - if len(prompt_sequence) == self.prompt_samples_each_epoch: - break - assert len(prompt_sequence) == self.prompt_samples_each_epoch, f"{len(prompt_sequence)} == {self.prompt_samples_each_epoch}" - - print(f"prepare TxtMessageDataset done: {len(prompt_sequence)} prompt & {len(pretrain_sequence)} pretrain, for epoch {self.epoch_index}.") - batch_sequence = prompt_sequence + pretrain_sequence - assert len(batch_sequence) == self.num_samples_each_epoch, "[Epoch {self.epoch_index}] Wrong data len" - return batch_sequence - - def __iter__(self): - while True: - if self.sample_strategy == "in_batch": - yield self.yield_in_batch() - elif self.sample_strategy == "in_data": - yield self.yield_in_data() - - self.epoch_index += 1 - - def _postprocess_sequence(self, message, mes_type=None): - """Post process sequence: tokenization & truncation.""" - message_data = message['data'] - new_meaasage_data = [] - if mes_type == "prompt": - for _ in reversed(range(len(message_data))): - if message_data[_]["role"] == "user": - new_meaasage_data = message_data[:_ + 1] - break - assert new_meaasage_data[-1]["role"] == "user", f"prompt data last role must user, {new_meaasage_data}" - token_ids = self.tokenizer.apply_chat_template(new_meaasage_data, tokenize=True, add_generation_prompt=True, return_tensors="pt") - if token_ids.shape[-1] <= 4 or token_ids.shape[-1] > self.max_prompt_len: - # TODO truncation?? - # raise RuntimeError(f"token_ids is too long: {token_ids.shape[-1]}") - print(f"[TXT Loader] Warning, {mes_type} message {message} is too short or long, skipped...") - return None - elif mes_type == "pretrain": - for _ in reversed(range(len(message_data))): - if message_data[_]["role"] == "assistant": - new_meaasage_data = message_data[:_ + 1] - break - assert new_meaasage_data[-1]["role"] == "assistant", f"pretrain data last role must assistant, {new_meaasage_data}" - token_ids = self.tokenizer.apply_chat_template(new_meaasage_data, tokenize=True, add_generation_prompt=False, return_tensors="pt") - - if token_ids.shape[-1] <= 4 or token_ids.shape[-1] > self.max_pretrain_len: - # TODO truncation?? - # raise RuntimeError(f"token_ids is too long: {token_ids.shape[-1]}") - print(f"[TXT Loader] Warning, {mes_type} message {message} is too short or long, skipped...") - return None - return Message(message=new_meaasage_data, - token_ids=token_ids, - sys_meta=message['sys_meta'], - rm_meta=message['rm_meta'], - mes_type=mes_type) diff --git a/xtuner/rlhf/dataset/utils/__init__.py b/xtuner/rlhf/dataset/utils/__init__.py new file mode 100644 index 000000000..ec03048b9 --- /dev/null +++ b/xtuner/rlhf/dataset/utils/__init__.py @@ -0,0 +1,7 @@ +from .collate_fns import message_data_collator, messages_collate_fn +from .map_fns import H4_summarize_map_fn, hhrlhf_map_fn + +__all__ = [ + 'message_data_collator', 'messages_collate_fn', 'hhrlhf_map_fn', + 'H4_summarize_map_fn' +] diff --git a/xtuner/rlhf/dataset/utils/collate_fns.py b/xtuner/rlhf/dataset/utils/collate_fns.py new file mode 100644 index 000000000..c3551f2a0 --- /dev/null +++ b/xtuner/rlhf/dataset/utils/collate_fns.py @@ -0,0 +1,28 @@ +from collections import defaultdict +from functools import partial +from typing import Dict, Sequence + + +def messages_collate_fn( + instances: Sequence[Dict], + return_only_messages: bool = True, +): + + return_dict = defaultdict(list) + messages = [] + + for example in instances: + assert 'conversation' in example.keys() + messages.append(example['conversation']) + for k, v in example.items(): + return_dict[k].append(v) + + if return_only_messages: + return messages + else: + return return_dict + + +def message_data_collator(return_only_messages=True): + return partial( + messages_collate_fn, return_only_messages=return_only_messages) diff --git a/xtuner/rlhf/dataset/utils/from_hf.py b/xtuner/rlhf/dataset/utils/from_hf.py new file mode 100644 index 000000000..e4791d148 --- /dev/null +++ b/xtuner/rlhf/dataset/utils/from_hf.py @@ -0,0 +1,75 @@ +from datasets import load_dataset + +from xtuner.dataset import process_hf_dataset +from xtuner.dataset.map_fns import template_map_fn_factory +from xtuner.rlhf.dataset.utils import H4_summarize_map_fn, hhrlhf_map_fn +from xtuner.utils import PROMPT_TEMPLATE + + +def read_hf_dataset(tokenizer, + path: str = None, + data_dir: str = None, + dataset_map_fn=None, + max_length=8192, + split='train', + prompt_template=PROMPT_TEMPLATE.internlm_chat, + remove_unused_columns=False, + shuffle_before_pack=False, + pack_to_max_length=False): + # https://huggingface.co/datasets/Anthropic/hh-rlhf + template_map_fn = template_map_fn_factory(template=prompt_template) + dataset_org = load_dataset(path, data_dir=data_dir, trust_remote_code=True) + dataset = process_hf_dataset( + dataset=dataset_org, + tokenizer=tokenizer, + max_length=max_length, + split=split, + dataset_map_fn=dataset_map_fn, + template_map_fn=template_map_fn, + remove_unused_columns=remove_unused_columns, + shuffle_before_pack=shuffle_before_pack, + pack_to_max_length=pack_to_max_length) + return dataset + + +def load_from_hf(hf_dir, tokenizer, data_dir=None): + if 'Anthropic/hh-rlhf' in hf_dir: + # train: Dataset({ + # features: ['chosen', 'rejected'], + # num_rows: 160800 + # }) + # test: Dataset({ + # features: ['chosen', 'rejected'], + # num_rows: 8552 + # }) + if data_dir is not None: + data_dir = data_dir + elif 'helpful-base' in hf_dir: + data_dir = 'helpful-base' + elif 'harmless-base' in hf_dir: + data_dir = 'harmless-base' + + dataset = read_hf_dataset( + tokenizer=tokenizer, + path='Anthropic/hh-rlhf', + data_dir=data_dir, + max_length=8192, + split='train', + dataset_map_fn=hhrlhf_map_fn) + if 'summarize_from_feedback' in hf_dir: + # train_prefs: Dataset({ + # features: ['prompt', 'chosen', 'rejected'], + # num_rows: 92858 + # }) + # train_sft: Dataset({ + # features: ['prompt', 'chosen', 'rejected'], + # num_rows: 92858 + # }) + dataset = read_hf_dataset( + tokenizer=tokenizer, + path='HuggingFaceH4/summarize_from_feedback', + data_dir=data_dir, + max_length=8192, + split='train_prefs', + dataset_map_fn=H4_summarize_map_fn) + return dataset diff --git a/xtuner/rlhf/dataset/utils/map_fns.py b/xtuner/rlhf/dataset/utils/map_fns.py new file mode 100644 index 000000000..f66dd3ec7 --- /dev/null +++ b/xtuner/rlhf/dataset/utils/map_fns.py @@ -0,0 +1,24 @@ +import re + + +def hhrlhf_map_fn(example): + string = example['chosen'] + pattern = r'(\n\nHuman|\n\nAssistant)(.+?)(?=(\n\nHuman|\n\nAssistant|$))' # noqa: E501 + matches = re.findall(pattern, string, re.DOTALL) + messages = [] + for match in matches: + role, content = match[0].strip(), match[1].strip() + if role == 'Human': + messages.append({'role': 'user', 'content': content[2:]}) + elif role == 'Assistant': + messages.append({'role': 'assistant', 'content': content[2:]}) + else: + raise NotImplementedError('role must in Human or Assistant') + return {'conversation': messages} + + +def H4_summarize_map_fn(example): + # prompt = example['prompt'] + chosen = example['chosen'] + # rejected = example['rejected'] + return {'conversation': chosen} diff --git a/xtuner/rlhf/envs/__init__.py b/xtuner/rlhf/envs/__init__.py index e69de29bb..5175867a2 100644 --- a/xtuner/rlhf/envs/__init__.py +++ b/xtuner/rlhf/envs/__init__.py @@ -0,0 +1,3 @@ +from .txt_env import TxtEnv + +__all__ = ['TxtEnv'] diff --git a/xtuner/rlhf/envs/prompt_utils.py b/xtuner/rlhf/envs/prompt_utils.py index 54dfab3cf..9695ec92b 100644 --- a/xtuner/rlhf/envs/prompt_utils.py +++ b/xtuner/rlhf/envs/prompt_utils.py @@ -1,6 +1,4 @@ - - -META_PROMPT = { +SYSTEM_PROMPT = { 'default': '', 'helpful': diff --git a/xtuner/rlhf/envs/txt_env.py b/xtuner/rlhf/envs/txt_env.py index 7490151b8..b385e86a3 100644 --- a/xtuner/rlhf/envs/txt_env.py +++ b/xtuner/rlhf/envs/txt_env.py @@ -1,20 +1,21 @@ -import time +from collections.abc import Iterable from copy import deepcopy import torch from loguru import logger -from torch.utils.data import IterableDataset from ..model_server.base_model_server import BaseModelServer -from .prompt_utils import META_PROMPT +from ..timer import Timer +from .prompt_utils import SYSTEM_PROMPT -class TxtEnv: +class TxtEnv(): """A generic RL environment to generate textual sequences.""" def __init__( self, - dataloader: IterableDataset, + prompt_mes_iter: Iterable, + pretrain_mes_iter: Iterable = None, max_new_tokens: int = 1024, actor_micro_bs: int = 32, reward_micro_bs: int = 32, @@ -25,10 +26,12 @@ def __init__( ): """ Args: - dataloader (IterableDataset): generate rl data iteratively - reward_function: reward function that computes scalar reward for each episode # noqa: E501 + dataloader: generate rl data iteratively + reward_function: reward function that computes scalar reward """ - self.dataloader: IterableDataset = iter(dataloader) + self.prompt_mes_iter = iter(prompt_mes_iter) + self.pretrain_mes_iter = iter( + pretrain_mes_iter) if pretrain_mes_iter else None self.reward_function: BaseModelServer = reward_function self._cur_messagess = [] self.max_new_tokens = max_new_tokens @@ -38,137 +41,130 @@ def __init__( self.generate_kwargs: dict = generate_kwargs def rollout(self, policy_model: BaseModelServer, display=False): - sample_data = deepcopy(next(self.dataloader)) + prompt_datas = deepcopy(next(self.prompt_mes_iter)) prompt_input_messages = [] - pretrain_input_messages = [] - for data in sample_data: - if data.sys_meta != 'default': - message = deepcopy([{ - 'role': 'system', - 'content': META_PROMPT[data.sys_meta] - }] + data.message) + for data in prompt_datas: + assert data.mes_type == 'prompt' + if data.sys_prompt != 'default': + message = deepcopy([ + dict( + role='system', content=SYSTEM_PROMPT[data.sys_prompt]) + ] + data.message) else: message = deepcopy(data.message) - if data.mes_type == 'prompt': - prompt_input_messages.append(message) - elif data.mes_type == 'pretrain': - pretrain_input_messages.append(message) - else: - raise TypeError(f'Wrong message type {data.mes_type}') + prompt_input_messages.append(message) # prompt data - s_t = time.time() - print(f'[For Generate]: {prompt_input_messages[0]}') - trajectories = policy_model.generate( - inputs=prompt_input_messages, - micro_batch_size=self.actor_micro_bs, - step=self.max_new_tokens, - output_str=True, - generate_kwargs=self.generate_kwargs) - logger.info( - f'[actor generate] duration: {round(time.time() - s_t, 2)} s, len(inputs): {len(prompt_input_messages)} ' # noqa: E501 - ) + logger.info(f'[For Generate]: {prompt_input_messages[0]}') + with Timer('policy_model.generate'): + trajectories = policy_model.generate( + inputs=prompt_input_messages, + micro_batch_size=self.actor_micro_bs, + step=self.max_new_tokens, + output_str=True, + generate_kwargs=self.generate_kwargs) + logger.info(f'[generate] len: {len(prompt_input_messages)}') if self.async_reward: - reward_output_ref = self.get_reward_async(sample_data, + reward_output_ref = self.get_reward_async(prompt_datas, trajectories) trajectories['reward_output_ref'] = reward_output_ref else: - rewards = self.get_reward(sample_data, trajectories) + rewards = self.get_reward(prompt_datas, trajectories) trajectories['rewards'] = rewards # pretrain data - if len(pretrain_input_messages) > 0: - from ..tokenizer import tokenizer_utils - pretrain_input_ids, pretrain_attention_mask = tokenizer_utils.encode( + if self.pretrain_mes_iter is not None: + pretrain_datas = deepcopy(next(self.pretrain_mes_iter)) + pretrain_input_messages = [] + for data in pretrain_datas: + assert data.mes_type == 'pretrain' + pretrain_input_messages.append(message) + + from xtuner.rlhf.tokenizer import tokenizer_utils + pt_input_ids, pt_attention_mask = tokenizer_utils.encode( pretrain_input_messages, policy_model.tokenizer) - pretrain_labels = torch.nn.functional.pad(pretrain_input_ids[:, 1:], (0, 1), mode="constant", value=-100) - - trajectories.pretrain_data = {"input_ids": pretrain_input_ids, - "labels": pretrain_labels, - "attention_mask": pretrain_attention_mask} - print( - f'[TxtEnv & {policy_model.__class__.__name__}] gets {len(pretrain_input_messages)} pretrain episodes.' # noqa: E501 - ) + pretrain_labels = torch.nn.functional.pad( + pt_input_ids[:, 1:], (0, 1), mode='constant', value=-100) + + trajectories.pretrain_data = { + 'input_ids': pt_input_ids, + 'labels': pretrain_labels, + 'attention_mask': pt_attention_mask + } + logger.info(f'[TxtEnv & {policy_model.__class__.__name__}] \ + gets {pt_input_ids.shape} pretrain data.') else: trajectories.pretrain_data = None return trajectories - # default get_reward() is blocking. get_reward_async() needs to call get_reward_collect() # noqa: E501 - def get_reward_async(self, sample_data, policyout): - s_t = time.time() + # default get_reward() is blocking. + # get_reward_async() needs to call get_reward_collect() + def get_reward_async(self, prompt_datas, policyout): rm_input_messages = [] - for i in range(len(sample_data)): - if sample_data[i].mes_type != "prompt": + for i in range(len(prompt_datas)): + if prompt_datas[i].mes_type != 'prompt': continue - if sample_data[i].rm_meta != 'default': - cur_rm_data = [{ - 'role': 'system', - 'content': META_PROMPT[sample_data[i].rm_meta] - }] + sample_data[i].message + [{ - 'role': - 'assistant', - 'content': - policyout.output_ans_str[i] - }] + if prompt_datas[i].rm_prompt != 'default': + # Conditional Reward Model + # for queries from different domains, use appropriate conditional system prompts # noqa: E501 + # From Alignment section of the InternLM2 Technical Report: + # https://arxiv.org/pdf/2403.17297 + cur_rm_data = [ + dict( + role='system', + content=SYSTEM_PROMPT[prompt_datas[i].rm_prompt]) + ] + prompt_datas[i].message + [ + dict( + role='assistant', content=policyout.output_ans_str[i]) + ] else: - cur_rm_data = sample_data[i].message + [{ - 'role': - 'assistant', - 'content': - policyout.output_ans_str[i] - }] + cur_rm_data = prompt_datas[i].message + [ + dict( + role='assistant', content=policyout.output_ans_str[i]) + ] rm_input_messages.append(cur_rm_data) - print(f'[For Reward]: {rm_input_messages[0]}') - reward_output_ref = self.reward_function.infer_async( - rm_input_messages, - output_logprobs=False, - micro_batch_size=self.reward_micro_bs) - logger.info( - f'[reward infer] async duration: {round(time.time() - s_t, 2)} s') + logger.info(f'[For Reward]: {rm_input_messages[0]}') + with Timer('reward_model.infer_async'): + reward_output_ref = self.reward_function.infer_async( + rm_input_messages, + output_logprobs=False, + micro_batch_size=self.reward_micro_bs) return reward_output_ref def get_reward_collect(self, reward_output_ref): - s_t = time.time() - rm_out = self.reward_function.infer_get(reward_output_ref) - logger.info( - f'[reward infer] async wait duration: {round(time.time() - s_t, 2)} s' # noqa: E501 - ) + with Timer('reward_model.infer_get'): + rm_out = self.reward_function.infer_get(reward_output_ref) rewards = rm_out.logits.squeeze(-1) return rewards - def get_reward(self, sample_data, policyout): - s_t = time.time() + def get_reward(self, prompt_datas, policyout): rm_input_messages = [] - for i in range(len(sample_data)): - if sample_data[i].mes_type != "prompt": + for i in range(len(prompt_datas)): + if prompt_datas[i].mes_type != 'prompt': continue - if sample_data[i].rm_meta != 'default': - cur_rm_data = [{ - 'role': 'system', - 'content': META_PROMPT[sample_data[i].rm_meta] - }] + sample_data[i].message + [{ - 'role': - 'assistant', - 'content': - policyout.output_ans_str[i] - }] + if prompt_datas[i].rm_prompt != 'default': + cur_rm_data = [ + dict( + role='system', + content=SYSTEM_PROMPT[prompt_datas[i].rm_prompt]) + ] + prompt_datas[i].message + [ + dict( + role='assistant', content=policyout.output_ans_str[i]) + ] else: - cur_rm_data = sample_data[i].message + [{ - 'role': - 'assistant', - 'content': - policyout.output_ans_str[i] - }] + cur_rm_data = prompt_datas[i].message + [ + dict( + role='assistant', content=policyout.output_ans_str[i]) + ] rm_input_messages.append(cur_rm_data) - print(f'[For Reward]: {rm_input_messages[0]}') - rm_out = self.reward_function.infer( - rm_input_messages, - output_logprobs=False, - micro_batch_size=self.reward_micro_bs) - logger.info( - f'[reward infer] duration: {round(time.time() - s_t, 2)} s') + logger.info(f'[For Reward]: {rm_input_messages[0]}') + with Timer('reward_model.infer'): + rm_out = self.reward_function.infer( + rm_input_messages, + output_logprobs=False, + micro_batch_size=self.reward_micro_bs) rewards = rm_out.logits.squeeze(-1) return rewards diff --git a/xtuner/rlhf/loss/__init__.py b/xtuner/rlhf/loss/__init__.py index e69de29bb..712703598 100644 --- a/xtuner/rlhf/loss/__init__.py +++ b/xtuner/rlhf/loss/__init__.py @@ -0,0 +1,4 @@ +from .actor_loss import PPOPolicyLoss, PretrainLoss +from .critic_loss import CriticLoss + +__all__ = ['PPOPolicyLoss', 'PretrainLoss', 'CriticLoss'] diff --git a/xtuner/rlhf/loss/actor_loss.py b/xtuner/rlhf/loss/actor_loss.py index e5c05cc01..bfd7e5b68 100644 --- a/xtuner/rlhf/loss/actor_loss.py +++ b/xtuner/rlhf/loss/actor_loss.py @@ -1,57 +1,63 @@ from typing import Any import torch +from loguru import logger from ..policy_output import logprobs_from_logits -class ActorLoss(torch.nn.Module): +class PretrainLoss(torch.nn.Module): + """Loss function for flash GPT Language Model.""" + + def __init__(self, label_smoothing=0): + super().__init__() + + if label_smoothing is not None and label_smoothing != 0: + logger.warning(f'Use label_smoothing: {label_smoothing}') + self.label_smoothing = label_smoothing + + # the output will gather output is set in the model, + # so use ordinary loss + self.loss_fn = torch.nn.CrossEntropyLoss( + reduction='mean', label_smoothing=label_smoothing) + + def forward(self, *args): + if len(args) == 3: + # residual is to match prenorm + logits, _, labels = args + elif len(args) == 2: + # When using postnorm + logits, labels = args + else: + raise RuntimeError( + f'The number of criterion inputs are:{len(args)}') + shift_logits = logits.contiguous().view(-1, logits.size(-1)) + shift_labels = labels.contiguous().view(-1) + loss = self.loss_fn(shift_logits, shift_labels) + # There is no need to consider the ignore_index problem here, + # because the loss calculation will be calculated through the calculation range, # noqa: E501 + # and -100 must be outside this range, + # so there is no problem + + return loss + + +class PPOPolicyLoss(torch.nn.Module): """Loss function for actor model.""" - def __init__(self, cliprange: float = 0.2, loss_type: str = 'per_seq'): + def __init__(self, cliprange: float = 0.2): super().__init__() self.cliprange = cliprange - self.loss_type = loss_type - assert self.loss_type in ['per_token', 'per_seq'] - def actor_loss_fn(self, logprobs, old_logprobs, advantages, mask, - loss_factor): + def actor_loss_fn(self, logprobs, old_logprobs, advantages, mask): ratio = (logprobs - old_logprobs).exp() pg_loss1 = -ratio * advantages pg_loss2 = -ratio.clamp(1 - self.cliprange, 1 + self.cliprange) * advantages - if self.loss_type == 'per_seq': - pg_loss = (torch.max(pg_loss1, pg_loss2) * mask).sum() / mask.sum() - elif self.loss_type == 'per_token': - pg_loss = torch.sum( - torch.max(pg_loss1, pg_loss2) * mask) * loss_factor - else: - raise RuntimeError( - f"ActorLoss.loss_type must be ['per_seq', 'per_token'], got {self.loss_type}" # noqa: E501 - ) + pg_loss = (torch.max(pg_loss1, pg_loss2) * mask).sum() / mask.sum() return pg_loss.mean() def forward(self, logits: torch.Tensor, labels: dict[str, Any]): - """Forward function of ActorLoss. - - Args: - logits (Tensor): Forward result of the model. Its shape may be varied. # noqa: E501 - For packed forward: (micro_bsz * seqlen, 1), where micro_bsz = 1 # noqa: E501 - For non packed forward: (micro_bsz, seqlen, 1) - - labels (tuple[dict]): Label values which are split by pipeline - schedule into pieces. The length of the list is micro_bsz. Each - element is a dict, representing labels to a batch. - - Note: - The parameter `labels` seems strange because of pj-colossalai's - pipeline schedule mechanism. Labels are delivered to colosslai.Engine # noqa: E501 - in List format, so pipeline schedule split it into micro_bsz pieces, # noqa: E501 - and deliver them to loss_fn by `*args`. - - Returns: - Tensor: Return the final loss - """ assert logits.ndim == 3 mask = labels['mask'] @@ -59,7 +65,6 @@ def forward(self, logits: torch.Tensor, labels: dict[str, Any]): input_ids = labels['input_ids'] old_logprobs = labels['old_logprobs'] advantages = labels['advantages'] - loss_factor = labels['loss_factor'] logpy = logprobs_from_logits( logits=logits[:, :-1, :], labels=input_ids[:, 1:], gather=True) @@ -70,7 +75,5 @@ def forward(self, logits: torch.Tensor, labels: dict[str, Any]): logprobs=logprobs, old_logprobs=old_logprobs, advantages=advantages, - mask=mask, - loss_factor=loss_factor, - ) + mask=mask) return loss diff --git a/xtuner/rlhf/loss/critic_loss.py b/xtuner/rlhf/loss/critic_loss.py index 3ad4e2db6..f043cfe27 100644 --- a/xtuner/rlhf/loss/critic_loss.py +++ b/xtuner/rlhf/loss/critic_loss.py @@ -6,52 +6,19 @@ class CriticLoss(torch.nn.Module): """Loss function for critic model.""" - def __init__(self, - cliprange_value: float = 0.5, - loss_type: str = 'per_seq'): + def __init__(self, cliprange_value: float = 0.5): super().__init__() self.cliprange_value = cliprange_value - self.loss_type = loss_type - assert self.loss_type in ['per_token', 'per_seq'] - def critic_loss_fn(self, values, old_values, returns, mask, loss_factor): + def critic_loss_fn(self, values, old_values, returns, mask): values_clipped = old_values + (values - old_values).clamp( -self.cliprange_value, self.cliprange_value) vf_loss1 = (values_clipped - returns)**2 vf_loss2 = (values - returns)**2 - - if self.loss_type == 'per_seq': - vf_loss = (torch.max(vf_loss1, vf_loss2) * mask).sum() / mask.sum() - elif self.loss_type == 'per_token': - vf_loss = torch.sum( - torch.max(vf_loss1, vf_loss2) * mask * loss_factor) - else: - raise RuntimeError( - f"CriticLoss.loss_type must be ['per_seq', 'per_token'], got {self.loss_type}" # noqa: E501 - ) + vf_loss = (torch.max(vf_loss1, vf_loss2) * mask).sum() / mask.sum() return 0.5 * vf_loss.mean() def forward(self, values: torch.Tensor, labels: dict[str, Any]): - """Forward function of CriticLoss. - - Args: - values (Tensor): Forward result of the model. Its shape may be varied. # noqa: E501 - For packed forward: (micro_bsz * seqlen, 1), where micro_bsz = 1 # noqa: E501 - For non packed forward: (micro_bsz, seqlen, 1) - - labels (Tuple[dict]): Label values which are split by pipeline - schedule into pieces. The length of the list is micro_bsz. Each - element is a dict, representing labels to a batch. - - Note: - The parameter `labels` seems strange because of pj-colossalai's - pipeline schedule mechanism. Labels are delivered to colosslai.Engine # noqa: E501 - in List format, so pipeline schedule split it into micro_bsz pieces, # noqa: E501 - and deliver them to loss_fn by `*args`. - - Returns: - Tensor: Return the final loss - """ assert values.ndim == 2 mask = labels['mask'] num_actions = mask.size(1) @@ -59,12 +26,6 @@ def forward(self, values: torch.Tensor, labels: dict[str, Any]): old_values = labels['old_values'] returns = labels['returns'] - loss_factor = labels['loss_factor'] loss = self.critic_loss_fn( - values=values, - old_values=old_values, - returns=returns, - mask=mask, - loss_factor=loss_factor, - ) + values=values, old_values=old_values, returns=returns, mask=mask) return loss diff --git a/xtuner/rlhf/loss/pretrain_loss.py b/xtuner/rlhf/loss/pretrain_loss.py deleted file mode 100644 index 6356291d0..000000000 --- a/xtuner/rlhf/loss/pretrain_loss.py +++ /dev/null @@ -1,36 +0,0 @@ -import torch -from loguru import logger - - -class PretrainLoss(torch.nn.Module): - """Loss function for flash GPT Language Model.""" - - def __init__(self, label_smoothing=0): - super().__init__() - - if label_smoothing is not None and label_smoothing != 0: - logger.warning(f'Use label_smoothing: {label_smoothing}') - self.label_smoothing = label_smoothing - - # Here, the output will gather output is set in the model, so use ordinary loss # noqa: E501 - self.loss_fn = torch.nn.CrossEntropyLoss( - reduction='mean', label_smoothing=label_smoothing) - - def forward(self, *args): - if len(args) == 3: - # residual is to match prenorm - logits, _, labels = args - elif len(args) == 2: - # When using postnorm - logits, labels = args - else: - raise RuntimeError( - f'The number of criterion inputs are:{len(args)}') - shift_logits = logits.contiguous().view(-1, logits.size(-1)) - shift_labels = labels.contiguous().view(-1) - loss = self.loss_fn(shift_logits, shift_labels) - # There is no need to consider the ignore_index problem here, because the loss calculation will be # noqa: E501 - # calculated through the calculation range, and -100 must be outside this range, so there is no problem # noqa: E501 - - return loss - diff --git a/xtuner/rlhf/main.py b/xtuner/rlhf/main.py index ea64cc36a..4576b396d 100644 --- a/xtuner/rlhf/main.py +++ b/xtuner/rlhf/main.py @@ -3,16 +3,15 @@ import os import time -import numpy as np from loguru import logger from xtuner.rlhf.config.config import Config from xtuner.rlhf.coordinator import Coordinator -from xtuner.rlhf.dataset.txt_loader import TxtMessageDataset -from xtuner.rlhf.envs.txt_env import TxtEnv -from xtuner.rlhf.repeaters.base import BaseRepeater -from xtuner.rlhf.tokenizer.tokenizer_utils import get_tokenizer -from xtuner.rlhf.trainer.ppo import PPOTrainer +from xtuner.rlhf.dataset import MessageIter +from xtuner.rlhf.envs import TxtEnv +from xtuner.rlhf.repeaters import BaseRepeater +from xtuner.rlhf.timer import Timer +from xtuner.rlhf.trainer import PPOTrainer def parse_args(): @@ -68,19 +67,6 @@ def validate_config(config: Config): logger.info(f'{k}: {v}') logger.info('#################### CONFIG END ####################') - # init dataset - model_path = config['model_configs']['actor']['model_path'] - tokenizer_config = config.get('tokenizer_config', {}) - for model_type in config['model_configs'].keys(): - if 'tokenizer_config' not in config['model_configs'][model_type]: - config['model_configs'][model_type][ - 'tokenizer_config'] = tokenizer_config - tokenizer = get_tokenizer( - model_path, trust_remote_code=True, **tokenizer_config) - dataset_config = config['dataset_config'] - dataset_config['tokenizer'] = tokenizer - txt_loader = TxtMessageDataset(**dataset_config) - # init model cluster_address = args.address if cluster_address != 'auto': @@ -88,59 +74,69 @@ def validate_config(config: Config): logger.info(f'cluster_address={cluster_address}') coordinator = Coordinator(cluster_address, config['model_configs']) model_dict = coordinator.create_models() - sft_model = model_dict['reference'] + ref_model = model_dict['reference'] actor_model = model_dict['actor'] reward_model = model_dict['reward'] critic_model = model_dict['critic'] - # init txt env + # init prompt & pretrain dataset + prompt_dataset_config = config['prompt_dataset_config'] + prompt_mes_iter = MessageIter( + tokenizer=ref_model.tokenizer, **prompt_dataset_config) + pretrain_dataset_config = config['pretrain_dataset_config'] + pretrain_mes_iter = MessageIter( + tokenizer=ref_model.tokenizer, **pretrain_dataset_config) + # init txt env rollout_config = config.get('rollout_config', {}) txt_env = TxtEnv( - dataloader=txt_loader, + prompt_mes_iter=prompt_mes_iter, + pretrain_mes_iter=pretrain_mes_iter, reward_function=reward_model, **rollout_config, ) # init repeater repeater_config = config.get('repeater_config', {}) rl_repeater = BaseRepeater( - sft_model=sft_model, + ref_model=ref_model, **repeater_config, ) # init trainer train_config = config.get('train_config', {}) ppo = PPOTrainer( - policy_model=actor_model, value_model=None, **train_config) - pretrain_step = train_config['pretrain_step'] + policy_model=actor_model, critic_model=None, **train_config) + critic_warmup_step = train_config['critic_warmup_step'] save_interval = train_config['save_interval'] - np.set_printoptions(threshold=np.inf) - step = 1 - while True: + max_train_step = train_config.get('max_train_step', float('inf')) + + step = 0 + while step <= max_train_step: s_t = time.time() - trajectories = txt_env.rollout(policy_model=actor_model) - # deal with trajectories - trajectories = rl_repeater.process( - trajectories, - policy_model=actor_model, - value_model=critic_model, - sft_model=None, - env=txt_env) - - # # for value & policy learn - value_loss_ref = ppo.value_learn_async(trajectories, critic_model) - - ppo_loss, pt_loss = None, None - if pretrain_step <= 0: - ppo_loss, pt_loss = ppo.policy_learn(trajectories, actor_model) - logger_train.info( - f'[Policy Train] Step: {step}, ppo loss: {ppo_loss}, pretrain loss: {pt_loss}' # noqa: E501 - ) - - value_loss = ppo.value_learn_get(value_loss_ref, critic_model) + with Timer(f'step {step}: end_to_end'): + trajectories = txt_env.rollout(policy_model=actor_model) + # deal with trajectories + trajectories = rl_repeater.process( + trajectories, + policy_model=actor_model, + critic_model=critic_model, + ref_model=None, + env=txt_env) + + # # for critic & policy learn + critic_loss_ref = ppo.critic_learn_async(trajectories, + critic_model) + + ppo_loss, pt_loss = None, None + if critic_warmup_step <= 0: + ppo_loss, pt_loss = ppo.policy_learn(trajectories, actor_model) + logger_train.info(f'[Policy Train] Step: {step}, \ + ppo loss: {ppo_loss}, pretrain loss: {pt_loss}') + + critic_loss = ppo.critic_learn_get(critic_loss_ref, critic_model) logger_train.info( - f'[Value Train] step: {step}, value loss: {value_loss}') + f'[Critic Train] step: {step}, critic loss: {critic_loss}') logger_train.info(f'rewards: {trajectories.rewards.mean()}') - pretrain_step -= 1 + critic_warmup_step -= 1 if config['rollout_config'].get('write_to_file', True): if not os.path.exists(f'{work_dir}/rollouts'): @@ -163,7 +159,7 @@ def validate_config(config: Config): step=step, policy_loss=ppo_loss, pretrain_loss=pt_loss, - critic_loss=value_loss, + critic_loss=critic_loss, ) with open(f'{work_dir}/train_rlhf.log.jsonl', 'a') as f: f.write(json.dumps(summaries) + '\n') diff --git a/xtuner/rlhf/model_backend/generate_utils.py b/xtuner/rlhf/model_backend/generate_utils.py index 15adb5640..095dba9f9 100644 --- a/xtuner/rlhf/model_backend/generate_utils.py +++ b/xtuner/rlhf/model_backend/generate_utils.py @@ -33,7 +33,7 @@ def get_question_answer_mask( def partition_by_micro_batch_size( - input_ids: Union[list[str], torch.Tensor, list[int]], + input_ids: Union[torch.Tensor, list[int]], micro_batch_size: int, attention_mask: torch.Tensor = None, position_ids: torch.Tensor = None, @@ -123,7 +123,7 @@ def partition_list_by_micro_batch_size( micro_batches = [[{} for i in range(length)] for _ in range(num_splits)] if attention_mask is None: attention_mask = [None for _ in range(length)] - if position_ids == None: + if position_ids is None: position_ids = [None for _ in range(length)] for i in range(length): sub_input_ids = input_ids[i] diff --git a/xtuner/rlhf/model_backend/hf_model_runner.py b/xtuner/rlhf/model_backend/hf_model_runner.py index b873a3497..15861cb2d 100644 --- a/xtuner/rlhf/model_backend/hf_model_runner.py +++ b/xtuner/rlhf/model_backend/hf_model_runner.py @@ -25,7 +25,7 @@ from ..utils import set_seed from .dist_utils import init_process_group from .generate_utils import (get_answer_str, get_question_answer_mask, - merge_loss_list, partition_by_micro_batch_size, + partition_by_micro_batch_size, partition_list_by_micro_batch_size) from .ray_actor_group import RayActorGroup from .ray_actor_mixin import RayActorMixin @@ -40,7 +40,7 @@ class HfModelRunner: - """ModelTrainer is capable of training, inference, and generation.""" + """HfModelRunner is capable of training, inference, and generation.""" def __init__(self, model_config): self.model_config: dict = model_config @@ -227,8 +227,6 @@ def train( debug=False, **_ignored, ): - return_list = True - if isinstance(input_ids, torch.Tensor): input_ids = [input_ids] labels = [labels] @@ -662,7 +660,7 @@ def __init__(self, name: str, config: dict): self.released = True num_gpus = get_gpu_requirement(config) self.dp_size = get_dp_size(config) - self.tokenizer_pad_token_id = config.tokenizer_config.pad_token_id + self.tokenizer_pad_token_id = config.tokenizer_config['pad_token_id'] bundles = [{ 'CPU': DEFAULT_NUM_CPUS, 'GPU': DEFAULT_NUM_GPUS diff --git a/xtuner/rlhf/model_backend/vllm_model_runner.py b/xtuner/rlhf/model_backend/vllm_model_runner.py index c04677a1f..6acda480b 100644 --- a/xtuner/rlhf/model_backend/vllm_model_runner.py +++ b/xtuner/rlhf/model_backend/vllm_model_runner.py @@ -244,7 +244,7 @@ def __init__(self, name: str, config: dict): self.config = config self.tp_size = get_tp_size(config) # tensor parallelism self.dp_size = get_dp_size(config) # num of vllm_engines - self.tokenizer_pad_token_id = config.tokenizer_config.pad_token_id + self.tokenizer_pad_token_id = config.tokenizer_config['pad_token_id'] self.ray_actors: list[VllmGeneratorRayActor] = [] # i.e., vllm_engines # Adapted from https://github.com/OpenLLMAI/OpenRLHF/blob/v0.2.5/openrlhf/trainer/ray/vllm_engine.py # noqa: E501 @@ -319,7 +319,7 @@ def generate_async(self, input_ids, attention_mask, *args, **kwargs): def generate_get(self, object_refs, timeout=None): outputs = ray.get(object_refs, timeout=timeout) padding_token_map = { - 'output_ids': self.config.tokenizer_config.pad_token_id + 'output_ids': self.config.tokenizer_config['pad_token_id'] } return concat_policy_outputs(outputs, padding_token_map) diff --git a/xtuner/rlhf/model_server/actor_model_server.py b/xtuner/rlhf/model_server/actor_model_server.py index a18cf43f4..e6f992a51 100644 --- a/xtuner/rlhf/model_server/actor_model_server.py +++ b/xtuner/rlhf/model_server/actor_model_server.py @@ -25,8 +25,8 @@ def initialize_async(self): generator_config['model_path'] = self.model_config['model_path'] generator_config['tokenizer_config'] = self.tokenizer_config - generator_config[ - 'tokenizer_path'] = self.tokenizer_config.tokenizer_path + generator_config['tokenizer_path'] = self.tokenizer_config[ + 'tokenizer_path'] generator_type = generator_config.get('generator_type', None) if generator_type == ENGINE_VLLM: from ..model_backend.vllm_model_runner import \ diff --git a/xtuner/rlhf/model_server/base_model_server.py b/xtuner/rlhf/model_server/base_model_server.py index 884482f4e..f5cc93975 100644 --- a/xtuner/rlhf/model_server/base_model_server.py +++ b/xtuner/rlhf/model_server/base_model_server.py @@ -110,8 +110,8 @@ def train_async(self, position_ids=None, *args, **train_kwargs): - return self.trainer.train_async(input_ids, labels, attention_mask, position_ids, - *args, **train_kwargs) + return self.trainer.train_async(input_ids, labels, attention_mask, + position_ids, *args, **train_kwargs) def train_get(self, object_refs, timeout: Optional[float] = None): return self.trainer.train_get(object_refs, timeout=timeout) @@ -123,8 +123,8 @@ def train(self, position_ids=None, *args, **train_kwargs): - object_refs = self.train_async(input_ids, labels, attention_mask, position_ids, - *args, **train_kwargs) + object_refs = self.train_async(input_ids, labels, attention_mask, + position_ids, *args, **train_kwargs) loss = self.train_get(object_refs) self.log_cuda_mem_stats(remark='[train] ') return loss diff --git a/xtuner/rlhf/policy_output.py b/xtuner/rlhf/policy_output.py index a7287cee6..34aa3c34c 100644 --- a/xtuner/rlhf/policy_output.py +++ b/xtuner/rlhf/policy_output.py @@ -125,17 +125,17 @@ def concat_policy_outputs(policy_outputs: list[PolicyOutput], def padding_policy_outputs(policy_outputs: list[PolicyOutput], - padding_token_map={}): - DEFAULT_PADDING_ID = 0 - RIGHT_PADDING = True + padding_token_map={}, + right_padding=True, + padding_id=0): tensor_keys = union_tensor_keys_from_policy_outputs(policy_outputs) for key in tensor_keys: - padding_id = padding_token_map.get(key, DEFAULT_PADDING_ID) + padding_id = padding_token_map.get(key, padding_id) max_seq_len = find_max_seq_len(policy_outputs, key) for policy_output in policy_outputs: origin_tensor = policy_output[key] padding_size = max_seq_len - origin_tensor.shape[1] - pad = (0, padding_size) if RIGHT_PADDING else (padding_size, 0) + pad = (0, padding_size) if right_padding else (padding_size, 0) padded_tensor = torch.nn.functional.pad( origin_tensor, pad, mode='constant', value=padding_id) policy_output[key] = padded_tensor @@ -147,8 +147,7 @@ def find_max_seq_len(policy_outputs: list[PolicyOutput], key): for policy_output in policy_outputs: if policy_output[key] is None: continue - batch_size, seq_len = policy_output[ - key].shape # assert: only support 2d tensor + batch_size, seq_len = policy_output[key].shape[:2] max_seq_len = seq_len if seq_len > max_seq_len else max_seq_len return max_seq_len @@ -171,4 +170,4 @@ def logprobs_from_logits(logits: torch.Tensor, if not gather or labels is None: return logp logpy = torch.gather(logp, -1, labels.unsqueeze(2)).squeeze(-1) - return logpy.cuda() + return logpy diff --git a/xtuner/rlhf/repeaters/__init__.py b/xtuner/rlhf/repeaters/__init__.py index e69de29bb..1ded00298 100644 --- a/xtuner/rlhf/repeaters/__init__.py +++ b/xtuner/rlhf/repeaters/__init__.py @@ -0,0 +1,3 @@ +from .base import BaseRepeater + +__all__ = ['BaseRepeater'] diff --git a/xtuner/rlhf/repeaters/base.py b/xtuner/rlhf/repeaters/base.py index 156fe3e02..ac8833a60 100644 --- a/xtuner/rlhf/repeaters/base.py +++ b/xtuner/rlhf/repeaters/base.py @@ -1,10 +1,8 @@ -import time - import torch -from loguru import logger from ..model_server.base_model_server import BaseModelServer from ..policy_output import PolicyOutput +from ..timer import Timer from .running_mean_std import RunningStates @@ -12,7 +10,7 @@ class BaseRepeater: def __init__( self, - sft_model, + ref_model, actor_micro_bs: int = 8, ref_micro_bs: int = 8, critic_micro_bs: int = 32, @@ -27,7 +25,7 @@ def __init__( fine_grained_rm: bool = False, **_ignored, ): - self.sft_model = sft_model + self.ref_model = ref_model self.actor_micro_bs = actor_micro_bs self.ref_micro_bs = ref_micro_bs self.critic_micro_bs = critic_micro_bs @@ -45,26 +43,27 @@ def process( self, trajectories: PolicyOutput, policy_model: BaseModelServer, - value_model: BaseModelServer, - sft_model: BaseModelServer = None, + critic_model: BaseModelServer, + ref_model: BaseModelServer = None, # only used for async reward model.infer_get() in _get_kl_rewards env=None, ): - value_output_ref = self._get_values_async(trajectories, value_model) + critic_output_ref = self._get_values_async(trajectories, critic_model) action_mask = trajectories['action_mask'] num_actions = action_mask.size(1) - if sft_model is not None: - self.sft_model: BaseModelServer = sft_model - kl_rewards, entropy, kl_distance, policy_logprobs, sft_logprobs = self._get_kl_rewards( # noqa: E501 - trajectories, policy_model, env=env) + if ref_model is not None: + self.ref_model: BaseModelServer = ref_model + (kl_rewards, entropy, kl_distance, policy_logprobs, + ref_logprobs) = self._get_kl_rewards( + trajectories, policy_model, env=env) trajectories['kl'] = (kl_distance * action_mask).sum( axis=-1) / action_mask.sum(axis=-1) trajectories['entropy'] = entropy trajectories['kl_rewards'] = kl_rewards trajectories['policy_logprobs'] = policy_logprobs - trajectories['sft_logprobs'] = sft_logprobs + trajectories['ref_logprobs'] = ref_logprobs - values = self._get_values_collect(value_output_ref, value_model) + values = self._get_values_collect(critic_output_ref, critic_model) old_values = values[:, -num_actions:] advantages, returns = self.get_advantages_and_returns( old_values, kl_rewards, action_mask) @@ -79,24 +78,24 @@ def _get_kl_rewards(self, trajectories: PolicyOutput, policy_model: BaseModelServer, env=None): - s_t = time.time() - policy_output = policy_model.infer_async( - inputs=trajectories.output_ids, - micro_batch_size=self.actor_micro_bs, - attention_mask=trajectories.attention_mask, - output_logits=False, - output_logprobs=True) - sft_output = self.sft_model.infer_async( - inputs=trajectories.output_ids, - micro_batch_size=self.ref_micro_bs, - attention_mask=trajectories.attention_mask, - output_logits=False, - output_logprobs=True) - policy_output = policy_model.infer_get(policy_output) - sft_output = self.sft_model.infer_get(sft_output) - logger.info( - f'[actor & ref infer_async] duration: {round(time.time() - s_t, 2)} s' # noqa: E501 - ) + with Timer('policy_model.infer_async'): + policy_output = policy_model.infer_async( + inputs=trajectories.output_ids, + micro_batch_size=self.actor_micro_bs, + attention_mask=trajectories.attention_mask, + output_logits=False, + output_logprobs=True) + with Timer('ref_model.infer_async'): + ref_output = self.ref_model.infer_async( + inputs=trajectories.output_ids, + micro_batch_size=self.ref_micro_bs, + attention_mask=trajectories.attention_mask, + output_logits=False, + output_logprobs=True) + with Timer('ref_model.infer_get'): + policy_output = policy_model.infer_get(policy_output) + with Timer('ref_model.infer_get'): + ref_output = self.ref_model.infer_get(ref_output) # Experimental if env.async_reward: @@ -111,18 +110,19 @@ def _get_kl_rewards(self, if self.norm_rewards: self.running_states.update(clipped_rewards) - norm_reward_score = (clipped_rewards - self.running_states.mean) / ( - self.running_states.var.sqrt() + 1e-8) + norm_reward_score = (clipped_rewards - + self.running_states.mean) / ( + self.running_states.var.sqrt() + 1e-8) action_mask = trajectories.action_mask num_actions = action_mask.size(1) policy_logprobs = policy_output.logprobs[:, -num_actions:] - sft_logprobs = sft_output.logprobs[:, -num_actions:] + ref_logprobs = ref_output.logprobs[:, -num_actions:] if self.kl_coeff <= 0.0: self.kl_coeff = 0.0 # compute_approx_kl - log_ratio = policy_logprobs - sft_logprobs + log_ratio = policy_logprobs - ref_logprobs kl = log_ratio * action_mask kl_reward = -self.kl_coeff * kl @@ -138,43 +138,36 @@ def _get_kl_rewards(self, entropy = -(policy_logprobs * action_mask).sum(axis=-1) / action_mask.sum(axis=-1) - return reward, entropy, kl, policy_logprobs, sft_logprobs + return reward, entropy, kl, policy_logprobs, ref_logprobs def _get_values(self, trajectories: PolicyOutput, - value_model: BaseModelServer): - s_t = time.time() - value_output = value_model.infer( - inputs=trajectories.output_ids, - attention_mask=trajectories.attention_mask, - output_logits=True, - micro_batch_size=self.critic_micro_bs, - ) - logger.info( - f'[critic infer] duration: {round(time.time() - s_t, 2)} s') - raw_values = value_output.logits.squeeze(-1) + critic_model: BaseModelServer): + with Timer('critic_model.infer'): + critic_output = critic_model.infer( + inputs=trajectories.output_ids, + attention_mask=trajectories.attention_mask, + output_logits=True, + micro_batch_size=self.critic_micro_bs, + ) + raw_values = critic_output.logits.squeeze(-1) return raw_values def _get_values_async(self, trajectories: PolicyOutput, - value_model: BaseModelServer): - s_t = time.time() - value_output_ref = value_model.infer_async( - inputs=trajectories.output_ids, - attention_mask=trajectories.attention_mask, - output_logits=True, - micro_batch_size=self.critic_micro_bs, - ) - logger.info( - f'[critic infer] async duration: {round(time.time() - s_t, 2)} s') - return value_output_ref - - def _get_values_collect(self, value_output_ref, - value_model: BaseModelServer): - s_t = time.time() - value_output = value_model.infer_get(value_output_ref) - raw_values = value_output.logits.squeeze(-1) - logger.info( - f'[critic infer] async wait duration: {round(time.time() - s_t, 2)} s' # noqa: E501 - ) + critic_model: BaseModelServer): + with Timer('critic_model.infer_async'): + critic_output_ref = critic_model.infer_async( + inputs=trajectories.output_ids, + attention_mask=trajectories.attention_mask, + output_logits=True, + micro_batch_size=self.critic_micro_bs, + ) + return critic_output_ref + + def _get_values_collect(self, critic_output_ref, + critic_model: BaseModelServer): + with Timer('critic_model.infer_get'): + critic_output = critic_model.infer_get(critic_output_ref) + raw_values = critic_output.logits.squeeze(-1) return raw_values def get_advantages_and_returns( @@ -184,9 +177,10 @@ def get_advantages_and_returns( action_mask: torch.Tensor, ): # Adopted from https://github.com/CarperAI/trlx/blob/main/trlx/models/modeling_ppo.py#L134 # noqa: E501 - """Function that computes advantages and returns from rewards and values. - Calculated as in the original PPO paper: https://arxiv.org/abs/1707.06347 - Note that rewards may include a KL divergence loss term. + """Function that computes advantages and returns from rewards and + values. Calculated as in the original PPO paper: + https://arxiv.org/abs/1707.06347 Note that rewards may include a KL + divergence loss term. Advantages looks like this: Adv1 = R1 + γ * λ * R2 + γ^2 * λ^2 * R3 + ... @@ -195,12 +189,6 @@ def get_advantages_and_returns( Returns looks like this: Ret1 = R1 + γ * λ * R2 + γ^2 * λ^2 * R3 + ... + γ * (1 - λ) V2 + γ^2 * λ * (1 - λ) V3 + ... - - Args: - values: Tensor of shape (batch_size, response_size) - rewards: Tensor of shape (batch_size, response_size) - response_length: Length of the response sequence - use_whitening: Whether to use whitening (ie. normalize advantages) or not """ lastgaelam = 0 advantages_reversed = [] @@ -212,8 +200,10 @@ def get_advantages_and_returns( for t in reversed(range(response_length)): nextvalues = values[:, t + 1] if t < response_length - 1 else 0.0 - # Since old_rewards and old_values are masked with action_mask, i.e. they have - # 0's at pad tokens, delta will be 0 if current t is at a pad token, so will lastgaelam + # Since old_rewards and old_values are masked with action_mask, + # i.e. they have 0's at pad tokens, + # delta will be 0 if current t is at a pad token, + # so will lastgaelam delta = rewards[:, t] + self.gamma * nextvalues - values[:, t] lastgaelam = delta + self.gamma * self.gae_lambda * lastgaelam advantages_reversed.append(lastgaelam) diff --git a/xtuner/rlhf/tokenizer/tokenizer_utils.py b/xtuner/rlhf/tokenizer/tokenizer_utils.py index d782c6eb0..a21e615b9 100644 --- a/xtuner/rlhf/tokenizer/tokenizer_utils.py +++ b/xtuner/rlhf/tokenizer/tokenizer_utils.py @@ -1,12 +1,9 @@ from typing import Optional, Union +from loguru import logger from transformers import (AutoTokenizer, LlamaTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast) -from ..logger import init_logger - -logger = init_logger(__name__) - PADDING_SIDE = 'left' diff --git a/xtuner/rlhf/trainer/__init__.py b/xtuner/rlhf/trainer/__init__.py index e69de29bb..855182fb3 100644 --- a/xtuner/rlhf/trainer/__init__.py +++ b/xtuner/rlhf/trainer/__init__.py @@ -0,0 +1,3 @@ +from .ppo import PPOTrainer + +__all__ = ['PPOTrainer'] diff --git a/xtuner/rlhf/trainer/ppo.py b/xtuner/rlhf/trainer/ppo.py index ccfbeb61f..28564de8d 100644 --- a/xtuner/rlhf/trainer/ppo.py +++ b/xtuner/rlhf/trainer/ppo.py @@ -1,11 +1,6 @@ -import time - -import torch from loguru import logger -from ..loss.actor_loss import ActorLoss -from ..loss.critic_loss import CriticLoss -from ..loss.pretrain_loss import PretrainLoss +from ..loss import CriticLoss, PPOPolicyLoss, PretrainLoss from ..model_server.base_model_server import BaseModelServer from ..timer import Timer @@ -13,19 +8,19 @@ class PPOTrainer: def __init__( - self, - actor_micro_bs=2, - critic_micro_bs=2, - policy_learn_time=1, - value_learn_time=1, - policy_minibatch=None, - value_minibatch=None, - ppo_loss_weight=1.0, - pretrain_loss_weight=0.5, - pretrain_criterion=PretrainLoss(label_smoothing=0), - policy_criterion=ActorLoss(cliprange=0.2, loss_type='per_seq'), - value_criterion=CriticLoss(cliprange_value=0.5, loss_type='per_seq'), - **kwargs, + self, + actor_micro_bs=2, + critic_micro_bs=2, + policy_learn_time=1, + critic_learn_time=1, + policy_minibatch=None, + critic_minibatch=None, + ppo_loss_weight=1.0, + pretrain_loss_weight=0.5, + pretrain_criterion=PretrainLoss(label_smoothing=0), + policy_criterion=PPOPolicyLoss(cliprange=0.2), + critic_criterion=CriticLoss(cliprange_value=0.5), + **kwargs, ): self.actor_micro_bs = actor_micro_bs @@ -34,15 +29,15 @@ def __init__( self.policy_learn_time = policy_learn_time self.policy_minibatch = policy_minibatch - # value - self.value_learn_time = value_learn_time - self.value_minibatch = value_minibatch + # critic + self.critic_learn_time = critic_learn_time + self.critic_minibatch = critic_minibatch self.ppo_loss_weight = ppo_loss_weight self.pretrain_loss_weight = pretrain_loss_weight self.pretrain_criterion = pretrain_criterion self.policy_criterion = policy_criterion - self.value_criterion = value_criterion + self.critic_criterion = critic_criterion def policy_learn(self, trajectories, policy_model: BaseModelServer): if self.policy_minibatch is None: @@ -79,7 +74,6 @@ def policy_learn(self, trajectories, policy_model: BaseModelServer): trajectories.output_ids[begin:end, :] ) == self.policy_minibatch, '[Policy learn] make sure len(policy_batch_inputs) == self.policy_minibatch' # noqa: E501 - loss_factor = 1.0 train_lables = [ dict( input_ids=trajectories.output_ids[begin:end, :], @@ -87,14 +81,12 @@ def policy_learn(self, trajectories, policy_model: BaseModelServer): begin:end, :], advantages=trajectories.advantages[begin:end, :], mask=trajectories.action_mask[begin:end, :], - loss_factor=torch.tensor(loss_factor), ), ] # pretrain data if trajectories.pretrain_data is not None: - logger.info( - f'[Policy Train] policy train with pretrain data {trajectories.pretrain_data["input_ids"].shape}' - ) + logger.info(f'[Policy Train] pretrain data \ + {trajectories.pretrain_data["input_ids"].shape}') train_input_ids.append( trajectories.pretrain_data['input_ids']) train_lables.append(trajectories.pretrain_data['labels']) @@ -105,90 +97,87 @@ def policy_learn(self, trajectories, policy_model: BaseModelServer): loss_weights.append(self.pretrain_loss_weight) micro_batch_size.append(self.actor_micro_bs) - s_t = time.time() - p_loss = policy_model.train( - input_ids=train_input_ids, - labels=train_lables, - attention_mask=train_attention_mask, - # position_ids=train_position_ids, - criterion=train_criterion, - loss_weights=loss_weights, - micro_batch_size=micro_batch_size) + with Timer('policy_model.train'): + p_loss = policy_model.train( + input_ids=train_input_ids, + labels=train_lables, + attention_mask=train_attention_mask, + # position_ids=train_position_ids, + criterion=train_criterion, + loss_weights=loss_weights, + micro_batch_size=micro_batch_size) if isinstance(p_loss, list): ppo_loss.append(p_loss[0].item()) pretrain_loss.append(p_loss[1].item()) logger.info( - f'[Policy Train] duration: {round(time.time() - s_t, 2)} s, prompt data: {train_input_ids[0].shape}, ppo loss: {p_loss[0].item()}; pretrain data: {train_input_ids[1].shape}, pretrain loss: {p_loss[1].item()}' + f'[Policy Train] prompt data: {train_input_ids[0].shape}, ppo loss: {p_loss[0].item()}; pretrain data: {train_input_ids[1].shape}, pretrain loss: {p_loss[1].item()}' # noqa: E501 ) else: ppo_loss.append(p_loss.item()) logger.info( - f'[Policy Train] duration: {round(time.time() - s_t, 2)} s, prompt data: {train_input_ids[0].shape}, ppo loss: {p_loss.item()}' + f'[Policy Train] prompt data: {train_input_ids[0].shape}, ppo loss: {p_loss.item()}' # noqa: E501 ) with Timer('policy_model.sync_model'): policy_model.sync_model() return ppo_loss, pretrain_loss - def value_learn_async(self, trajectories, value_model: BaseModelServer): - if self.value_minibatch is None: - self.value_minibatch = len(trajectories.output_ids) - value_updates = len(trajectories.output_ids) // self.value_minibatch - value_loss = [] - assert value_updates == 1 and self.policy_learn_time == 1, f'value_updates={value_updates} * self.policy_learn_time={self.policy_learn_time} > 1' # noqa: E501 - s_t = time.time() - value_batch_inputs, labels = self._value_learn_prepare( - 0, 0, trajectories, value_updates) - v_loss_ref = value_model.train_async( - input_ids=value_batch_inputs['input_ids'], - labels=labels, - attention_mask=value_batch_inputs['attention_mask'], - criterion=self.value_criterion, - micro_batch_size=self.critic_micro_bs, - ) - logger.info( - f'[critic train] async duration: {round(time.time() - s_t, 2)} s, {self.value_minibatch} batch' # noqa: E501 - ) - value_loss.append(v_loss_ref) - return value_loss - - def value_learn_get(self, value_loss_ref, value_model: BaseModelServer): - with Timer('value_model.train_get'): + def critic_learn_async(self, trajectories, critic_model: BaseModelServer): + if self.critic_minibatch is None: + self.critic_minibatch = len(trajectories.output_ids) + critic_updates = len(trajectories.output_ids) // self.critic_minibatch + critic_loss = [] + assert critic_updates == 1 and self.policy_learn_time == 1, f'critic_updates={critic_updates} * self.policy_learn_time={self.policy_learn_time} > 1' # noqa: E501 + with Timer('critic_model.train_async'): + critic_batch_inputs, labels = self._critic_learn_prepare( + 0, 0, trajectories, critic_updates) + v_loss_ref = critic_model.train_async( + input_ids=critic_batch_inputs['input_ids'], + labels=labels, + attention_mask=critic_batch_inputs['attention_mask'], + criterion=self.critic_criterion, + micro_batch_size=self.critic_micro_bs, + ) + logger.info(f'[critic train] {self.critic_minibatch} batch') + critic_loss.append(v_loss_ref) + return critic_loss + + def critic_learn_get(self, critic_loss_ref, critic_model: BaseModelServer): + with Timer('critic_model.train_get'): return [ - value_model.train_get(ref).item() for ref in value_loss_ref + critic_model.train_get(ref).item() for ref in critic_loss_ref ] - def value_learn(self, trajectories, value_model: BaseModelServer): - if self.value_minibatch is None: - self.value_minibatch = len(trajectories.output_ids) - value_updates = len(trajectories.output_ids) // self.value_minibatch - value_loss = [] + def critic_learn(self, trajectories, critic_model: BaseModelServer): + if self.critic_minibatch is None: + self.critic_minibatch = len(trajectories.output_ids) + critic_updates = len(trajectories.output_ids) // self.critic_minibatch + critic_loss = [] for learn_i in range(self.policy_learn_time): - for step_i in range(value_updates): - s_t = time.time() - value_batch_inputs, labels = self._value_learn_prepare( - step_i, learn_i, trajectories, value_updates) - v_loss = value_model.train( - input_ids=value_batch_inputs['input_ids'], - labels=labels, - attention_mask=value_batch_inputs['attention_mask'], - criterion=self.value_criterion, - micro_batch_size=self.critic_micro_bs, - ) - logger.info( - f'[critic train] duration: {round(time.time() - s_t, 2)} s, {self.value_minibatch} batch,value loss: {v_loss.item()}' # noqa: E501 - ) - value_loss.append(v_loss.item()) - return value_loss - - def _value_learn_prepare(self, step_i, learn_i, trajectories, - value_updates): - logger.info('[Value Train] start value trains {}/{} | {}'.format( - step_i + 1, value_updates, learn_i + 1)) - begin = step_i * self.value_minibatch - end = begin + self.value_minibatch - value_batch_inputs = { + for step_i in range(critic_updates): + with Timer('critic_model.train'): + critic_batch_inputs, labels = self._critic_learn_prepare( + step_i, learn_i, trajectories, critic_updates) + v_loss = critic_model.train( + input_ids=critic_batch_inputs['input_ids'], + labels=labels, + attention_mask=critic_batch_inputs['attention_mask'], + criterion=self.critic_criterion, + micro_batch_size=self.critic_micro_bs, + ) + logger.info(f'[Critic train] {self.critic_minibatch} batch, \ + critic loss: {v_loss.item()}') + critic_loss.append(v_loss.item()) + return critic_loss + + def _critic_learn_prepare(self, step_i, learn_i, trajectories, + critic_updates): + logger.info('[Critic Train] start critic trains {}/{} | {}'.format( + step_i + 1, critic_updates, learn_i + 1)) + begin = step_i * self.critic_minibatch + end = begin + self.critic_minibatch + critic_batch_inputs = { 'input_ids': trajectories.output_ids[begin:end, :], 'old_values': trajectories.old_values[begin:end, :], 'returns': trajectories.returns[begin:end, :], @@ -196,14 +185,12 @@ def _value_learn_prepare(self, step_i, learn_i, trajectories, 'attention_mask': trajectories.attention_mask[begin:end, :] } assert len( - value_batch_inputs['input_ids'] - ) == self.value_minibatch, '[Value learn] make sure len(value_batch_inputs) == self.value_minibatch' # noqa: E501 + critic_batch_inputs['input_ids'] + ) == self.critic_minibatch, '[critic learn] make sure len(critic_batch_inputs) == self.critic_minibatch' # noqa: E501 - loss_factor = 1.0 labels = dict( - old_values=value_batch_inputs['old_values'], - returns=value_batch_inputs['returns'], - mask=value_batch_inputs['action_mask'], - loss_factor=torch.tensor(loss_factor), + old_values=critic_batch_inputs['old_values'], + returns=critic_batch_inputs['returns'], + mask=critic_batch_inputs['action_mask'], ) - return value_batch_inputs, labels + return critic_batch_inputs, labels From 6adc32cfd8591cfb0d37ebb37f7a35ad71f8ff71 Mon Sep 17 00:00:00 2001 From: lishuaibin Date: Wed, 19 Jun 2024 15:34:50 +0800 Subject: [PATCH 04/37] rename,remove --- examples/rlhf/four_model_8gpu.py | 39 ++-- examples/rlhf/four_model_vllm_8gpu.py | 39 ++-- examples/rlhf/quick_start.md | 4 +- xtuner/rlhf/config/config_consts.py | 4 +- xtuner/rlhf/coordinator.py | 14 +- xtuner/rlhf/dataset/message_iter.py | 18 +- xtuner/rlhf/envs/base.py | 12 + xtuner/rlhf/envs/txt_env.py | 50 ++-- .../rlhf/envs/{prompt_utils.py => utils.py} | 0 xtuner/rlhf/loss/__init__.py | 2 +- .../loss/{actor_loss.py => policy_loss.py} | 6 +- xtuner/rlhf/main.py | 43 ++-- xtuner/rlhf/model_backend/hf_model_runner.py | 25 +- xtuner/rlhf/model_backend/vllm_worker_wrap.py | 2 +- xtuner/rlhf/model_server/__init__.py | 10 + xtuner/rlhf/model_server/base_model_server.py | 9 +- ...model_server.py => policy_model_server.py} | 8 +- .../rlhf/model_server/reward_model_server.py | 8 +- xtuner/rlhf/policy_output.py | 59 ++--- xtuner/rlhf/repeaters/__init__.py | 5 +- xtuner/rlhf/repeaters/base.py | 215 +----------------- xtuner/rlhf/repeaters/kl_gae.py | 208 +++++++++++++++++ .../{running_mean_std.py => utils.py} | 0 xtuner/rlhf/tokenizer/__init__.py | 3 + xtuner/rlhf/tokenizer/tokenizer_utils.py | 17 +- xtuner/rlhf/trainer/ppo.py | 139 ++++++----- xtuner/rlhf/utils.py | 3 +- 27 files changed, 476 insertions(+), 466 deletions(-) create mode 100644 xtuner/rlhf/envs/base.py rename xtuner/rlhf/envs/{prompt_utils.py => utils.py} (100%) rename xtuner/rlhf/loss/{actor_loss.py => policy_loss.py} (94%) rename xtuner/rlhf/model_server/{actor_model_server.py => policy_model_server.py} (94%) create mode 100644 xtuner/rlhf/repeaters/kl_gae.py rename xtuner/rlhf/repeaters/{running_mean_std.py => utils.py} (100%) diff --git a/examples/rlhf/four_model_8gpu.py b/examples/rlhf/four_model_8gpu.py index b4ebb31ce..9a105668d 100644 --- a/examples/rlhf/four_model_8gpu.py +++ b/examples/rlhf/four_model_8gpu.py @@ -9,17 +9,27 @@ PRETRAIN_BATCH_SIZE = 32 GENERATE_MICRO_BATCH_SIZE = 16 -AC_INFER_MICRO_BATCH_SIZE = 8 -REF_INFER_MICRO_BATCH_SIZE = 8 +INFER_MICRO_BATCH_SIZE = 8 TRAIN_MICRO_BATCH_SIZE = 2 ZERO_STAGE = 3 -ACTOR_DP_SIZE = 2 +POLICY_DP_SIZE = 2 CRITIC_DP_SIZE = 2 -ACTOR_GRADIENT_ACC_STEP = (PROMPT_BATCH_SIZE + PRETRAIN_BATCH_SIZE - ) // ACTOR_DP_SIZE // TRAIN_MICRO_BATCH_SIZE +POLICY_GRADIENT_ACC_STEP = (PROMPT_BATCH_SIZE + PRETRAIN_BATCH_SIZE + ) // POLICY_DP_SIZE // TRAIN_MICRO_BATCH_SIZE CRITIC_GRADIENT_ACC_STEP = PROMPT_BATCH_SIZE // CRITIC_DP_SIZE // TRAIN_MICRO_BATCH_SIZE # noqa: E501 +# checkout generate config +assert PROMPT_BATCH_SIZE % GENERATE_MICRO_BATCH_SIZE == 0 +assert PROMPT_BATCH_SIZE % POLICY_DP_SIZE == 0 +# checkout infer config +assert PROMPT_BATCH_SIZE % (INFER_MICRO_BATCH_SIZE * POLICY_DP_SIZE) == 0 +assert PROMPT_BATCH_SIZE % (INFER_MICRO_BATCH_SIZE * CRITIC_DP_SIZE) == 0 +# checkout learn config +assert (PROMPT_BATCH_SIZE + PRETRAIN_BATCH_SIZE) % (TRAIN_MICRO_BATCH_SIZE * + POLICY_DP_SIZE) == 0 +assert (PROMPT_BATCH_SIZE) % (TRAIN_MICRO_BATCH_SIZE * CRITIC_DP_SIZE) == 0 + MODEL_DTYPE = 'auto' tokenizer_config = dict( @@ -29,7 +39,7 @@ ) rollout_config = dict( - actor_micro_bs=GENERATE_MICRO_BATCH_SIZE, + policy_micro_bs=GENERATE_MICRO_BATCH_SIZE, reward_micro_bs=GENERATE_MICRO_BATCH_SIZE, max_new_tokens=MAX_ANSWER_LEN, write_to_file=True, @@ -47,20 +57,19 @@ ) repeater_config = dict( - actor_micro_bs=AC_INFER_MICRO_BATCH_SIZE, - critic_micro_bs=AC_INFER_MICRO_BATCH_SIZE, - ref_micro_bs=REF_INFER_MICRO_BATCH_SIZE, + policy_micro_bs=INFER_MICRO_BATCH_SIZE, + critic_micro_bs=INFER_MICRO_BATCH_SIZE, + ref_micro_bs=INFER_MICRO_BATCH_SIZE, kl_coeff=0.01, gamma=1.0, gae_lambda=0.99, clip_reward_min=-5, clip_reward_max=5, - answer_end_id=92542, norm_rewards=True, ) train_config = dict( - actor_micro_bs=TRAIN_MICRO_BATCH_SIZE, + policy_micro_bs=TRAIN_MICRO_BATCH_SIZE, critic_micro_bs=TRAIN_MICRO_BATCH_SIZE, ppo_loss_weight=1.0, pretrain_loss_weight=0.5, @@ -70,9 +79,9 @@ ) model_configs = dict( - actor=dict( + policy=dict( model_path='internlm/internlm2-chat-1_8b-sft', - model_type='actor', + model_type='policy', trainer_config=dict( torch_dtype=MODEL_DTYPE, trainer_type='huggingface', @@ -85,7 +94,7 @@ lr_decay_rate=1, ), parallel=dict( - data=dict(size=ACTOR_DP_SIZE, mode='deepspeed'), + data=dict(size=POLICY_DP_SIZE, mode='deepspeed'), tensor=dict(size=1, mode='1d'), pipeline=dict(size=1, interleaved_overlap=False), sequence=False, @@ -112,7 +121,7 @@ 'grad_accum_dtype': 'fp32' }, 'train_micro_batch_size_per_gpu': TRAIN_MICRO_BATCH_SIZE, - 'gradient_accumulation_steps': ACTOR_GRADIENT_ACC_STEP, + 'gradient_accumulation_steps': POLICY_GRADIENT_ACC_STEP, 'train_batch_size': PROMPT_BATCH_SIZE + PRETRAIN_BATCH_SIZE, }, ), diff --git a/examples/rlhf/four_model_vllm_8gpu.py b/examples/rlhf/four_model_vllm_8gpu.py index 5da931ff7..81da45d03 100644 --- a/examples/rlhf/four_model_vllm_8gpu.py +++ b/examples/rlhf/four_model_vllm_8gpu.py @@ -9,17 +9,27 @@ PRETRAIN_BATCH_SIZE = 32 GENERATE_MICRO_BATCH_SIZE = 16 -AC_INFER_MICRO_BATCH_SIZE = 8 -REF_INFER_MICRO_BATCH_SIZE = 8 +INFER_MICRO_BATCH_SIZE = 8 TRAIN_MICRO_BATCH_SIZE = 2 ZERO_STAGE = 3 -ACTOR_DP_SIZE = 2 +POLICY_DP_SIZE = 2 CRITIC_DP_SIZE = 2 -ACTOR_GRADIENT_ACC_STEP = (PROMPT_BATCH_SIZE + PRETRAIN_BATCH_SIZE - ) // ACTOR_DP_SIZE // TRAIN_MICRO_BATCH_SIZE +POLICY_GRADIENT_ACC_STEP = (PROMPT_BATCH_SIZE + PRETRAIN_BATCH_SIZE + ) // POLICY_DP_SIZE // TRAIN_MICRO_BATCH_SIZE CRITIC_GRADIENT_ACC_STEP = PROMPT_BATCH_SIZE // CRITIC_DP_SIZE // TRAIN_MICRO_BATCH_SIZE # noqa: E501 +# checkout generate config +assert PROMPT_BATCH_SIZE % GENERATE_MICRO_BATCH_SIZE == 0 +assert PROMPT_BATCH_SIZE % POLICY_DP_SIZE == 0 +# checkout infer config +assert PROMPT_BATCH_SIZE % (INFER_MICRO_BATCH_SIZE * POLICY_DP_SIZE) == 0 +assert PROMPT_BATCH_SIZE % (INFER_MICRO_BATCH_SIZE * CRITIC_DP_SIZE) == 0 +# checkout learn config +assert (PROMPT_BATCH_SIZE + PRETRAIN_BATCH_SIZE) % (TRAIN_MICRO_BATCH_SIZE * + POLICY_DP_SIZE) == 0 +assert (PROMPT_BATCH_SIZE) % (TRAIN_MICRO_BATCH_SIZE * CRITIC_DP_SIZE) == 0 + MODEL_DTYPE = 'auto' tokenizer_config = dict( @@ -29,7 +39,7 @@ ) rollout_config = dict( - actor_micro_bs=GENERATE_MICRO_BATCH_SIZE, + policy_micro_bs=GENERATE_MICRO_BATCH_SIZE, reward_micro_bs=GENERATE_MICRO_BATCH_SIZE, max_new_tokens=MAX_ANSWER_LEN, write_to_file=True, @@ -47,20 +57,19 @@ ) repeater_config = dict( - actor_micro_bs=AC_INFER_MICRO_BATCH_SIZE, - critic_micro_bs=AC_INFER_MICRO_BATCH_SIZE, - ref_micro_bs=REF_INFER_MICRO_BATCH_SIZE, + policy_micro_bs=INFER_MICRO_BATCH_SIZE, + critic_micro_bs=INFER_MICRO_BATCH_SIZE, + ref_micro_bs=INFER_MICRO_BATCH_SIZE, kl_coeff=0.01, gamma=1.0, gae_lambda=0.99, clip_reward_min=-5, clip_reward_max=5, - answer_end_id=92542, norm_rewards=True, ) train_config = dict( - actor_micro_bs=TRAIN_MICRO_BATCH_SIZE, + policy_micro_bs=TRAIN_MICRO_BATCH_SIZE, critic_micro_bs=TRAIN_MICRO_BATCH_SIZE, ppo_loss_weight=1.0, pretrain_loss_weight=0.5, @@ -70,9 +79,9 @@ ) model_configs = dict( - actor=dict( + policy=dict( model_path='internlm/internlm2-chat-1_8b-sft', - model_type='actor', + model_type='policy', trainer_config=dict( torch_dtype=MODEL_DTYPE, trainer_type='huggingface', @@ -85,7 +94,7 @@ lr_decay_rate=1, ), parallel=dict( - data=dict(size=ACTOR_DP_SIZE, mode='deepspeed'), + data=dict(size=POLICY_DP_SIZE, mode='deepspeed'), tensor=dict(size=1, mode='1d'), pipeline=dict(size=1, interleaved_overlap=False), sequence=False, @@ -112,7 +121,7 @@ 'grad_accum_dtype': 'fp32' }, 'train_micro_batch_size_per_gpu': TRAIN_MICRO_BATCH_SIZE, - 'gradient_accumulation_steps': ACTOR_GRADIENT_ACC_STEP, + 'gradient_accumulation_steps': POLICY_GRADIENT_ACC_STEP, 'train_batch_size': PROMPT_BATCH_SIZE + PRETRAIN_BATCH_SIZE, }, ), diff --git a/examples/rlhf/quick_start.md b/examples/rlhf/quick_start.md index 8cc5cb494..bbd9b9d26 100644 --- a/examples/rlhf/quick_start.md +++ b/examples/rlhf/quick_start.md @@ -33,6 +33,6 @@ pip uninstall cupy-cuda12x -y pip install cupy-cuda11x==12.1 python -m cupyx.tools.install_library --library nccl --cuda 11.x -# 启动任务 -xtuner rlhf -c examples/rlhf/four_model_vllm_8gpu.py +# 启动任务,首次启动建议添加 HF_ENDPOINT=https://hf-mirror.com 方便数据集加载 +HF_ENDPOINT=https://hf-mirror.com xtuner rlhf -c examples/rlhf/four_model_vllm_8gpu.py ``` diff --git a/xtuner/rlhf/config/config_consts.py b/xtuner/rlhf/config/config_consts.py index a54be13db..03c64aa43 100644 --- a/xtuner/rlhf/config/config_consts.py +++ b/xtuner/rlhf/config/config_consts.py @@ -1,7 +1,7 @@ # keywords for config files -# model type (actor, critic, reward, reference, ...) for `model_type` -MODEL_TYPE_ACTOR = 'actor' +# model type (policy, critic, reward, reference, ...) for `model_type` +MODEL_TYPE_POLICY = 'policy' MODEL_TYPE_REFERENCE = 'reference' MODEL_TYPE_REWARD = 'reward' MODEL_TYPE_CRITIC = 'critic' diff --git a/xtuner/rlhf/coordinator.py b/xtuner/rlhf/coordinator.py index b64e15ff3..60c214c44 100644 --- a/xtuner/rlhf/coordinator.py +++ b/xtuner/rlhf/coordinator.py @@ -3,14 +3,12 @@ import ray from loguru import logger -from .config.config_consts import (MODEL_TYPE_ACTOR, MODEL_TYPE_CRITIC, +from .config.config_consts import (MODEL_TYPE_CRITIC, MODEL_TYPE_POLICY, MODEL_TYPE_REFERENCE, MODEL_TYPE_REWARD) from .config.config_utils import get_resource_requirement -from .model_server.actor_model_server import ActorModelServer -from .model_server.base_model_server import BaseModelServer -from .model_server.critic_model_server import CriticModelServer -from .model_server.ref_model_server import RefModelServer -from .model_server.reward_model_server import RewardModelServer +from .model_server import (BaseModelServer, CriticModelServer, + PolicyModelServer, RefModelServer, + RewardModelServer) ROOT_PATH = Path(__file__).parents[1].resolve() @@ -60,8 +58,8 @@ def create_models(self) -> dict[str, BaseModelServer]: self.model_dict = {} for model_name, model_config in self.model_configs.items(): model_type = model_config['model_type'] - if model_type == MODEL_TYPE_ACTOR: - self.model_dict[model_name] = ActorModelServer( + if model_type == MODEL_TYPE_POLICY: + self.model_dict[model_name] = PolicyModelServer( model_name, model_config) elif model_type == MODEL_TYPE_CRITIC: self.model_dict[model_name] = CriticModelServer( diff --git a/xtuner/rlhf/dataset/message_iter.py b/xtuner/rlhf/dataset/message_iter.py index 0776e32b7..e7f918f5b 100644 --- a/xtuner/rlhf/dataset/message_iter.py +++ b/xtuner/rlhf/dataset/message_iter.py @@ -75,11 +75,11 @@ def __init__(self, self.epoch_index = 0 def _init_in_data(self): - logger.info('====== Init in data dataset ======') + logger.info(f'Init {self.message_type} in data dataset ...') self.message_dataset = MultiSourceInDataDatset( task_groups=self.message_datasets, tokenizer=self.tokenizer) - logger.info('====== Init in data sampler ======') + logger.info(f'Init {self.message_type} in data sampler ...') assert hasattr(self.message_dataset, 'all_dataset') mes_sampler = RandomSampler(self.message_dataset.all_dataset) self.mes_dataloader = iter( @@ -90,7 +90,8 @@ def _init_in_data(self): batch_size=self.samples_each_epoch)) def yield_in_data(self): - logger.info('====== yield data from in_data sampler ======') + logger.info('yielding data from ' + f'{self.message_type} in_data sampler ...') mes_sequence = [] mes_batch_messages = next(self.mes_dataloader) @@ -122,11 +123,11 @@ def yield_in_data(self): return mes_sequence def _init_in_batch(self): - logger.info('====== Init in batch dataset ======') + logger.info(f'Init {self.message_type} in batch dataset ...') self.message_dataset = MultiSourceInBatchDatset( task_groups=self.message_datasets, tokenizer=self.tokenizer) - logger.info('====== Init in batch sampler ======') + logger.info(f'Init {self.message_type} in batch sampler ...') samples_cnts = [] for task in self.message_dataset._task_group: task['target_num_each_epoch'] = int( @@ -141,7 +142,8 @@ def _init_in_batch(self): assert sum(samples_cnts) >= self.samples_each_epoch def yield_in_batch(self): - logger.info('====== yield data from in_batch sampler ======') + logger.info('yield data from ' + f'{self.message_type} in_batch sampler ...') mes_sequence = [] # epoch_rng only use in this epoch. @@ -200,7 +202,7 @@ def _postprocess_sequence(self, message): # TODO truncation?? logger.warning( f'[MES_ITER] {self.message_type} message {message} ' - 'is too short or long, skipped...') + 'is too short or long, skipped.') return None elif self.message_type == 'pretrain': for _ in reversed(range(len(message_data))): @@ -219,7 +221,7 @@ def _postprocess_sequence(self, message): # TODO truncation?? logger.warning( f'[MES_ITER] {self.message_type} message {message} ' - 'is too short or long, skipped...') + 'is too short or long, skipped.') return None return Message( message=new_meaasage_data, diff --git a/xtuner/rlhf/envs/base.py b/xtuner/rlhf/envs/base.py new file mode 100644 index 000000000..6f3f1e84e --- /dev/null +++ b/xtuner/rlhf/envs/base.py @@ -0,0 +1,12 @@ +class EnvBase: + """`EnvBase` is the base class of different environments. + + `env` is responsible to generate the trajectory data. + """ + + def __init__(self): + pass + + def rollout(self, *args, **kwargs): + """define rollout.""" + raise NotImplementedError diff --git a/xtuner/rlhf/envs/txt_env.py b/xtuner/rlhf/envs/txt_env.py index b385e86a3..b13cbfc7f 100644 --- a/xtuner/rlhf/envs/txt_env.py +++ b/xtuner/rlhf/envs/txt_env.py @@ -6,41 +6,40 @@ from ..model_server.base_model_server import BaseModelServer from ..timer import Timer -from .prompt_utils import SYSTEM_PROMPT +from .base import EnvBase +from .utils import SYSTEM_PROMPT -class TxtEnv(): +class TxtEnv(EnvBase): """A generic RL environment to generate textual sequences.""" def __init__( self, + policy_model: BaseModelServer, + reward_model: BaseModelServer, prompt_mes_iter: Iterable, pretrain_mes_iter: Iterable = None, max_new_tokens: int = 1024, - actor_micro_bs: int = 32, + policy_micro_bs: int = 32, reward_micro_bs: int = 32, - reward_function: BaseModelServer = None, async_reward: bool = True, generate_kwargs: dict = None, **_ignored, ): - """ - Args: - dataloader: generate rl data iteratively - reward_function: reward function that computes scalar reward - """ + self.policy_model = policy_model + self.reward_model = reward_model + self.prompt_mes_iter = iter(prompt_mes_iter) self.pretrain_mes_iter = iter( pretrain_mes_iter) if pretrain_mes_iter else None - self.reward_function: BaseModelServer = reward_function - self._cur_messagess = [] + self.max_new_tokens = max_new_tokens - self.actor_micro_bs = actor_micro_bs + self.policy_micro_bs = policy_micro_bs self.reward_micro_bs = reward_micro_bs self.async_reward = async_reward self.generate_kwargs: dict = generate_kwargs - def rollout(self, policy_model: BaseModelServer, display=False): + def rollout(self, display=True): prompt_datas = deepcopy(next(self.prompt_mes_iter)) prompt_input_messages = [] for data in prompt_datas: @@ -54,15 +53,17 @@ def rollout(self, policy_model: BaseModelServer, display=False): message = deepcopy(data.message) prompt_input_messages.append(message) # prompt data - logger.info(f'[For Generate]: {prompt_input_messages[0]}') + if display: + logger.info( + f'[TXT_ENV For Generate]: \n{prompt_input_messages[0]}') with Timer('policy_model.generate'): - trajectories = policy_model.generate( + trajectories = self.policy_model.generate( inputs=prompt_input_messages, - micro_batch_size=self.actor_micro_bs, + micro_batch_size=self.policy_micro_bs, step=self.max_new_tokens, output_str=True, generate_kwargs=self.generate_kwargs) - logger.info(f'[generate] len: {len(prompt_input_messages)}') + logger.info(f'[Generate] len: {len(prompt_input_messages)}') if self.async_reward: reward_output_ref = self.get_reward_async(prompt_datas, @@ -80,9 +81,9 @@ def rollout(self, policy_model: BaseModelServer, display=False): assert data.mes_type == 'pretrain' pretrain_input_messages.append(message) - from xtuner.rlhf.tokenizer import tokenizer_utils - pt_input_ids, pt_attention_mask = tokenizer_utils.encode( - pretrain_input_messages, policy_model.tokenizer) + from xtuner.rlhf.tokenizer import encode_inputs + pt_input_ids, pt_attention_mask = encode_inputs( + pretrain_input_messages, self.policy_model.tokenizer) pretrain_labels = torch.nn.functional.pad( pt_input_ids[:, 1:], (0, 1), mode='constant', value=-100) @@ -91,8 +92,7 @@ def rollout(self, policy_model: BaseModelServer, display=False): 'labels': pretrain_labels, 'attention_mask': pt_attention_mask } - logger.info(f'[TxtEnv & {policy_model.__class__.__name__}] \ - gets {pt_input_ids.shape} pretrain data.') + logger.info(f'[TxtEnv] gets {pt_input_ids.shape} pretrain data.') else: trajectories.pretrain_data = None @@ -127,7 +127,7 @@ def get_reward_async(self, prompt_datas, policyout): logger.info(f'[For Reward]: {rm_input_messages[0]}') with Timer('reward_model.infer_async'): - reward_output_ref = self.reward_function.infer_async( + reward_output_ref = self.reward_model.infer_async( rm_input_messages, output_logprobs=False, micro_batch_size=self.reward_micro_bs) @@ -135,7 +135,7 @@ def get_reward_async(self, prompt_datas, policyout): def get_reward_collect(self, reward_output_ref): with Timer('reward_model.infer_get'): - rm_out = self.reward_function.infer_get(reward_output_ref) + rm_out = self.reward_model.infer_get(reward_output_ref) rewards = rm_out.logits.squeeze(-1) return rewards @@ -162,7 +162,7 @@ def get_reward(self, prompt_datas, policyout): logger.info(f'[For Reward]: {rm_input_messages[0]}') with Timer('reward_model.infer'): - rm_out = self.reward_function.infer( + rm_out = self.reward_model.infer( rm_input_messages, output_logprobs=False, micro_batch_size=self.reward_micro_bs) diff --git a/xtuner/rlhf/envs/prompt_utils.py b/xtuner/rlhf/envs/utils.py similarity index 100% rename from xtuner/rlhf/envs/prompt_utils.py rename to xtuner/rlhf/envs/utils.py diff --git a/xtuner/rlhf/loss/__init__.py b/xtuner/rlhf/loss/__init__.py index 712703598..ed50f738d 100644 --- a/xtuner/rlhf/loss/__init__.py +++ b/xtuner/rlhf/loss/__init__.py @@ -1,4 +1,4 @@ -from .actor_loss import PPOPolicyLoss, PretrainLoss from .critic_loss import CriticLoss +from .policy_loss import PPOPolicyLoss, PretrainLoss __all__ = ['PPOPolicyLoss', 'PretrainLoss', 'CriticLoss'] diff --git a/xtuner/rlhf/loss/actor_loss.py b/xtuner/rlhf/loss/policy_loss.py similarity index 94% rename from xtuner/rlhf/loss/actor_loss.py rename to xtuner/rlhf/loss/policy_loss.py index bfd7e5b68..f09bfb76c 100644 --- a/xtuner/rlhf/loss/actor_loss.py +++ b/xtuner/rlhf/loss/policy_loss.py @@ -43,13 +43,13 @@ def forward(self, *args): class PPOPolicyLoss(torch.nn.Module): - """Loss function for actor model.""" + """Loss function for policy model.""" def __init__(self, cliprange: float = 0.2): super().__init__() self.cliprange = cliprange - def actor_loss_fn(self, logprobs, old_logprobs, advantages, mask): + def policy_loss_fn(self, logprobs, old_logprobs, advantages, mask): ratio = (logprobs - old_logprobs).exp() pg_loss1 = -ratio * advantages pg_loss2 = -ratio.clamp(1 - self.cliprange, @@ -71,7 +71,7 @@ def forward(self, logits: torch.Tensor, labels: dict[str, Any]): num_actions = mask.size(1) logprobs = logpy[:, -num_actions:] - loss = self.actor_loss_fn( + loss = self.policy_loss_fn( logprobs=logprobs, old_logprobs=old_logprobs, advantages=advantages, diff --git a/xtuner/rlhf/main.py b/xtuner/rlhf/main.py index 4576b396d..d7319c33c 100644 --- a/xtuner/rlhf/main.py +++ b/xtuner/rlhf/main.py @@ -9,7 +9,7 @@ from xtuner.rlhf.coordinator import Coordinator from xtuner.rlhf.dataset import MessageIter from xtuner.rlhf.envs import TxtEnv -from xtuner.rlhf.repeaters import BaseRepeater +from xtuner.rlhf.repeaters import KLGAERepeater from xtuner.rlhf.timer import Timer from xtuner.rlhf.trainer import PPOTrainer @@ -36,8 +36,8 @@ def parse_args(): def validate_config(config: Config): assert config['model_configs'] is not None - assert config['model_configs']['actor'] is not None - assert config['model_configs']['actor']['model_path'] is not None + assert config['model_configs']['policy'] is not None + assert config['model_configs']['policy']['model_path'] is not None assert config['dataset_config'] is not None assert config['rollout_config'] is not None assert config['rollout_config']['generate_kwargs'] is not None @@ -75,7 +75,7 @@ def validate_config(config: Config): coordinator = Coordinator(cluster_address, config['model_configs']) model_dict = coordinator.create_models() ref_model = model_dict['reference'] - actor_model = model_dict['actor'] + policy_model = model_dict['policy'] reward_model = model_dict['reward'] critic_model = model_dict['critic'] @@ -90,49 +90,48 @@ def validate_config(config: Config): # init txt env rollout_config = config.get('rollout_config', {}) txt_env = TxtEnv( + policy_model=policy_model, + reward_model=reward_model, prompt_mes_iter=prompt_mes_iter, pretrain_mes_iter=pretrain_mes_iter, - reward_function=reward_model, **rollout_config, ) # init repeater repeater_config = config.get('repeater_config', {}) - rl_repeater = BaseRepeater( + ppo_repeater = KLGAERepeater( ref_model=ref_model, + policy_model=policy_model, + critic_model=critic_model, + env=txt_env, **repeater_config, ) # init trainer train_config = config.get('train_config', {}) ppo = PPOTrainer( - policy_model=actor_model, critic_model=None, **train_config) + policy_model=policy_model, critic_model=critic_model, **train_config) critic_warmup_step = train_config['critic_warmup_step'] save_interval = train_config['save_interval'] max_train_step = train_config.get('max_train_step', float('inf')) - step = 0 + step = 1 while step <= max_train_step: s_t = time.time() with Timer(f'step {step}: end_to_end'): - trajectories = txt_env.rollout(policy_model=actor_model) + trajectories = txt_env.rollout(display=True) # deal with trajectories - trajectories = rl_repeater.process( - trajectories, - policy_model=actor_model, - critic_model=critic_model, - ref_model=None, - env=txt_env) + trajectories = ppo_repeater.process(trajectories) # # for critic & policy learn - critic_loss_ref = ppo.critic_learn_async(trajectories, - critic_model) + critic_loss = ppo.critic_learn(trajectories) + # critic_loss_ref = ppo.critic_learn_async(trajectories) ppo_loss, pt_loss = None, None if critic_warmup_step <= 0: - ppo_loss, pt_loss = ppo.policy_learn(trajectories, actor_model) + ppo_loss, pt_loss = ppo.policy_learn(trajectories) logger_train.info(f'[Policy Train] Step: {step}, \ ppo loss: {ppo_loss}, pretrain loss: {pt_loss}') - critic_loss = ppo.critic_learn_get(critic_loss_ref, critic_model) + # critic_loss = ppo.critic_learn_get(critic_loss_ref) logger_train.info( f'[Critic Train] step: {step}, critic loss: {critic_loss}') logger_train.info(f'rewards: {trajectories.rewards.mean()}') @@ -164,7 +163,7 @@ def validate_config(config: Config): with open(f'{work_dir}/train_rlhf.log.jsonl', 'a') as f: f.write(json.dumps(summaries) + '\n') - step += 1 logger_train.info(f'[end to end] duration: {time.time() - s_t} s') - if step % save_interval == 0: - actor_model.save_model(f'{work_dir}/ckpt/{step}/') + if (step % save_interval == 0) or (step == max_train_step): + policy_model.save_model(f'{work_dir}/ckpt/{step}/') + step += 1 diff --git a/xtuner/rlhf/model_backend/hf_model_runner.py b/xtuner/rlhf/model_backend/hf_model_runner.py index 15861cb2d..bb56898e4 100644 --- a/xtuner/rlhf/model_backend/hf_model_runner.py +++ b/xtuner/rlhf/model_backend/hf_model_runner.py @@ -21,7 +21,7 @@ from ..config.config_utils import get_dp_size, get_gpu_requirement from ..policy_output import (PolicyOutput, concat_policy_outputs, logprobs_from_logits) -from ..tokenizer import tokenizer_utils +from ..tokenizer import get_tokenizer from ..utils import set_seed from .dist_utils import init_process_group from .generate_utils import (get_answer_str, get_question_answer_mask, @@ -77,7 +77,7 @@ def initialize(self): # 2. Tokenizer tokenizer_path = self.model_config.get('tokenizer_path', model_path) tokenizer_config = self.model_config.get('tokenizer_config', {}) - self.tokenizer = tokenizer_utils.get_tokenizer( + self.tokenizer = get_tokenizer( tokenizer_path, trust_remote_code=True, **tokenizer_config) # 3. Trainer @@ -188,7 +188,7 @@ def compute_loss( labels = labels.to(self.device) loss = criterion(logits, labels) elif isinstance(labels, dict): - # OPT. C) Use customized loss function, see loss/actor_loss.py + # OPT. C) Use customized loss function, see loss/policy_loss.py logits: torch.Tensor = self.model( **batch, use_cache=False, return_dict=True).logits for k, v in labels.items(): @@ -338,7 +338,7 @@ def _infer( @torch.no_grad() def infer( self, - inputs: Union[torch.Tensor, list[dict], list[list[dict]]], + input_ids: torch.Tensor, micro_batch_size: Optional[ int] = -1, # -1: use the entire input as one batch tokenizer=None, # Only used for reward models @@ -353,12 +353,6 @@ def infer( ) -> PolicyOutput: self.info_rank0( f'[{self.model_type}] self.infer() kwargs: {infer_kwargs}') - if not isinstance(inputs, torch.Tensor): - input_ids, attention_mask = tokenizer_utils.encode( - inputs, self.tokenizer) - else: - input_ids = inputs - input_ids = input_ids.to(self.device) if attention_mask is not None: attention_mask = attention_mask.to(self.device) @@ -484,7 +478,7 @@ def _generate( @torch.no_grad() def generate( self, - inputs: Union[torch.Tensor, str, list[str]], + input_ids: torch.Tensor, micro_batch_size: Optional[ int] = -1, # -1: use the entire input as one batch attention_mask=None, @@ -500,11 +494,6 @@ def generate( ) -> PolicyOutput: self.info_rank0( f'[{self.model_type}] self.generate() kwargs: {generate_kwargs}') - if not isinstance(inputs, torch.Tensor): - input_ids, attention_mask = tokenizer_utils.encode( - inputs, self.tokenizer, add_generation_prompt=True) - else: - input_ids = inputs input_ids = input_ids.to(self.device) if attention_mask is not None: assert isinstance(attention_mask, torch.Tensor) @@ -781,7 +770,7 @@ def infer_async(self, input_ids, attention_mask, *args, **kwargs): assert len(micro_batches) == self.dp_size return [ self.ray_actors[index].infer.remote( - inputs=micro_batch['input_ids'], + input_ids=micro_batch['input_ids'], attention_mask=micro_batch['attention_mask'], *args, **kwargs, @@ -807,7 +796,7 @@ def generate_async(self, input_ids, attention_mask, *args, **kwargs): assert len(micro_batches) == self.dp_size return [ self.ray_actors[index].generate.remote( - inputs=micro_batch['input_ids'], + input_ids=micro_batch['input_ids'], attention_mask=micro_batch['attention_mask'], *args, **kwargs, diff --git a/xtuner/rlhf/model_backend/vllm_worker_wrap.py b/xtuner/rlhf/model_backend/vllm_worker_wrap.py index 843941e81..daef742e4 100644 --- a/xtuner/rlhf/model_backend/vllm_worker_wrap.py +++ b/xtuner/rlhf/model_backend/vllm_worker_wrap.py @@ -64,7 +64,7 @@ def init_process_group(self, master_address, master_port, rank_offset, f'rank={rank}, world_size={world_size}, group_name={group_name}') def update_weight(self, name, dtype, shape, empty_cache=False): - """Broadcast weight to all vllm workers from source rank 0 (actor + """Broadcast weight to all vllm workers from source rank 0 (policy model)""" if torch.distributed.get_rank() == 0: logger.debug( diff --git a/xtuner/rlhf/model_server/__init__.py b/xtuner/rlhf/model_server/__init__.py index e69de29bb..d60547baa 100644 --- a/xtuner/rlhf/model_server/__init__.py +++ b/xtuner/rlhf/model_server/__init__.py @@ -0,0 +1,10 @@ +from .base_model_server import BaseModelServer +from .critic_model_server import CriticModelServer +from .policy_model_server import PolicyModelServer +from .ref_model_server import RefModelServer +from .reward_model_server import RewardModelServer + +__all__ = [ + 'BaseModelServer', 'PolicyModelServer', 'RefModelServer', + 'CriticModelServer', 'RewardModelServer' +] diff --git a/xtuner/rlhf/model_server/base_model_server.py b/xtuner/rlhf/model_server/base_model_server.py index f5cc93975..143553d2f 100644 --- a/xtuner/rlhf/model_server/base_model_server.py +++ b/xtuner/rlhf/model_server/base_model_server.py @@ -8,7 +8,7 @@ from ..config.config_consts import ENGINE_HUGGINGFACE, ENGINE_INTERNEVO from ..model_backend.hf_model_runner import HfModelRunnerRayActorGroup from ..model_backend.models.modeling_internlm2_p import InternLM2ForCausalLM -from ..tokenizer import tokenizer_utils +from ..tokenizer import encode_inputs, get_tokenizer DEFAULT_GET_TIMEOUT = 600.0 # 10 min @@ -37,7 +37,7 @@ def init_tokenizer_and_config(self, model_config): else: tokenizer_path = model_config['model_path'] - self.tokenizer = tokenizer_utils.get_tokenizer( + self.tokenizer = get_tokenizer( tokenizer_path, trust_remote_code=True, **tokenizer_config) tokenizer_config['tokenizer_path'] = tokenizer_path @@ -65,7 +65,7 @@ def initialize_async(self): self.init_trainer_config(self.model_config, self.tokenizer_config) trainer_type = self.trainer_config.get('trainer_type', - 'huggingface').lower() + ENGINE_HUGGINGFACE).lower() if trainer_type == ENGINE_HUGGINGFACE: self.trainer = HfModelRunnerRayActorGroup( name=f'{self.model_name}_trainer', config=self.trainer_config) @@ -83,8 +83,7 @@ def initialize_get(self): # Inference def infer_async(self, inputs, attention_mask=None, *args, **infer_kwargs): if not isinstance(inputs, torch.Tensor): - input_ids, attention_mask = tokenizer_utils.encode( - inputs, self.tokenizer) + input_ids, attention_mask = encode_inputs(inputs, self.tokenizer) else: input_ids = inputs return self.trainer.infer_async( diff --git a/xtuner/rlhf/model_server/actor_model_server.py b/xtuner/rlhf/model_server/policy_model_server.py similarity index 94% rename from xtuner/rlhf/model_server/actor_model_server.py rename to xtuner/rlhf/model_server/policy_model_server.py index e6f992a51..bbe819347 100644 --- a/xtuner/rlhf/model_server/actor_model_server.py +++ b/xtuner/rlhf/model_server/policy_model_server.py @@ -4,11 +4,11 @@ from loguru import logger from ..config.config_consts import ENGINE_VLLM -from ..tokenizer import tokenizer_utils +from ..tokenizer import encode_inputs from .base_model_server import BaseModelServer -class ActorModelServer(BaseModelServer): +class PolicyModelServer(BaseModelServer): # Initialize def initialize_async(self): super().initialize_async() @@ -56,14 +56,14 @@ def generate_async(self, input_ids = inputs elif isinstance(inputs, list): if not self.generator_eq_trainer: - input_ids, attention_mask = tokenizer_utils.encode( + input_ids, attention_mask = encode_inputs( inputs, self.tokenizer, return_tensors=None, padding=False, add_generation_prompt=True) else: - input_ids, attention_mask = tokenizer_utils.encode( + input_ids, attention_mask = encode_inputs( inputs, self.tokenizer, add_generation_prompt=True) else: raise NotImplementedError(f'unknown inputs: {inputs}') diff --git a/xtuner/rlhf/model_server/reward_model_server.py b/xtuner/rlhf/model_server/reward_model_server.py index de275c0d9..b658c1513 100644 --- a/xtuner/rlhf/model_server/reward_model_server.py +++ b/xtuner/rlhf/model_server/reward_model_server.py @@ -2,7 +2,7 @@ from transformers import AutoConfig from ..model_backend.models.critical_and_reward import get_reward_model -from ..tokenizer import tokenizer_utils +from ..tokenizer import encode_inputs from ..utils import expand_reward_token_id from .base_model_server import BaseModelServer @@ -16,6 +16,9 @@ def get_model_class(self, model_path): def init_tokenizer_and_config(self, model_config): super().init_tokenizer_and_config(self.model_config) + # specify `reward_token_id`` to get scalar reward of a sequence + # according to the `Rward Model` training strategy, + # which is set to `pad_token_id` by default self.reward_token_id = self.tokenizer.pad_token_id model_path = model_config['model_path'] auto_config = AutoConfig.from_pretrained( @@ -26,8 +29,7 @@ def init_tokenizer_and_config(self, model_config): # Inference def infer_async(self, inputs, attention_mask=None, *args, **infer_kwargs): if not isinstance(inputs, torch.Tensor): - input_ids, attention_mask = tokenizer_utils.encode( - inputs, self.tokenizer) + input_ids, attention_mask = encode_inputs(inputs, self.tokenizer) else: input_ids = inputs diff --git a/xtuner/rlhf/policy_output.py b/xtuner/rlhf/policy_output.py index 34aa3c34c..af7129410 100644 --- a/xtuner/rlhf/policy_output.py +++ b/xtuner/rlhf/policy_output.py @@ -20,30 +20,6 @@ class PolicyOutput(ModelOutput): question_mask: Optional[torch.Tensor] = None answer_mask: Optional[torch.Tensor] = None - def __eq__(self, other: ModelOutput): - if len(self.keys()) != len(other.keys()): - return False - for k, v in self.items(): - if k not in other: - return False - vother = other[k] - - if isinstance(v, torch.Tensor): - if not torch.equal(v, vother): - return False - elif isinstance(v, tuple): # tuple(torch.Tensor) - for i, j in zip(v, vother): - if isinstance(i, torch.Tensor): - if not torch.equal(i, j): - return False - else: - if i != j: - return False - else: - if v != vother: - return False - return True - def to(self, device): for k, v in self.items(): if isinstance(v, torch.Tensor): @@ -61,8 +37,8 @@ def union_keys_from_policy_outputs(policy_outputs: list[PolicyOutput]) -> list: all_keys = set() for po in policy_outputs: all_keys = all_keys.union(set(po.keys())) - return list( - all_keys) # e.g., return ["output_str", "output_ids", "loss", ...] + # e.g., return ["output_str", "output_ids", "loss", ...] + return list(all_keys) def union_tensor_keys_from_policy_outputs( @@ -70,19 +46,22 @@ def union_tensor_keys_from_policy_outputs( all_keys = set() for po in policy_outputs: all_keys = all_keys.union(set(po.get_tensor_keys())) - return list(all_keys) # e.g., return ["output_ids", "loss", ...] + # e.g., return ["output_ids", "loss", ...] + return list(all_keys) def concat_policy_outputs(policy_outputs: list[PolicyOutput], padding_token_map: dict = None) -> PolicyOutput: if isinstance(policy_outputs, PolicyOutput): - return policy_outputs # Wrong input type + # Wrong input type + return policy_outputs elif policy_outputs is None or len(policy_outputs) == 0: return PolicyOutput(None) elif len(policy_outputs) == 1: return policy_outputs[0] - if padding_token_map is not None: # padding + # padding + if padding_token_map is not None: policy_outputs = padding_policy_outputs(policy_outputs, padding_token_map) @@ -92,15 +71,18 @@ def concat_policy_outputs(policy_outputs: list[PolicyOutput], for po in policy_outputs: value = po[key] if value is not None: - break # get the first non-empty value + # get the first non-empty value + break if value is None: - continue # skip if all values are None + # skip if all values are None + continue if isinstance(value, torch.Tensor): concated[key] = torch.cat( [po[key] for po in policy_outputs if po[key] is not None], dim=0) - elif isinstance(value, list): # e.g., list[str] + elif isinstance(value, list): + # e.g., list[str] concated[key] = [] for po in policy_outputs: if po[key] is not None: @@ -153,21 +135,20 @@ def find_max_seq_len(policy_outputs: list[PolicyOutput], key): def logprobs_from_logits(logits: torch.Tensor, - labels: torch.Tensor = None, + labels: torch.Tensor, gather: bool = True) -> torch.Tensor: r""" - Adapted from: https://github.com/huggingface/trl/blob/main/trl/core.py#L95 + Adapted from: https://github.com/huggingface/trl/blob/main/trl/core.py#L131 Example: ```python - >>> logits, _, values = model(**input_kwargs) + >>> logits, _ = model(**input_kwargs) >>> input_ids = input_kwargs["input_ids"] >>> logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:]) ```""" - - logp = torch.nn.functional.log_softmax(logits, dim=-1) - if not gather or labels is None: + logp = torch.nn.functional.log_softmax(logits, dim=2) + if not gather: return logp - logpy = torch.gather(logp, -1, labels.unsqueeze(2)).squeeze(-1) + logpy = torch.gather(logp, 2, labels.unsqueeze(2)).squeeze(-1) return logpy diff --git a/xtuner/rlhf/repeaters/__init__.py b/xtuner/rlhf/repeaters/__init__.py index 1ded00298..14ca68a07 100644 --- a/xtuner/rlhf/repeaters/__init__.py +++ b/xtuner/rlhf/repeaters/__init__.py @@ -1,3 +1,4 @@ -from .base import BaseRepeater +from .base import RepeaterBase +from .kl_gae import KLGAERepeater -__all__ = ['BaseRepeater'] +__all__ = ['RepeaterBase', 'KLGAERepeater'] diff --git a/xtuner/rlhf/repeaters/base.py b/xtuner/rlhf/repeaters/base.py index ac8833a60..5a6e63054 100644 --- a/xtuner/rlhf/repeaters/base.py +++ b/xtuner/rlhf/repeaters/base.py @@ -1,212 +1,15 @@ -import torch - -from ..model_server.base_model_server import BaseModelServer from ..policy_output import PolicyOutput -from ..timer import Timer -from .running_mean_std import RunningStates - - -class BaseRepeater: - - def __init__( - self, - ref_model, - actor_micro_bs: int = 8, - ref_micro_bs: int = 8, - critic_micro_bs: int = 32, - kl_coeff=0.02, - gamma=1.0, - gae_lambda=0.95, - norm_adv=False, - clip_reward_min: int = -5, - clip_reward_max: int = 5, - norm_rewards=True, - reward_scale: bool = False, - fine_grained_rm: bool = False, - **_ignored, - ): - self.ref_model = ref_model - self.actor_micro_bs = actor_micro_bs - self.ref_micro_bs = ref_micro_bs - self.critic_micro_bs = critic_micro_bs - self.kl_coeff = kl_coeff - self.gamma = gamma - self.gae_lambda = gae_lambda - # rewards - self.clip_reward_min = clip_reward_min - self.clip_reward_max = clip_reward_max - self.norm_rewards = norm_rewards - if self.norm_rewards: - self.running_states = RunningStates(epsilon=0) - - def process( - self, - trajectories: PolicyOutput, - policy_model: BaseModelServer, - critic_model: BaseModelServer, - ref_model: BaseModelServer = None, - # only used for async reward model.infer_get() in _get_kl_rewards - env=None, - ): - critic_output_ref = self._get_values_async(trajectories, critic_model) - action_mask = trajectories['action_mask'] - num_actions = action_mask.size(1) - if ref_model is not None: - self.ref_model: BaseModelServer = ref_model - (kl_rewards, entropy, kl_distance, policy_logprobs, - ref_logprobs) = self._get_kl_rewards( - trajectories, policy_model, env=env) - trajectories['kl'] = (kl_distance * action_mask).sum( - axis=-1) / action_mask.sum(axis=-1) - trajectories['entropy'] = entropy - trajectories['kl_rewards'] = kl_rewards - trajectories['policy_logprobs'] = policy_logprobs - trajectories['ref_logprobs'] = ref_logprobs - - values = self._get_values_collect(critic_output_ref, critic_model) - old_values = values[:, -num_actions:] - advantages, returns = self.get_advantages_and_returns( - old_values, kl_rewards, action_mask) - - trajectories['advantages'] = advantages - trajectories['returns'] = returns - trajectories['old_values'] = old_values - - return trajectories - - def _get_kl_rewards(self, - trajectories: PolicyOutput, - policy_model: BaseModelServer, - env=None): - with Timer('policy_model.infer_async'): - policy_output = policy_model.infer_async( - inputs=trajectories.output_ids, - micro_batch_size=self.actor_micro_bs, - attention_mask=trajectories.attention_mask, - output_logits=False, - output_logprobs=True) - with Timer('ref_model.infer_async'): - ref_output = self.ref_model.infer_async( - inputs=trajectories.output_ids, - micro_batch_size=self.ref_micro_bs, - attention_mask=trajectories.attention_mask, - output_logits=False, - output_logprobs=True) - with Timer('ref_model.infer_get'): - policy_output = policy_model.infer_get(policy_output) - with Timer('ref_model.infer_get'): - ref_output = self.ref_model.infer_get(ref_output) - - # Experimental - if env.async_reward: - rewards = env.get_reward_collect(trajectories['reward_output_ref']) - trajectories['reward_output_ref'] = None - trajectories['rewards'] = rewards - # Experimental - - clipped_rewards = torch.clamp( - rewards, min=self.clip_reward_min, max=self.clip_reward_max) - trajectories['clipped_rewards'] = clipped_rewards - - if self.norm_rewards: - self.running_states.update(clipped_rewards) - norm_reward_score = (clipped_rewards - - self.running_states.mean) / ( - self.running_states.var.sqrt() + 1e-8) - action_mask = trajectories.action_mask - num_actions = action_mask.size(1) - - policy_logprobs = policy_output.logprobs[:, -num_actions:] - ref_logprobs = ref_output.logprobs[:, -num_actions:] - - if self.kl_coeff <= 0.0: - self.kl_coeff = 0.0 - # compute_approx_kl - log_ratio = policy_logprobs - ref_logprobs - kl = log_ratio * action_mask - kl_reward = -self.kl_coeff * kl - - eos_indices = action_mask.size( - 1) - 1 - action_mask.long().fliplr().argmax( - dim=1, keepdim=True) - last_reward = torch.zeros_like(kl).scatter_( - dim=1, - index=eos_indices, - src=norm_reward_score.unsqueeze(1).to(kl.dtype)) - - reward = last_reward + kl_reward - - entropy = -(policy_logprobs * - action_mask).sum(axis=-1) / action_mask.sum(axis=-1) - return reward, entropy, kl, policy_logprobs, ref_logprobs - - def _get_values(self, trajectories: PolicyOutput, - critic_model: BaseModelServer): - with Timer('critic_model.infer'): - critic_output = critic_model.infer( - inputs=trajectories.output_ids, - attention_mask=trajectories.attention_mask, - output_logits=True, - micro_batch_size=self.critic_micro_bs, - ) - raw_values = critic_output.logits.squeeze(-1) - return raw_values - - def _get_values_async(self, trajectories: PolicyOutput, - critic_model: BaseModelServer): - with Timer('critic_model.infer_async'): - critic_output_ref = critic_model.infer_async( - inputs=trajectories.output_ids, - attention_mask=trajectories.attention_mask, - output_logits=True, - micro_batch_size=self.critic_micro_bs, - ) - return critic_output_ref - - def _get_values_collect(self, critic_output_ref, - critic_model: BaseModelServer): - with Timer('critic_model.infer_get'): - critic_output = critic_model.infer_get(critic_output_ref) - raw_values = critic_output.logits.squeeze(-1) - return raw_values - def get_advantages_and_returns( - self, - values: torch.Tensor, - rewards: torch.Tensor, - action_mask: torch.Tensor, - ): - # Adopted from https://github.com/CarperAI/trlx/blob/main/trlx/models/modeling_ppo.py#L134 # noqa: E501 - """Function that computes advantages and returns from rewards and - values. Calculated as in the original PPO paper: - https://arxiv.org/abs/1707.06347 Note that rewards may include a KL - divergence loss term. - Advantages looks like this: - Adv1 = R1 + γ * λ * R2 + γ^2 * λ^2 * R3 + ... - - V1 + γ * (1 - λ) V2 + γ^2 * λ * (1 - λ) V3 + ... +class RepeaterBase: + """`RepeaterBase` is the base class of different repeaters. - Returns looks like this: - Ret1 = R1 + γ * λ * R2 + γ^2 * λ^2 * R3 + ... - + γ * (1 - λ) V2 + γ^2 * λ * (1 - λ) V3 + ... - """ - lastgaelam = 0 - advantages_reversed = [] - response_length = rewards.size(1) + `repeater` is responsible to deal with the trajectory data. + """ - # Mask invalid responses - values = action_mask * values - rewards = action_mask * rewards + def __init__(self): + pass - for t in reversed(range(response_length)): - nextvalues = values[:, t + 1] if t < response_length - 1 else 0.0 - # Since old_rewards and old_values are masked with action_mask, - # i.e. they have 0's at pad tokens, - # delta will be 0 if current t is at a pad token, - # so will lastgaelam - delta = rewards[:, t] + self.gamma * nextvalues - values[:, t] - lastgaelam = delta + self.gamma * self.gae_lambda * lastgaelam - advantages_reversed.append(lastgaelam) - advantages = torch.stack(advantages_reversed[::-1], dim=1) - returns = advantages + values - return advantages.detach(), returns + def process(self, trajectories: PolicyOutput, *args, **kwargs): + """define process, such as get GAEs.""" + raise NotImplementedError diff --git a/xtuner/rlhf/repeaters/kl_gae.py b/xtuner/rlhf/repeaters/kl_gae.py new file mode 100644 index 000000000..a1304ac53 --- /dev/null +++ b/xtuner/rlhf/repeaters/kl_gae.py @@ -0,0 +1,208 @@ +import torch + +from ..model_server.base_model_server import BaseModelServer +from ..policy_output import PolicyOutput +from ..timer import Timer +from .base import RepeaterBase +from .utils import RunningStates + + +class KLGAERepeater(RepeaterBase): + + def __init__( + self, + ref_model: BaseModelServer, + policy_model: BaseModelServer, + critic_model: BaseModelServer, + policy_micro_bs: int = 8, + ref_micro_bs: int = 8, + critic_micro_bs: int = 32, + kl_coeff=0.01, + gamma=1.0, + gae_lambda=0.99, + clip_reward_min: int = -5, + clip_reward_max: int = 5, + norm_rewards=True, + norm_adv=False, + env=None, + **_ignored, + ): + # models + self.ref_model = ref_model + self.policy_model = policy_model + self.critic_model = critic_model + + self.policy_micro_bs = policy_micro_bs + self.ref_micro_bs = ref_micro_bs + self.critic_micro_bs = critic_micro_bs + self.kl_coeff = kl_coeff + self.gamma = gamma + self.gae_lambda = gae_lambda + # rewards + self.clip_reward_min = clip_reward_min + self.clip_reward_max = clip_reward_max + self.norm_rewards = norm_rewards + if self.norm_rewards: + self.running_states = RunningStates(epsilon=0) + self.norm_adv = norm_adv + + # only used for async reward model.infer_get() in _get_kl_rewards + self.env = env + + def process(self, trajectories: PolicyOutput): + critic_output_ref = self._get_values_async(trajectories) + action_mask = trajectories['action_mask'] + num_actions = action_mask.size(1) + (kl_rewards, entropy, kl_distance, policy_logprobs, + ref_logprobs) = self._get_kl_rewards(trajectories) + trajectories['kl'] = (kl_distance * action_mask).sum( + axis=-1) / action_mask.sum(axis=-1) + trajectories['entropy'] = entropy + trajectories['kl_rewards'] = kl_rewards + trajectories['policy_logprobs'] = policy_logprobs + trajectories['ref_logprobs'] = ref_logprobs + + values = self._get_values_collect(critic_output_ref) + old_values = values[:, -num_actions:] + advantages, returns = self.get_advantages_and_returns( + old_values, kl_rewards, action_mask) + if self.norm_adv: + advantages = (advantages - advantages.mean()) / ( + advantages.std() + 1e-8) + trajectories['advantages'] = advantages + trajectories['returns'] = returns + trajectories['old_values'] = old_values + + return trajectories + + def _get_kl_rewards(self, trajectories: PolicyOutput): + with Timer('policy_model.infer_async'): + policy_output = self.policy_model.infer_async( + inputs=trajectories.output_ids, + micro_batch_size=self.policy_micro_bs, + attention_mask=trajectories.attention_mask, + output_logits=False, + output_logprobs=True) + with Timer('ref_model.infer_async'): + ref_output = self.ref_model.infer_async( + inputs=trajectories.output_ids, + micro_batch_size=self.ref_micro_bs, + attention_mask=trajectories.attention_mask, + output_logits=False, + output_logprobs=True) + with Timer('ref_model.infer_get'): + policy_output = self.policy_model.infer_get(policy_output) + with Timer('ref_model.infer_get'): + ref_output = self.ref_model.infer_get(ref_output) + + # Experimental + if self.env.async_reward: + rewards = self.env.get_reward_collect( + trajectories['reward_output_ref']) + trajectories['reward_output_ref'] = None + trajectories['rewards'] = rewards + # Experimental + + clipped_rewards = torch.clamp( + rewards, min=self.clip_reward_min, max=self.clip_reward_max) + trajectories['clipped_rewards'] = clipped_rewards + + if self.norm_rewards: + self.running_states.update(clipped_rewards) + norm_reward_score = (clipped_rewards - + self.running_states.mean) / ( + self.running_states.var.sqrt() + 1e-8) + action_mask = trajectories.action_mask + num_actions = action_mask.size(1) + + policy_logprobs = policy_output.logprobs[:, -num_actions:] + ref_logprobs = ref_output.logprobs[:, -num_actions:] + + if self.kl_coeff <= 0.0: + self.kl_coeff = 0.0 + # compute_approx_kl + log_ratio = policy_logprobs - ref_logprobs + kl = log_ratio * action_mask + kl_reward = -self.kl_coeff * kl + + eos_indices = action_mask.size( + 1) - 1 - action_mask.long().fliplr().argmax( + dim=1, keepdim=True) + last_reward = torch.zeros_like(kl).scatter_( + dim=1, + index=eos_indices, + src=norm_reward_score.unsqueeze(1).to(kl.dtype)) + + reward = last_reward + kl_reward + + entropy = -(policy_logprobs * + action_mask).sum(axis=-1) / action_mask.sum(axis=-1) + return reward, entropy, kl, policy_logprobs, ref_logprobs + + def _get_values(self, trajectories: PolicyOutput): + with Timer('critic_model.infer'): + critic_output = self.critic_model.infer( + inputs=trajectories.output_ids, + attention_mask=trajectories.attention_mask, + output_logits=True, + micro_batch_size=self.critic_micro_bs, + ) + raw_values = critic_output.logits.squeeze(-1) + return raw_values + + def _get_values_async(self, trajectories: PolicyOutput): + with Timer('critic_model.infer_async'): + critic_output_ref = self.critic_model.infer_async( + inputs=trajectories.output_ids, + attention_mask=trajectories.attention_mask, + output_logits=True, + micro_batch_size=self.critic_micro_bs, + ) + return critic_output_ref + + def _get_values_collect(self, critic_output_ref): + with Timer('critic_model.infer_get'): + critic_output = self.critic_model.infer_get(critic_output_ref) + raw_values = critic_output.logits.squeeze(-1) + return raw_values + + def get_advantages_and_returns( + self, + values: torch.Tensor, + rewards: torch.Tensor, + action_mask: torch.Tensor, + ): + # Adopted from https://github.com/CarperAI/trlx/blob/main/trlx/models/modeling_ppo.py#L134 # noqa: E501 + """Function that computes advantages and returns from rewards and + values. Calculated as in the original PPO paper: + https://arxiv.org/abs/1707.06347 Note that rewards may include a KL + divergence loss term. + + Advantages looks like this: + Adv1 = R1 + γ * λ * R2 + γ^2 * λ^2 * R3 + ... + - V1 + γ * (1 - λ) V2 + γ^2 * λ * (1 - λ) V3 + ... + + Returns looks like this: + Ret1 = R1 + γ * λ * R2 + γ^2 * λ^2 * R3 + ... + + γ * (1 - λ) V2 + γ^2 * λ * (1 - λ) V3 + ... + """ + lastgaelam = 0 + advantages_reversed = [] + response_length = rewards.size(1) + + # Mask invalid responses + values = action_mask * values + rewards = action_mask * rewards + + for t in reversed(range(response_length)): + nextvalues = values[:, t + 1] if t < response_length - 1 else 0.0 + # Since old_rewards and old_values are masked with action_mask, + # i.e. they have 0's at pad tokens, + # delta will be 0 if current t is at a pad token, + # so will lastgaelam + delta = rewards[:, t] + self.gamma * nextvalues - values[:, t] + lastgaelam = delta + self.gamma * self.gae_lambda * lastgaelam + advantages_reversed.append(lastgaelam) + advantages = torch.stack(advantages_reversed[::-1], dim=1) + returns = advantages + values + return advantages.detach(), returns diff --git a/xtuner/rlhf/repeaters/running_mean_std.py b/xtuner/rlhf/repeaters/utils.py similarity index 100% rename from xtuner/rlhf/repeaters/running_mean_std.py rename to xtuner/rlhf/repeaters/utils.py diff --git a/xtuner/rlhf/tokenizer/__init__.py b/xtuner/rlhf/tokenizer/__init__.py index e69de29bb..e59b36fff 100644 --- a/xtuner/rlhf/tokenizer/__init__.py +++ b/xtuner/rlhf/tokenizer/__init__.py @@ -0,0 +1,3 @@ +from .tokenizer_utils import encode_inputs, get_tokenizer + +__all__ = ['get_tokenizer', 'encode_inputs'] diff --git a/xtuner/rlhf/tokenizer/tokenizer_utils.py b/xtuner/rlhf/tokenizer/tokenizer_utils.py index a21e615b9..2a1539aa7 100644 --- a/xtuner/rlhf/tokenizer/tokenizer_utils.py +++ b/xtuner/rlhf/tokenizer/tokenizer_utils.py @@ -1,7 +1,7 @@ from typing import Optional, Union from loguru import logger -from transformers import (AutoTokenizer, LlamaTokenizer, PreTrainedTokenizer, +from transformers import (AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast) PADDING_SIDE = 'left' @@ -36,19 +36,6 @@ def get_tokenizer( raise RuntimeError(err_msg) from e else: raise e - except OSError as e: - if 'Incorrect path_or_model_id' in str(e): # e.g., v13.model - tokenizer = LlamaTokenizer.from_pretrained( - tokenizer_name, - *args, - trust_remote_code=trust_remote_code, - tokenizer_revision=tokenizer_revision, - padding_side=padding_side, - **kwargs, - ) - logger.warning('Using LlamaTokenizer.') - else: - raise e except AttributeError as e: raise e @@ -61,7 +48,7 @@ def get_tokenizer( return tokenizer -def encode( +def encode_inputs( inputs: Union[list[str], list[list[dict]]], tokenizer, return_tensors='pt', diff --git a/xtuner/rlhf/trainer/ppo.py b/xtuner/rlhf/trainer/ppo.py index 28564de8d..655c91f94 100644 --- a/xtuner/rlhf/trainer/ppo.py +++ b/xtuner/rlhf/trainer/ppo.py @@ -9,7 +9,9 @@ class PPOTrainer: def __init__( self, - actor_micro_bs=2, + policy_model: BaseModelServer, + critic_model: BaseModelServer, + policy_micro_bs=2, critic_micro_bs=2, policy_learn_time=1, critic_learn_time=1, @@ -23,25 +25,29 @@ def __init__( **kwargs, ): - self.actor_micro_bs = actor_micro_bs - self.critic_micro_bs = critic_micro_bs # policy + self.policy_model = policy_model self.policy_learn_time = policy_learn_time self.policy_minibatch = policy_minibatch - - # critic - self.critic_learn_time = critic_learn_time - self.critic_minibatch = critic_minibatch + self.policy_micro_bs = policy_micro_bs self.ppo_loss_weight = ppo_loss_weight self.pretrain_loss_weight = pretrain_loss_weight self.pretrain_criterion = pretrain_criterion self.policy_criterion = policy_criterion + + # critic + self.critic_model = critic_model + self.critic_learn_time = critic_learn_time + self.critic_minibatch = critic_minibatch + self.critic_micro_bs = critic_micro_bs + self.critic_criterion = critic_criterion - def policy_learn(self, trajectories, policy_model: BaseModelServer): + def policy_learn(self, trajectories): if self.policy_minibatch is None: self.policy_minibatch = len(trajectories.output_ids) + assert len(trajectories.output_ids) % self.policy_minibatch == 0 policy_updates = len(trajectories.output_ids) // self.policy_minibatch ppo_loss = [] pretrain_loss = [] @@ -55,24 +61,13 @@ def policy_learn(self, trajectories, policy_model: BaseModelServer): begin = i * self.policy_minibatch end = begin + self.policy_minibatch - train_input_ids = [ - trajectories.output_ids[begin:end, :], - ] + train_input_ids = [trajectories.output_ids[begin:end, :]] train_attention_mask = [ - trajectories.attention_mask[begin:end, :], - ] - train_criterion = [ - self.policy_criterion, - ] - loss_weights = [ - self.ppo_loss_weight, - ] - micro_batch_size = [ - self.actor_micro_bs, + trajectories.attention_mask[begin:end, :] ] - assert len( - trajectories.output_ids[begin:end, :] - ) == self.policy_minibatch, '[Policy learn] make sure len(policy_batch_inputs) == self.policy_minibatch' # noqa: E501 + train_criterion = [self.policy_criterion] + loss_weights = [self.ppo_loss_weight] + micro_batch_size = [self.policy_micro_bs] train_lables = [ dict( @@ -85,8 +80,9 @@ def policy_learn(self, trajectories, policy_model: BaseModelServer): ] # pretrain data if trajectories.pretrain_data is not None: - logger.info(f'[Policy Train] pretrain data \ - {trajectories.pretrain_data["input_ids"].shape}') + logger.info( + '[Policy Train] pretrain data ' + f'{trajectories.pretrain_data["input_ids"].shape}') train_input_ids.append( trajectories.pretrain_data['input_ids']) train_lables.append(trajectories.pretrain_data['labels']) @@ -95,10 +91,10 @@ def policy_learn(self, trajectories, policy_model: BaseModelServer): trajectories.pretrain_data['attention_mask']) train_criterion.append(self.pretrain_criterion) loss_weights.append(self.pretrain_loss_weight) - micro_batch_size.append(self.actor_micro_bs) + micro_batch_size.append(self.policy_micro_bs) with Timer('policy_model.train'): - p_loss = policy_model.train( + p_loss = self.policy_model.train( input_ids=train_input_ids, labels=train_lables, attention_mask=train_attention_mask, @@ -119,55 +115,33 @@ def policy_learn(self, trajectories, policy_model: BaseModelServer): ) with Timer('policy_model.sync_model'): - policy_model.sync_model() + self.policy_model.sync_model() return ppo_loss, pretrain_loss - def critic_learn_async(self, trajectories, critic_model: BaseModelServer): + def critic_learn(self, trajectories): if self.critic_minibatch is None: self.critic_minibatch = len(trajectories.output_ids) + assert len(trajectories.output_ids) % self.critic_minibatch == 0 critic_updates = len(trajectories.output_ids) // self.critic_minibatch critic_loss = [] - assert critic_updates == 1 and self.policy_learn_time == 1, f'critic_updates={critic_updates} * self.policy_learn_time={self.policy_learn_time} > 1' # noqa: E501 - with Timer('critic_model.train_async'): - critic_batch_inputs, labels = self._critic_learn_prepare( - 0, 0, trajectories, critic_updates) - v_loss_ref = critic_model.train_async( - input_ids=critic_batch_inputs['input_ids'], - labels=labels, - attention_mask=critic_batch_inputs['attention_mask'], - criterion=self.critic_criterion, - micro_batch_size=self.critic_micro_bs, - ) - logger.info(f'[critic train] {self.critic_minibatch} batch') - critic_loss.append(v_loss_ref) - return critic_loss - def critic_learn_get(self, critic_loss_ref, critic_model: BaseModelServer): - with Timer('critic_model.train_get'): - return [ - critic_model.train_get(ref).item() for ref in critic_loss_ref - ] - - def critic_learn(self, trajectories, critic_model: BaseModelServer): - if self.critic_minibatch is None: - self.critic_minibatch = len(trajectories.output_ids) - critic_updates = len(trajectories.output_ids) // self.critic_minibatch - critic_loss = [] - - for learn_i in range(self.policy_learn_time): + for learn_i in range(self.critic_learn_time): for step_i in range(critic_updates): + logger.info( + '[Critic Train] start critic trains {}/{} | {}'.format( + step_i + 1, critic_updates, learn_i + 1)) with Timer('critic_model.train'): critic_batch_inputs, labels = self._critic_learn_prepare( step_i, learn_i, trajectories, critic_updates) - v_loss = critic_model.train( + v_loss = self.critic_model.train( input_ids=critic_batch_inputs['input_ids'], labels=labels, attention_mask=critic_batch_inputs['attention_mask'], criterion=self.critic_criterion, micro_batch_size=self.critic_micro_bs, ) - logger.info(f'[Critic train] {self.critic_minibatch} batch, \ - critic loss: {v_loss.item()}') + logger.info(f'[Critic train] {self.critic_minibatch} batch, ' + f'critic loss: {v_loss.item()}') critic_loss.append(v_loss.item()) return critic_loss @@ -177,16 +151,12 @@ def _critic_learn_prepare(self, step_i, learn_i, trajectories, step_i + 1, critic_updates, learn_i + 1)) begin = step_i * self.critic_minibatch end = begin + self.critic_minibatch - critic_batch_inputs = { - 'input_ids': trajectories.output_ids[begin:end, :], - 'old_values': trajectories.old_values[begin:end, :], - 'returns': trajectories.returns[begin:end, :], - 'action_mask': trajectories.action_mask[begin:end, :], - 'attention_mask': trajectories.attention_mask[begin:end, :] - } - assert len( - critic_batch_inputs['input_ids'] - ) == self.critic_minibatch, '[critic learn] make sure len(critic_batch_inputs) == self.critic_minibatch' # noqa: E501 + critic_batch_inputs = dict( + input_ids=trajectories.output_ids[begin:end, :], + old_values=trajectories.old_values[begin:end, :], + returns=trajectories.returns[begin:end, :], + action_mask=trajectories.action_mask[begin:end, :], + attention_mask=trajectories.attention_mask[begin:end, :]) labels = dict( old_values=critic_batch_inputs['old_values'], @@ -194,3 +164,32 @@ def _critic_learn_prepare(self, step_i, learn_i, trajectories, mask=critic_batch_inputs['action_mask'], ) return critic_batch_inputs, labels + + def critic_learn_async(self, trajectories): + if self.critic_minibatch is None: + self.critic_minibatch = len(trajectories.output_ids) + assert len(trajectories.output_ids) % self.critic_minibatch == 0 + critic_updates = len(trajectories.output_ids) // self.critic_minibatch + critic_loss = [] + assert critic_updates == 1 and self.policy_learn_time == 1, \ + '[WIP] `critic_learn_async` support learn async in loop' + with Timer('critic_model.train_async'): + critic_batch_inputs, labels = self._critic_learn_prepare( + 0, 0, trajectories, critic_updates) + v_loss_ref = self.critic_model.train_async( + input_ids=critic_batch_inputs['input_ids'], + labels=labels, + attention_mask=critic_batch_inputs['attention_mask'], + criterion=self.critic_criterion, + micro_batch_size=self.critic_micro_bs, + ) + logger.info(f'[critic train] {self.critic_minibatch} batch') + critic_loss.append(v_loss_ref) + return critic_loss + + def critic_learn_get(self, critic_loss_ref): + with Timer('critic_model.train_get'): + return [ + self.critic_model.train_get(ref).item() + for ref in critic_loss_ref + ] diff --git a/xtuner/rlhf/utils.py b/xtuner/rlhf/utils.py index be21d519c..bee9ae986 100644 --- a/xtuner/rlhf/utils.py +++ b/xtuner/rlhf/utils.py @@ -29,8 +29,7 @@ def expand_reward_token_id(reward_token_id: int, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, pad_token_id=0): - assert len(input_ids.shape) == 2, \ - f'expand_reward_token_id error, len(input_ids.shape()) = {len(input_ids.shape())}' # noqa: E501 + assert len(input_ids.shape) == 2 new_input_ids = torch.zeros((input_ids.shape[0], input_ids.shape[1] + 1), dtype=input_ids.dtype).to(input_ids.device) new_attention_mask = torch.zeros_like( From 9599838d6c2797aa401a39847de1fe93329ea92d Mon Sep 17 00:00:00 2001 From: lishuaibin Date: Wed, 19 Jun 2024 17:26:16 +0800 Subject: [PATCH 05/37] save tokenizer --- xtuner/rlhf/model_server/base_model_server.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/xtuner/rlhf/model_server/base_model_server.py b/xtuner/rlhf/model_server/base_model_server.py index 143553d2f..6fec04d46 100644 --- a/xtuner/rlhf/model_server/base_model_server.py +++ b/xtuner/rlhf/model_server/base_model_server.py @@ -154,6 +154,8 @@ def state_dict_get(self): def save_model(self, path): self.trainer.save_model(path) + if self.tokenizer is not None: + self.tokenizer.save_pretrained(path) # Misc. def set_seed(self, seed: int = None): From bb96eef7107207606660ad931e1b22a0b8e1d62d Mon Sep 17 00:00:00 2001 From: lishuaibin Date: Wed, 19 Jun 2024 19:37:38 +0800 Subject: [PATCH 06/37] rm/move models --- xtuner/rlhf/model_backend/models/__init__.py | 0 .../models/configuration_internlm2.py | 159 -- .../models/modeling_internlm2_p.py | 1536 ----------------- xtuner/rlhf/model_server/base_model_server.py | 3 - .../rlhf/model_server/critic_model_server.py | 2 +- .../rlhf/model_server/reward_model_server.py | 2 +- .../utils.py} | 1 + 7 files changed, 3 insertions(+), 1700 deletions(-) delete mode 100644 xtuner/rlhf/model_backend/models/__init__.py delete mode 100644 xtuner/rlhf/model_backend/models/configuration_internlm2.py delete mode 100644 xtuner/rlhf/model_backend/models/modeling_internlm2_p.py rename xtuner/rlhf/{model_backend/models/critical_and_reward.py => model_server/utils.py} (97%) diff --git a/xtuner/rlhf/model_backend/models/__init__.py b/xtuner/rlhf/model_backend/models/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/xtuner/rlhf/model_backend/models/configuration_internlm2.py b/xtuner/rlhf/model_backend/models/configuration_internlm2.py deleted file mode 100644 index c76e1407f..000000000 --- a/xtuner/rlhf/model_backend/models/configuration_internlm2.py +++ /dev/null @@ -1,159 +0,0 @@ -# flake8: noqa: E501 -# Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved. -# -# This code is based on transformers/src/transformers/models/llama/configuration_llama.py -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""InternLM2 model configuration.""" - -from transformers.configuration_utils import PretrainedConfig -from transformers.utils import logging - -logger = logging.get_logger(__name__) - -INTERNLM2_PRETRAINED_CONFIG_ARCHIVE_MAP = {} - - -# Modified from transformers.model.llama.configuration_llama.LlamaConfig -class InternLM2Config(PretrainedConfig): - r""" - This is the configuration class to store the configuration of a [`InternLM2Model`]. It is used to instantiate - an InternLM2 model according to the specified arguments, defining the model architecture. Instantiating a - configuration with the defaults will yield a similar configuration to that of the InternLM2-7B. - - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PretrainedConfig`] for more information. - - - Args: - vocab_size (`int`, *optional*, defaults to 32000): - Vocabulary size of the InternLM2 model. Defines the number of different tokens that can be represented by the - `inputs_ids` passed when calling [`InternLM2Model`] - hidden_size (`int`, *optional*, defaults to 4096): - Dimension of the hidden representations. - intermediate_size (`int`, *optional*, defaults to 11008): - Dimension of the MLP representations. - num_hidden_layers (`int`, *optional*, defaults to 32): - Number of hidden layers in the Transformer encoder. - num_attention_heads (`int`, *optional*, defaults to 32): - Number of attention heads for each attention layer in the Transformer encoder. - num_key_value_heads (`int`, *optional*): - This is the number of key_value heads that should be used to implement Grouped Query Attention. If - `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if - `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When - converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed - by meanpooling all the original heads within that group. For more details checkout [this - paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to - `num_attention_heads`. - hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): - The non-linear activation function (function or string) in the decoder. - max_position_embeddings (`int`, *optional*, defaults to 2048): - The maximum sequence length that this model might ever be used with. Typically set this to something large - just in case (e.g., 512 or 1024 or 2048). - initializer_range (`float`, *optional*, defaults to 0.02): - The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - rms_norm_eps (`float`, *optional*, defaults to 1e-12): - The epsilon used by the rms normalization layers. - use_cache (`bool`, *optional*, defaults to `True`): - Whether or not the model should return the last key/values attentions (not used by all models). Only - relevant if `config.is_decoder=True`. - tie_word_embeddings(`bool`, *optional*, defaults to `False`): - Whether to tie weight embeddings - Example: - - """ - model_type = 'internlm2' - _auto_class = 'AutoConfig' - - def __init__( # pylint: disable=W0102 - self, - vocab_size=103168, - hidden_size=4096, - intermediate_size=11008, - num_hidden_layers=32, - num_attention_heads=32, - num_key_value_heads=None, - hidden_act='silu', - max_position_embeddings=2048, - initializer_range=0.02, - rms_norm_eps=1e-6, - use_cache=True, - pad_token_id=0, - bos_token_id=1, - eos_token_id=2, - reward_token_id=92527, - two_linear_reward_head=False, - tie_word_embeddings=False, - bias=True, - rope_theta=10000, - rope_scaling=None, - attn_implementation='eager', - **kwargs, - ): - self.vocab_size = vocab_size - self.max_position_embeddings = max_position_embeddings - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.bias = bias - - if num_key_value_heads is None: - num_key_value_heads = num_attention_heads - self.num_key_value_heads = num_key_value_heads - - self.hidden_act = hidden_act - self.initializer_range = initializer_range - self.rms_norm_eps = rms_norm_eps - self.use_cache = use_cache - self.rope_theta = rope_theta - self.rope_scaling = rope_scaling - self._rope_scaling_validation() - - self.attn_implementation = attn_implementation - if self.attn_implementation is None: - self.attn_implementation = 'eager' - - self.reward_token_id = reward_token_id - self.two_linear_reward_head = two_linear_reward_head - super().__init__( - pad_token_id=pad_token_id, - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - tie_word_embeddings=tie_word_embeddings, - **kwargs, - ) - - def _rope_scaling_validation(self): - """Validate the `rope_scaling` configuration.""" - if self.rope_scaling is None: - return - - if not isinstance(self.rope_scaling, - dict) or len(self.rope_scaling) != 2: - raise ValueError( - '`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, ' - f'got {self.rope_scaling}') - rope_scaling_type = self.rope_scaling.get('type', None) - rope_scaling_factor = self.rope_scaling.get('factor', None) - if rope_scaling_type is None or rope_scaling_type not in [ - 'linear', 'dynamic' - ]: - raise ValueError( - f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" - ) - if rope_scaling_factor is None or not isinstance( - rope_scaling_factor, float) or rope_scaling_factor < 1.0: - raise ValueError( - f"`rope_scaling`'s factor field must be a float >= 1, got {rope_scaling_factor}" - ) diff --git a/xtuner/rlhf/model_backend/models/modeling_internlm2_p.py b/xtuner/rlhf/model_backend/models/modeling_internlm2_p.py deleted file mode 100644 index 9d87f94ad..000000000 --- a/xtuner/rlhf/model_backend/models/modeling_internlm2_p.py +++ /dev/null @@ -1,1536 +0,0 @@ -# flake8: noqa: E501 -# Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved. -# -# This code is based on transformers/src/transformers/models/llama/modeling_llama.py -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""PyTorch InternLM2 model.""" -import math -import queue -import threading -import warnings -from typing import List, Optional, Tuple, Union - -import torch -import torch.nn.functional as F -import torch.utils.checkpoint -from einops import rearrange -from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from transformers.activations import ACT2FN -from transformers.modeling_outputs import (BaseModelOutputWithPast, - CausalLMOutputWithPast, - SequenceClassifierOutputWithPast) -from transformers.modeling_utils import PreTrainedModel -from transformers.utils import (add_start_docstrings, - add_start_docstrings_to_model_forward, logging, - replace_return_docstrings) - -try: - from transformers.generation.streamers import BaseStreamer -except: # pylint: disable=bare-except - BaseStreamer = None - -from .configuration_internlm2 import InternLM2Config - -logger = logging.get_logger(__name__) - -_CONFIG_FOR_DOC = 'InternLM2Config' - -flash_attn_func, flash_attn_varlen_func = None, None -pad_input, index_first_axis, unpad_input = None, None, None - - -def _import_flash_attn(): - global flash_attn_func, flash_attn_varlen_func - global pad_input, index_first_axis, unpad_input - try: - from flash_attn import flash_attn_func as _flash_attn_func - from flash_attn import \ - flash_attn_varlen_func as _flash_attn_varlen_func - from flash_attn.bert_padding import \ - index_first_axis as _index_first_axis - from flash_attn.bert_padding import pad_input as _pad_input - from flash_attn.bert_padding import unpad_input as _unpad_input - flash_attn_func, flash_attn_varlen_func = _flash_attn_func, _flash_attn_varlen_func - pad_input, index_first_axis, unpad_input = _pad_input, _index_first_axis, _unpad_input - except ImportError: - raise ImportError('flash_attn is not installed.') - - -# Copied from transformers.models.llama.modeling_llama._get_unpad_data -def _get_unpad_data(attention_mask): - seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) - indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() - max_seqlen_in_batch = seqlens_in_batch.max().item() - cu_seqlens = F.pad( - torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) - return ( - indices, - cu_seqlens, - max_seqlen_in_batch, - ) - - -# Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask(input_ids_shape: torch.Size, - dtype: torch.dtype, - device: torch.device, - past_key_values_length: int = 0): - """Make causal mask used for bi-directional self-attention.""" - bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), - torch.tensor(torch.finfo(dtype).min, device=device), - device=device) - mask_cond = torch.arange(mask.size(-1), device=device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - mask = mask.to(dtype) - - if past_key_values_length > 0: - mask = torch.cat([ - torch.zeros( - tgt_len, past_key_values_length, dtype=dtype, device=device), - mask - ], - dim=-1) - return mask[None, None, :, :].expand(bsz, 1, tgt_len, - tgt_len + past_key_values_length) - - -# Copied from transformers.models.bart.modeling_bart._expand_mask -def _expand_mask(mask: torch.Tensor, - dtype: torch.dtype, - tgt_len: Optional[int] = None): - """Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, - src_seq_len]`.""" - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, - src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill( - inverted_mask.to(torch.bool), - torch.finfo(dtype).min) - - -# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->InternLM2 -class InternLM2RMSNorm(nn.Module): - - def __init__(self, hidden_size, eps=1e-6): - """InternLM2RMSNorm is equivalent to T5LayerNorm.""" - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + - self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - -# Copied from transformers.model.llama.modeling_llama.LlamaRotaryEmbedding with Llama->InternLM2 -class InternLM2RotaryEmbedding(nn.Module): - - def __init__(self, - dim, - max_position_embeddings=2048, - base=10000, - device=None): - super().__init__() - - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / ( - self.base - **(torch.arange(0, self.dim, 2).float().to(device) / self.dim)) - self.register_buffer('inv_freq', inv_freq, persistent=False) - - # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache( - seq_len=max_position_embeddings, - device=self.inv_freq.device, - dtype=torch.get_default_dtype()) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - t = torch.arange( - self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - - freqs = torch.einsum('i,j->ij', t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer( - 'cos_cached', emb.cos().to(dtype), persistent=False) - self.register_buffer( - 'sin_cached', emb.sin().to(dtype), persistent=False) - - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache( - seq_len=seq_len, device=x.device, dtype=torch.float32) - - return ( - self.cos_cached[:seq_len].to(dtype=x.dtype), - self.sin_cached[:seq_len].to(dtype=x.dtype), - ) - - -# Copied from transformers.model.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->InternLM2 -class InternLM2LinearScalingRotaryEmbedding(InternLM2RotaryEmbedding): - """InternLM2RotaryEmbedding extended with linear scaling. - - Credits to the Reddit user /u/kaiokendev - """ - - def __init__(self, - dim, - max_position_embeddings=2048, - base=10000, - device=None, - scaling_factor=1.0): - self.scaling_factor = scaling_factor - super().__init__(dim, max_position_embeddings, base, device) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - t = torch.arange( - self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - t = t / self.scaling_factor - - freqs = torch.einsum('i,j->ij', t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer( - 'cos_cached', emb.cos().to(dtype), persistent=False) - self.register_buffer( - 'sin_cached', emb.sin().to(dtype), persistent=False) - - -# Copied from transformers.model.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->InternLM2 -class InternLM2DynamicNTKScalingRotaryEmbedding(InternLM2RotaryEmbedding): - """InternLM2RotaryEmbedding extended with Dynamic NTK scaling. - - Credits to the Reddit users /u/bloc97 and /u/emozilla. - """ - - def __init__(self, - dim, - max_position_embeddings=2048, - base=10000, - device=None, - scaling_factor=1.0): - self.scaling_factor = scaling_factor - super().__init__(dim, max_position_embeddings, base, device) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - - if seq_len > self.max_position_embeddings: - base = self.base * ((self.scaling_factor * seq_len / - self.max_position_embeddings) - - (self.scaling_factor - 1))**( - self.dim / (self.dim - 2)) - inv_freq = 1.0 / ( - base - **(torch.arange(0, self.dim, 2).float().to(device) / self.dim)) - self.register_buffer('inv_freq', inv_freq, persistent=False) - - t = torch.arange( - self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - - freqs = torch.einsum('i,j->ij', t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer( - 'cos_cached', emb.cos().to(dtype), persistent=False) - self.register_buffer( - 'sin_cached', emb.sin().to(dtype), persistent=False) - - -# Copied from transformers.model.llama.modeling_llama.rotate_half -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., :x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2:] - return torch.cat((-x2, x1), dim=-1) - - -# Copied from transformers.model.llama.modeling_llama.apply_rotary_pos_emb -def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors.""" - cos = cos[position_ids].unsqueeze(unsqueeze_dim) - sin = sin[position_ids].unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -class InternLM2MLP(nn.Module): - - def __init__(self, config): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.w1 = nn.Linear( - self.hidden_size, self.intermediate_size, bias=False) - self.w3 = nn.Linear( - self.hidden_size, self.intermediate_size, bias=False) - self.w2 = nn.Linear( - self.intermediate_size, self.hidden_size, bias=False) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, x): - down_proj = self.w2(self.act_fn(self.w1(x)) * self.w3(x)) - - return down_proj - - -# Copied from transformers.model.llama.modeling_llama.repeat_kv -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """This is the equivalent of torch.repeat_interleave(x, dim=1, - repeats=n_rep). - - The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to - (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, - None, :, :].expand(batch, - num_key_value_heads, - n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, - head_dim) - - -# Modified from transformers.model.llama.modeling_llama.LlamaAttention -class InternLM2Attention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper.""" - - def __init__(self, config: InternLM2Config): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.max_position_embeddings = config.max_position_embeddings - self.is_causal = True - - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f'hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}' - f' and `num_heads`: {self.num_heads}).') - - self.wqkv = nn.Linear( - self.hidden_size, - (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim, - bias=config.bias, - ) - - self.wo = nn.Linear( - self.num_heads * self.head_dim, self.hidden_size, bias=config.bias) - self._init_rope() - - def _init_rope(self): - if self.config.rope_scaling is None: - self.rotary_emb = InternLM2RotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - base=self.config.rope_theta, - ) - else: - scaling_type = self.config.rope_scaling['type'] - scaling_factor = self.config.rope_scaling['factor'] - if scaling_type == 'dynamic': - self.rotary_emb = InternLM2DynamicNTKScalingRotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - base=self.config.rope_theta, - scaling_factor=scaling_factor, - ) - elif scaling_type == 'linear': - self.rotary_emb = InternLM2LinearScalingRotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - base=self.config.rope_theta, - scaling_factor=scaling_factor, - ) - else: - raise ValueError( - "Currently we only support rotary embedding's type being 'dynamic' or 'linear'." - ) - return self.rotary_emb - - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, - self.head_dim).transpose(1, 2).contiguous() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], - Optional[Tuple[torch.Tensor]]]: - if 'padding_mask' in kwargs: - warnings.warn( - 'Passing `padding_mask` is deprecated and will be removed in v4.37. ' - 'Please make sure use `attention_mask` instead.`') - - bsz, q_len, _ = hidden_states.size() - - qkv_states = self.wqkv(hidden_states) - - qkv_states = rearrange( - qkv_states, - 'b q (h gs d) -> b q h gs d', - gs=2 + self.num_key_value_groups, - d=self.head_dim, - ) - - query_states = qkv_states[..., :self.num_key_value_groups, :] - query_states = rearrange(query_states, 'b q h gs d -> b q (h gs) d') - key_states = qkv_states[..., -2, :] - value_states = qkv_states[..., -1, :] - - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb( - query_states, key_states, cos, sin, position_ids) - - if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - past_key_value = (key_states, value_states) if use_cache else None - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul(query_states, key_states.transpose( - 2, 3)) / math.sqrt(self.head_dim) - - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( - f'Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is' - f' {attn_weights.size()}') - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f'Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}' - ) - attn_weights = attn_weights + attention_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax( - attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f'`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is' - f' {attn_output.size()}') - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - attn_output = self.wo(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -# Modified from transformers.model.llama.modeling_llama.InternLM2FlashAttention2 -class InternLM2FlashAttention2(InternLM2Attention): - """InternLM2 flash attention module. - - This module inherits from `InternLM2Attention` as the weights of the module - stays untouched. The only required change would be on the forward pass - where it needs to correctly call the public API of flash attention and deal - with padding tokens in case the input contains any of them. - """ - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], - Optional[Tuple[torch.Tensor]]]: - # InternLM2FlashAttention2 attention does not support output_attentions - if 'padding_mask' in kwargs: - warnings.warn( - 'Passing `padding_mask` is deprecated and will be removed in v4.37. ' - 'Please make sure use `attention_mask` instead.`') - - # overwrite attention_mask with padding_mask - attention_mask = kwargs.pop('padding_mask') - - output_attentions = False - - bsz, q_len, _ = hidden_states.size() - - qkv_states = self.wqkv(hidden_states) - - qkv_states = rearrange( - qkv_states, - 'b q (h gs d) -> b q h gs d', - gs=2 + self.num_key_value_groups, - d=self.head_dim, - ) - - query_states = qkv_states[..., :self.num_key_value_groups, :] - query_states = rearrange(query_states, 'b q h gs d -> b q (h gs) d') - key_states = qkv_states[..., -2, :] - value_states = qkv_states[..., -1, :] - - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - - query_states, key_states = apply_rotary_pos_emb( - query_states, key_states, cos, sin, position_ids) - - if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - past_key_value = (key_states, value_states) if use_cache else None - - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - attn_output = self._flash_attention_forward(query_states, key_states, - value_states, - attention_mask, q_len) - attn_output = attn_output.reshape(bsz, q_len, - self.hidden_size).contiguous() - attn_output = self.wo(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - def _flash_attention_forward(self, - query_states, - key_states, - value_states, - attention_mask, - query_length, - dropout=0.0, - softmax_scale=None): - """ - Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token - first unpad the input, then computes the attention scores and pad the final attention scores. - - Args: - query_states (`torch.Tensor`): - Input query states to be passed to Flash Attention API - key_states (`torch.Tensor`): - Input key states to be passed to Flash Attention API - value_states (`torch.Tensor`): - Input value states to be passed to Flash Attention API - attention_mask (`torch.Tensor`): - The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the - position of padding tokens and 1 for the position of non-padding tokens. - dropout (`int`, *optional*): - Attention dropout - softmax_scale (`float`, *optional*): - The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) - """ - # Contains at least one padding token in the sequence - causal = self.is_causal and query_length != 1 - if attention_mask is not None: - batch_size = query_states.shape[0] - query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._unpad_input( - query_states, key_states, value_states, attention_mask, - query_length) - - cu_seqlens_q, cu_seqlens_k = cu_seq_lens - max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens - - attn_output_unpad = flash_attn_varlen_func( - query_states, - key_states, - value_states, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_in_batch_q, - max_seqlen_k=max_seqlen_in_batch_k, - dropout_p=dropout, - softmax_scale=softmax_scale, - causal=causal, - ) - - attn_output = pad_input(attn_output_unpad, indices_q, batch_size, - query_length) - else: - attn_output = flash_attn_func( - query_states, - key_states, - value_states, - dropout, - softmax_scale=softmax_scale, - causal=causal) - - return attn_output - - def _unpad_input(self, query_layer, key_layer, value_layer, attention_mask, - query_length): - indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data( - attention_mask) - batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape - - key_layer = index_first_axis( - key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, - head_dim), indices_k) - value_layer = index_first_axis( - value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, - head_dim), indices_k) - - if query_length == kv_seq_len: - query_layer = index_first_axis( - query_layer.reshape(batch_size * kv_seq_len, self.num_heads, - head_dim), indices_k) - cu_seqlens_q = cu_seqlens_k - max_seqlen_in_batch_q = max_seqlen_in_batch_k - indices_q = indices_k - elif query_length == 1: - max_seqlen_in_batch_q = 1 - cu_seqlens_q = torch.arange( - batch_size + 1, dtype=torch.int32, device=query_layer.device - ) # There is a memcpy here, that is very bad. - indices_q = cu_seqlens_q[:-1] - query_layer = query_layer.squeeze(1) - else: - # The -q_len: slice assumes left padding. - attention_mask = attention_mask[:, -query_length:] - query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input( - query_layer, attention_mask) - - return ( - query_layer, - key_layer, - value_layer, - indices_q.to(torch.int64), - (cu_seqlens_q, cu_seqlens_k), - (max_seqlen_in_batch_q, max_seqlen_in_batch_k), - ) - - -INTERNLM2_ATTENTION_CLASSES = { - 'eager': InternLM2Attention, - 'flash_attention_2': InternLM2FlashAttention2, -} - - -# Modified from transformers.model.llama.modeling_llama.LlamaDecoderLayer -class InternLM2DecoderLayer(nn.Module): - - def __init__(self, config: InternLM2Config): - super().__init__() - self.hidden_size = config.hidden_size - - self.attention = INTERNLM2_ATTENTION_CLASSES[ - config.attn_implementation]( - config=config) - - self.feed_forward = InternLM2MLP(config) - self.attention_norm = InternLM2RMSNorm( - config.hidden_size, eps=config.rms_norm_eps) - self.ffn_norm = InternLM2RMSNorm( - config.hidden_size, eps=config.rms_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - **kwargs, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, - torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): - attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, - query_sequence_length, key_sequence_length)` if default attention is used. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - """ - if 'padding_mask' in kwargs: - warnings.warn( - 'Passing `padding_mask` is deprecated and will be removed in v4.37. ' - 'Please make sure use `attention_mask` instead.`') - - residual = hidden_states - - hidden_states = self.attention_norm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.attention( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - **kwargs, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.ffn_norm(hidden_states) - hidden_states = self.feed_forward(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states, ) - - if output_attentions: - outputs += (self_attn_weights, ) - - if use_cache: - outputs += (present_key_value, ) - - return outputs - - -InternLM2_START_DOCSTRING = r""" - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`InternLM2Config`]): - Model configuration class with all the parameters of the model. Initializing with a config file does not - load the weights associated with the model, only the configuration. Check out the - [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - - -# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel with Llama->InternLM2 -@add_start_docstrings( - 'The bare InternLM2 Model outputting raw hidden-states without any specific head on top.', - InternLM2_START_DOCSTRING, -) -class InternLM2PreTrainedModel(PreTrainedModel): - config_class = InternLM2Config - base_model_prefix = 'model' - supports_gradient_checkpointing = True - _no_split_modules = ['InternLM2DecoderLayer'] - _skip_keys_device_placement = 'past_key_values' - - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - -InternLM2_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - If `past_key_values` is used, optionally only the last `input_ids` have to be input (see - `past_key_values`). - - If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] - and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more - information on the default strategy. - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.n_positions - 1]`. - - [What are position IDs?](../glossary#position-ids) - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or - when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape - `(batch_size, num_heads, decoder_sequence_length, embed_size_per_head)`. - - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. - - If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't - have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` - of shape `(batch_size, sequence_length)`. - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -# Modified from transformers.model.llama.modeling_llama.LlamaModel -@add_start_docstrings( - 'The bare InternLM2 Model outputting raw hidden-states without any specific head on top.', - InternLM2_START_DOCSTRING, -) -class InternLM2Model(InternLM2PreTrainedModel): - """Transformer decoder consisting of *config.num_hidden_layers* layers. - Each layer is a [`InternLM2DecoderLayer`] - - Args: - config: InternLM2Config - """ - - _auto_class = 'AutoModel' - - def __init__(self, config: InternLM2Config): - super().__init__(config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - self.config = config - - self.tok_embeddings = nn.Embedding(config.vocab_size, - config.hidden_size, - self.padding_idx) - - self.layers = nn.ModuleList([ - InternLM2DecoderLayer(config) - for _ in range(config.num_hidden_layers) - ]) - self.norm = InternLM2RMSNorm( - config.hidden_size, eps=config.rms_norm_eps) - - self.gradient_checkpointing = False - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.tok_embeddings - - def set_input_embeddings(self, value): - self.tok_embeddings = value - - def _prepare_decoder_attention_mask(self, attention_mask, input_shape, - inputs_embeds, past_key_values_length): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - inputs_embeds.dtype, - device=inputs_embeds.device, - past_key_values_length=past_key_values_length, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask( - attention_mask, inputs_embeds.dtype, - tgt_len=input_shape[-1]).to(inputs_embeds.device) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else - expanded_attn_mask + combined_attention_mask) - - return combined_attention_mask - - @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else - self.config.output_hidden_states) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if self.config.attn_implementation == 'flash_attention_2': - _import_flash_attn() - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError( - 'You cannot specify both input_ids and inputs_embeds at the same time' - ) - elif input_ids is not None: - batch_size, seq_length = input_ids.shape[:2] - elif inputs_embeds is not None: - batch_size, seq_length = inputs_embeds.shape[:2] - else: - raise ValueError( - 'You have to specify either input_ids or inputs_embeds') - - seq_length_with_past = seq_length - past_key_values_length = 0 - if past_key_values is not None: - past_key_values_length = past_key_values[0][0].shape[2] - seq_length_with_past = seq_length_with_past + past_key_values_length - - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, - seq_length + past_key_values_length, - dtype=torch.long, - device=device) - position_ids = position_ids.unsqueeze(0) - - if inputs_embeds is None: - inputs_embeds = self.tok_embeddings(input_ids) - - if self.config.attn_implementation == 'flash_attention_2': - # 2d mask is passed through the layers - attention_mask = attention_mask if ( - attention_mask is not None and 0 in attention_mask) else None - else: - if attention_mask is None: - attention_mask = torch.ones((batch_size, seq_length_with_past), - dtype=torch.bool, - device=inputs_embeds.device) - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, (batch_size, seq_length), inputs_embeds, - past_key_values_length) - - # embed positions - hidden_states = inputs_embeds - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - '`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...' - ) - use_cache = False - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None - - for idx, decoder_layer in enumerate(self.layers): - if output_hidden_states: - all_hidden_states += (hidden_states, ) - - past_key_value = past_key_values[ - idx] if past_key_values is not None else None - - if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, None) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), - hidden_states, - attention_mask, - position_ids, - None, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += ( - layer_outputs[2 if output_attentions else 1], ) - - if output_attentions: - all_self_attns += (layer_outputs[1], ) - - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states, ) - - next_cache = next_decoder_cache if use_cache else None - if not return_dict: - return tuple( - v for v in - [hidden_states, next_cache, all_hidden_states, all_self_attns] - if v is not None) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - -class NormHead(nn.Module): - - def __init__(self, hidden_size, vocab_size, bias=False): - super().__init__() - self.weight = nn.Parameter(torch.empty((vocab_size, hidden_size))) - self.first_flag = True - - def forward(self, hidden_states): - norm_weight = nn.functional.normalize(self.weight) - return nn.functional.linear(hidden_states, norm_weight) - - -# Modified from transformers.model.llama.modeling_llama.LlamaForCausalLM -class InternLM2ForCausalLM(InternLM2PreTrainedModel): - _auto_class = 'AutoModelForCausalLM' - - _tied_weights_keys = ['output.weight'] - - def __init__(self, config): - super().__init__(config) - self.model = InternLM2Model(config) - self.vocab_size = config.vocab_size - self.output = NormHead( - config.hidden_size, config.vocab_size, bias=False) - - # Initialize weights and apply final processing - self.post_init() - self.norm_head = True - self.first_eval_flag = True - self.tmp_weight = None - - def get_input_embeddings(self): - return self.model.tok_embeddings - - def set_input_embeddings(self, value): - self.model.tok_embeddings = value - - def get_output_embeddings(self): - return self.output - - def set_output_embeddings(self, new_embeddings): - self.output = new_embeddings - - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model - - @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING) - @replace_return_docstrings( - output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, InternLM2ForCausalLM - - >>> model = InternLM2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) - >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) - - >>> prompt = "Hey, are you conscious? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." - ```""" - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else - self.config.output_hidden_states) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - - logits = self.output(hidden_states) - logits = logits.float() - - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - - if not return_dict: - output = (logits, ) + outputs[1:] - return (loss, ) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def prepare_inputs_for_generation(self, - input_ids, - past_key_values=None, - attention_mask=None, - inputs_embeds=None, - **kwargs): - if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] - - # Some generation methods already pass only the last input ID - if input_ids.shape[1] > past_length: - remove_prefix_length = past_length - else: - # Default to old behavior: keep only final ID - remove_prefix_length = input_ids.shape[1] - 1 - - input_ids = input_ids[:, remove_prefix_length:] - - position_ids = kwargs.get('position_ids', None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1]:] - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {'inputs_embeds': inputs_embeds} - else: - model_inputs = {'input_ids': input_ids} - - model_inputs.update({ - 'position_ids': position_ids, - 'past_key_values': past_key_values, - 'use_cache': kwargs.get('use_cache'), - 'attention_mask': attention_mask, - }) - return model_inputs - - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += (tuple( - past_state.index_select(0, beam_idx.to(past_state.device)) - for past_state in layer_past), ) - return reordered_past - - def build_inputs(self, - tokenizer, - query: str, - history: List[Tuple[str, str]] = [], - meta_instruction=''): - if tokenizer.add_bos_token: - prompt = '' - else: - prompt = tokenizer.bos_token - if meta_instruction: - prompt += f"""<|im_start|>system\n{meta_instruction}<|im_end|>\n""" - for record in history: - prompt += f"""<|im_start|>user\n{record[0]}<|im_end|>\n<|im_start|>assistant\n{record[1]}<|im_end|>\n""" - prompt += f"""<|im_start|>user\n{query}<|im_end|>\n<|im_start|>assistant\n""" - return tokenizer([prompt], return_tensors='pt') - - @torch.no_grad() - def chat( - self, - tokenizer, - query: str, - history: List[Tuple[str, str]] = [], - streamer: Optional[BaseStreamer] = None, - max_new_tokens: int = 1024, - do_sample: bool = True, - temperature: float = 0.8, - top_p: float = 0.8, - meta_instruction: - str = 'You are an AI assistant whose name is InternLM (书生·浦语).\n' - '- InternLM (书生·浦语) is a conversational language model that is developed by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n' - '- InternLM (书生·浦语) can understand and communicate fluently in the language chosen by the user such as English and 中文.', - **kwargs, - ): - inputs = self.build_inputs(tokenizer, query, history, meta_instruction) - inputs = { - k: v.to(self.device) - for k, v in inputs.items() if torch.is_tensor(v) - } - # also add end-of-assistant token in eos token id to avoid unnecessary generation - eos_token_id = [ - tokenizer.eos_token_id, - tokenizer.convert_tokens_to_ids(['<|im_end|>'])[0] - ] - outputs = self.generate( - **inputs, - streamer=streamer, - max_new_tokens=max_new_tokens, - do_sample=do_sample, - temperature=temperature, - top_p=top_p, - eos_token_id=eos_token_id, - **kwargs, - ) - outputs = outputs[0].cpu().tolist()[len(inputs['input_ids'][0]):] - response = tokenizer.decode(outputs, skip_special_tokens=True) - response = response.split('<|im_end|>')[0] - history = history + [(query, response)] - return response, history - - @torch.no_grad() - def stream_chat( - self, - tokenizer, - query: str, - history: List[Tuple[str, str]] = [], - max_new_tokens: int = 1024, - do_sample: bool = True, - temperature: float = 0.8, - top_p: float = 0.8, - **kwargs, - ): - """Return a generator in format: (response, history) Eg. - - ('你好,有什么可以帮助您的吗', [('你好', '你好,有什么可以帮助您的吗')]) ('你好,有什么可以帮助您的吗?', [('你好', - '你好,有什么可以帮助您的吗?')]) - """ - if BaseStreamer is None: - raise ModuleNotFoundError( - 'The version of `transformers` is too low. Please make sure ' - 'that you have installed `transformers>=4.28.0`.') - - response_queue = queue.Queue(maxsize=20) - - class ChatStreamer(BaseStreamer): - - def __init__(self, tokenizer) -> None: - super().__init__() - self.tokenizer = tokenizer - self.queue = response_queue - self.query = query - self.history = history - self.response = '' - self.cache = [] - self.received_inputs = False - self.queue.put( - (self.response, history + [(self.query, self.response)])) - - def put(self, value): - if len(value.shape) > 1 and value.shape[0] > 1: - raise ValueError('ChatStreamer only supports batch size 1') - elif len(value.shape) > 1: - value = value[0] - - if not self.received_inputs: - # The first received value is input_ids, ignore here - self.received_inputs = True - return - - self.cache.extend(value.tolist()) - token = self.tokenizer.decode( - self.cache, skip_special_tokens=True) - if token.strip() != '<|im_end|>': - self.response = self.response + token - history = self.history + [(self.query, self.response)] - self.queue.put((self.response, history)) - self.cache = [] - else: - self.end() - - def end(self): - self.queue.put(None) - - def stream_producer(): - return self.chat( - tokenizer=tokenizer, - query=query, - streamer=ChatStreamer(tokenizer=tokenizer), - history=history, - max_new_tokens=max_new_tokens, - do_sample=do_sample, - temperature=temperature, - top_p=top_p, - **kwargs, - ) - - def consumer(): - producer = threading.Thread(target=stream_producer) - producer.start() - while True: - res = response_queue.get() - if res is None: - return - yield res - - return consumer() - - -# Copied from transformers.model.llama.modeling_llama.LlamaForSequenceClassification with Llama->InternLM2 -@add_start_docstrings( - """ - The InternLM2 Model transformer with a sequence classification head on top (linear layer). - - [`InternLM2ForSequenceClassification`] uses the last token in order to do the classification, - as other causal models (e.g. GPT-2) do. - - Since it does classification on the last token, it requires to know the position of the last token. If a - `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If - no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the - padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in - each row of the batch). - """, - InternLM2_START_DOCSTRING, -) -class InternLM2ForSequenceClassification(InternLM2PreTrainedModel): - - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_labels - self.model = InternLM2Model(config) - self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.tok_embeddings - - def set_input_embeddings(self, value): - self.model.tok_embeddings = value - - @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, SequenceClassifierOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - transformer_outputs = self.model( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - hidden_states = transformer_outputs[0] - logits = self.score(hidden_states) - - if input_ids is not None: - batch_size = input_ids.shape[0] - else: - batch_size = inputs_embeds.shape[0] - - if self.config.pad_token_id is None and batch_size != 1: - raise ValueError( - 'Cannot handle batch sizes > 1 if no padding token is defined.' - ) - if self.config.pad_token_id is None: - sequence_lengths = -1 - else: - if input_ids is not None: - sequence_lengths = (torch.eq( - input_ids, self.config.pad_token_id).int().argmax(-1) - - 1).to(logits.device) - else: - sequence_lengths = -1 - - pooled_logits = logits[torch.arange(batch_size, device=logits.device), - sequence_lengths] - - loss = None - if labels is not None: - labels = labels.to(logits.device) - if self.config.problem_type is None: - if self.num_labels == 1: - self.config.problem_type = 'regression' - elif self.num_labels > 1 and (labels.dtype == torch.long - or labels.dtype == torch.int): - self.config.problem_type = 'single_label_classification' - else: - self.config.problem_type = 'multi_label_classification' - - if self.config.problem_type == 'regression': - loss_fct = MSELoss() - if self.num_labels == 1: - loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) - else: - loss = loss_fct(pooled_logits, labels) - elif self.config.problem_type == 'single_label_classification': - loss_fct = CrossEntropyLoss() - loss = loss_fct( - pooled_logits.view(-1, self.num_labels), labels.view(-1)) - elif self.config.problem_type == 'multi_label_classification': - loss_fct = BCEWithLogitsLoss() - loss = loss_fct(pooled_logits, labels) - if not return_dict: - output = (pooled_logits, ) + transformer_outputs[1:] - return ((loss, ) + output) if loss is not None else output - - return SequenceClassifierOutputWithPast( - loss=loss, - logits=pooled_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) diff --git a/xtuner/rlhf/model_server/base_model_server.py b/xtuner/rlhf/model_server/base_model_server.py index 6fec04d46..73702aefa 100644 --- a/xtuner/rlhf/model_server/base_model_server.py +++ b/xtuner/rlhf/model_server/base_model_server.py @@ -7,7 +7,6 @@ from ..config.config_consts import ENGINE_HUGGINGFACE, ENGINE_INTERNEVO from ..model_backend.hf_model_runner import HfModelRunnerRayActorGroup -from ..model_backend.models.modeling_internlm2_p import InternLM2ForCausalLM from ..tokenizer import encode_inputs, get_tokenizer DEFAULT_GET_TIMEOUT = 600.0 # 10 min @@ -56,8 +55,6 @@ def init_trainer_config(self, model_config, tokenizer_config): def get_model_class(self, model_path): # will be changed in subclasses - if model_path == 'internlm/internlm2-chat-1_8b-sft': - return InternLM2ForCausalLM return AutoModelForCausalLM def initialize_async(self): diff --git a/xtuner/rlhf/model_server/critic_model_server.py b/xtuner/rlhf/model_server/critic_model_server.py index fe5afc5b2..ee35aa829 100644 --- a/xtuner/rlhf/model_server/critic_model_server.py +++ b/xtuner/rlhf/model_server/critic_model_server.py @@ -1,5 +1,5 @@ -from ..model_backend.models.critical_and_reward import get_critic_model from .base_model_server import BaseModelServer +from .utils import get_critic_model class CriticModelServer(BaseModelServer): diff --git a/xtuner/rlhf/model_server/reward_model_server.py b/xtuner/rlhf/model_server/reward_model_server.py index b658c1513..84e5e42af 100644 --- a/xtuner/rlhf/model_server/reward_model_server.py +++ b/xtuner/rlhf/model_server/reward_model_server.py @@ -1,10 +1,10 @@ import torch from transformers import AutoConfig -from ..model_backend.models.critical_and_reward import get_reward_model from ..tokenizer import encode_inputs from ..utils import expand_reward_token_id from .base_model_server import BaseModelServer +from .utils import get_reward_model class RewardModelServer(BaseModelServer): diff --git a/xtuner/rlhf/model_backend/models/critical_and_reward.py b/xtuner/rlhf/model_server/utils.py similarity index 97% rename from xtuner/rlhf/model_backend/models/critical_and_reward.py rename to xtuner/rlhf/model_server/utils.py index bb1e3697a..8180f2278 100644 --- a/xtuner/rlhf/model_backend/models/critical_and_reward.py +++ b/xtuner/rlhf/model_server/utils.py @@ -1,3 +1,4 @@ +# Adopted from https://github.com/OpenLLMAI/OpenRLHF/blob/main/openrlhf/models/model.py#L134 # noqa: E501 from typing import Optional import torch From 5d6fb35d8b89464355ddeb988f99927c00612ff0 Mon Sep 17 00:00:00 2001 From: Zhu Zhihao Date: Thu, 20 Jun 2024 06:43:46 +0000 Subject: [PATCH 07/37] fix: zero3 for bigger models --- xtuner/rlhf/model_backend/hf_model_runner.py | 46 ++++++++++---------- 1 file changed, 24 insertions(+), 22 deletions(-) diff --git a/xtuner/rlhf/model_backend/hf_model_runner.py b/xtuner/rlhf/model_backend/hf_model_runner.py index bb56898e4..877892d58 100644 --- a/xtuner/rlhf/model_backend/hf_model_runner.py +++ b/xtuner/rlhf/model_backend/hf_model_runner.py @@ -50,6 +50,28 @@ def initialize(self): envs = self.model_config.get('envs', {}) for key, value in envs.items(): os.environ[key] = value + + # Parallel Settings + parallel: dict = self.model_config['parallel'] + assert parallel['tensor']['size'] == 1 # TODO: support TP + assert parallel['pipeline']['size'] == 1 # TODO: support PP + self.step = 0 + self.zero_stage = 1 + mixed_precision = self.model_config.get('mixed_precision', None) + if parallel['data'].get('mode') == ENGINE_PLUGIN_FSDP: + self.accelerator = Accelerator( + fsdp_plugin=FullyShardedDataParallelPlugin()) + self.zero_stage = 3 + elif parallel['data'].get('mode') == ENGINE_PLUGIN_DEEPSPEED: + from accelerate import DeepSpeedPlugin + + ds_config = self.model_config['deepspeed_config'] # requisite + self.accelerator = Accelerator( + deepspeed_plugin=DeepSpeedPlugin(ds_config)) + self.zero_stage = ds_config['zero_optimization']['stage'] + else: + self.accelerator = Accelerator(mixed_precision=mixed_precision) + self.zero_stage = 0 # 1. Model model_path = self.model_config.get('model_path') @@ -60,7 +82,7 @@ def initialize(self): AutoModelForCausalLM) self.model: PreTrainedModel = model_class.from_pretrained( pretrained_model_name_or_path=model_path, - device_map='auto', + device_map=None if self.zero_stage==3 else "auto", torch_dtype=torch_dtype, trust_remote_code=True, attn_implementation='flash_attention_2' @@ -81,29 +103,9 @@ def initialize(self): tokenizer_path, trust_remote_code=True, **tokenizer_config) # 3. Trainer - parallel: dict = self.model_config['parallel'] - assert parallel['tensor']['size'] == 1 # TODO: support TP - assert parallel['pipeline']['size'] == 1 # TODO: support PP - self.step = 0 - self.zero_stage = 1 - mixed_precision = self.model_config.get('mixed_precision', None) - if parallel['data'].get('mode') == ENGINE_PLUGIN_FSDP: - self.accelerator = Accelerator( - fsdp_plugin=FullyShardedDataParallelPlugin()) - self.zero_stage = 3 - elif parallel['data'].get('mode') == ENGINE_PLUGIN_DEEPSPEED: - from accelerate import DeepSpeedPlugin - - ds_config = self.model_config['deepspeed_config'] # requisite - self.accelerator = Accelerator( - deepspeed_plugin=DeepSpeedPlugin(ds_config)) - self.zero_stage = ds_config['zero_optimization']['stage'] - else: - self.accelerator = Accelerator(mixed_precision=mixed_precision) - self.zero_stage = 0 - train_kwargs = self.model_config.get('train_kwargs') if train_kwargs is None: # requires no training + self.model = self.accelerator.prepare(self.model) if self.zero_stage==3 else self.model self.device = self.accelerator.device logger.info( f'[{self.model_type}] __init__() done without train_kwargs.') From 795995d9b9e34d3cd109f2f2db3df3de424a9736 Mon Sep 17 00:00:00 2001 From: Zhu Zhihao Date: Thu, 20 Jun 2024 07:01:06 +0000 Subject: [PATCH 08/37] precommit check --- xtuner/rlhf/model_backend/hf_model_runner.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/xtuner/rlhf/model_backend/hf_model_runner.py b/xtuner/rlhf/model_backend/hf_model_runner.py index 877892d58..90f18f9e7 100644 --- a/xtuner/rlhf/model_backend/hf_model_runner.py +++ b/xtuner/rlhf/model_backend/hf_model_runner.py @@ -50,7 +50,7 @@ def initialize(self): envs = self.model_config.get('envs', {}) for key, value in envs.items(): os.environ[key] = value - + # Parallel Settings parallel: dict = self.model_config['parallel'] assert parallel['tensor']['size'] == 1 # TODO: support TP @@ -82,7 +82,7 @@ def initialize(self): AutoModelForCausalLM) self.model: PreTrainedModel = model_class.from_pretrained( pretrained_model_name_or_path=model_path, - device_map=None if self.zero_stage==3 else "auto", + device_map=None if self.zero_stage == 3 else 'auto', torch_dtype=torch_dtype, trust_remote_code=True, attn_implementation='flash_attention_2' @@ -105,7 +105,8 @@ def initialize(self): # 3. Trainer train_kwargs = self.model_config.get('train_kwargs') if train_kwargs is None: # requires no training - self.model = self.accelerator.prepare(self.model) if self.zero_stage==3 else self.model + self.model = self.accelerator.prepare( + self.model) if self.zero_stage == 3 else self.model self.device = self.accelerator.device logger.info( f'[{self.model_type}] __init__() done without train_kwargs.') From 1ff6ab36e20c361376bc6f035bc95acb5dc6cf6a Mon Sep 17 00:00:00 2001 From: lishuaibin Date: Thu, 20 Jun 2024 19:35:48 +0800 Subject: [PATCH 09/37] resume --- examples/rlhf/four_model_8gpu.py | 3 ++ examples/rlhf/four_model_vllm_8gpu.py | 3 ++ xtuner/rlhf/dataset/base.py | 5 ++- xtuner/rlhf/envs/txt_env.py | 8 ++++ xtuner/rlhf/main.py | 18 +++++--- xtuner/rlhf/model_backend/hf_model_runner.py | 43 +++++++++++++------ xtuner/rlhf/model_server/base_model_server.py | 4 +- 7 files changed, 60 insertions(+), 24 deletions(-) diff --git a/examples/rlhf/four_model_8gpu.py b/examples/rlhf/four_model_8gpu.py index 9a105668d..15df7eaff 100644 --- a/examples/rlhf/four_model_8gpu.py +++ b/examples/rlhf/four_model_8gpu.py @@ -1,6 +1,7 @@ ####################################################################### # Settings # ####################################################################### +RESUME_STEP=-1 MAX_PROMPT_LEN = 1024 MAX_ANSWER_LEN = 1024 MAX_PRETRAIN_LEN = 8192 @@ -43,6 +44,7 @@ reward_micro_bs=GENERATE_MICRO_BATCH_SIZE, max_new_tokens=MAX_ANSWER_LEN, write_to_file=True, + resume_step=RESUME_STEP, generate_kwargs={ 'do_sample': True, 'temperature': 1.0, @@ -76,6 +78,7 @@ critic_warmup_step=20, save_interval=40, max_train_step=400, + resume_step=RESUME_STEP, ) model_configs = dict( diff --git a/examples/rlhf/four_model_vllm_8gpu.py b/examples/rlhf/four_model_vllm_8gpu.py index 81da45d03..cce6bafd5 100644 --- a/examples/rlhf/four_model_vllm_8gpu.py +++ b/examples/rlhf/four_model_vllm_8gpu.py @@ -1,6 +1,7 @@ ####################################################################### # Settings # ####################################################################### +RESUME_STEP=-1 MAX_PROMPT_LEN = 1024 MAX_ANSWER_LEN = 1024 MAX_PRETRAIN_LEN = 8192 @@ -43,6 +44,7 @@ reward_micro_bs=GENERATE_MICRO_BATCH_SIZE, max_new_tokens=MAX_ANSWER_LEN, write_to_file=True, + resume_step=RESUME_STEP, generate_kwargs={ 'do_sample': True, 'temperature': 1.0, @@ -76,6 +78,7 @@ critic_warmup_step=20, save_interval=40, max_train_step=400, + resume_step=RESUME_STEP, ) model_configs = dict( diff --git a/xtuner/rlhf/dataset/base.py b/xtuner/rlhf/dataset/base.py index a9a4330d6..ef4ee9fd4 100644 --- a/xtuner/rlhf/dataset/base.py +++ b/xtuner/rlhf/dataset/base.py @@ -114,7 +114,8 @@ def __init__(self, task_groups, tokenizer=None, random_seed=1024): rm_prompt=rm_prompt)) logger.info( f'[DataLoader] Load {_task} with prob:{prob}, ' - f'sys_prompt type: {sys_prompt}, reward meta: {rm_prompt}') + f'sys_prompt type: {sys_prompt}, ' + f'reward prompt type: {rm_prompt}') else: logger.warning('[DataLoader] skip file, ' f'prob of {file_path} is {prob} ...') @@ -127,7 +128,7 @@ def __init__(self, task_groups, tokenizer=None, random_seed=1024): # loading & convert & save opensource datasets hf_dir = filepath.split('[HF]')[-1] - logger.info(f'Loading {hf_dir} with huggingface format ...') + logger.info(f'Loading {hf_dir} from huggingface ...') dataset = load_from_hf(hf_dir, tokenizer=tokenizer) task['dataset'] = IterDataset( data_list=dataset['conversation'], diff --git a/xtuner/rlhf/envs/txt_env.py b/xtuner/rlhf/envs/txt_env.py index b13cbfc7f..56aa3ebc1 100644 --- a/xtuner/rlhf/envs/txt_env.py +++ b/xtuner/rlhf/envs/txt_env.py @@ -24,6 +24,7 @@ def __init__( reward_micro_bs: int = 32, async_reward: bool = True, generate_kwargs: dict = None, + resume_step=-1, **_ignored, ): self.policy_model = policy_model @@ -38,8 +39,15 @@ def __init__( self.reward_micro_bs = reward_micro_bs self.async_reward = async_reward self.generate_kwargs: dict = generate_kwargs + self.resume_step = resume_step def rollout(self, display=True): + while self.resume_step > 0: + logger.info(f"[Resume] {self.resume_step} consuming data...") + trained_prompt_datas = next(self.prompt_mes_iter) + if self.pretrain_mes_iter is not None: + trained_pretrain_datas = next(self.pretrain_mes_iter) + self.resume_step -= 1 prompt_datas = deepcopy(next(self.prompt_mes_iter)) prompt_input_messages = [] for data in prompt_datas: diff --git a/xtuner/rlhf/main.py b/xtuner/rlhf/main.py index d7319c33c..e89556220 100644 --- a/xtuner/rlhf/main.py +++ b/xtuner/rlhf/main.py @@ -112,24 +112,27 @@ def validate_config(config: Config): critic_warmup_step = train_config['critic_warmup_step'] save_interval = train_config['save_interval'] max_train_step = train_config.get('max_train_step', float('inf')) + resume_step = train_config.get('resume_step', -1) - step = 1 + step = max(0, resume_step) while step <= max_train_step: s_t = time.time() with Timer(f'step {step}: end_to_end'): + # generate trajectories trajectories = txt_env.rollout(display=True) + # deal with trajectories trajectories = ppo_repeater.process(trajectories) - # # for critic & policy learn + # critic & policy learn critic_loss = ppo.critic_learn(trajectories) # critic_loss_ref = ppo.critic_learn_async(trajectories) ppo_loss, pt_loss = None, None if critic_warmup_step <= 0: ppo_loss, pt_loss = ppo.policy_learn(trajectories) - logger_train.info(f'[Policy Train] Step: {step}, \ - ppo loss: {ppo_loss}, pretrain loss: {pt_loss}') + logger_train.info(f'[Policy Train] Step: {step}, ' + f'ppo loss: {ppo_loss}, pretrain loss: {pt_loss}') # critic_loss = ppo.critic_learn_get(critic_loss_ref) logger_train.info( @@ -162,8 +165,9 @@ def validate_config(config: Config): ) with open(f'{work_dir}/train_rlhf.log.jsonl', 'a') as f: f.write(json.dumps(summaries) + '\n') - logger_train.info(f'[end to end] duration: {time.time() - s_t} s') - if (step % save_interval == 0) or (step == max_train_step): - policy_model.save_model(f'{work_dir}/ckpt/{step}/') + step += 1 + if (step % save_interval == 0) or (step == max_train_step): + policy_model.save(f'{work_dir}/ckpt/policy_model/{step}') + critic_model.save(f'{work_dir}/ckpt/critic_model/{step}') diff --git a/xtuner/rlhf/model_backend/hf_model_runner.py b/xtuner/rlhf/model_backend/hf_model_runner.py index 90f18f9e7..cfe62b8a5 100644 --- a/xtuner/rlhf/model_backend/hf_model_runner.py +++ b/xtuner/rlhf/model_backend/hf_model_runner.py @@ -556,20 +556,37 @@ def get_state_dict(self): def set_seed(self, seed=None): set_seed(seed) - def save_model(self, path): + def save(self, path): + # for resume + self.accelerator.wait_for_everyone() + self.accelerator.save_state(os.path.join(path, 'saved_state')) + + # save model, tokenizer, step if not self.accelerator.is_main_process: self.accelerator.get_state_dict(self.model) return - unwrapped_model = self.accelerator.unwrap_model(self.model) - if not os.path.exists(path): - os.makedirs(path) - unwrapped_model.save_pretrained( - path, - is_main_process=True, - save_function=self.accelerator.save, - state_dict=self.accelerator.get_state_dict(self.model), - ) - logger.info(f'save model to {path}') + else: + path = os.path.normpath(path) + logger.info(f'[Train step {self.step}] ' + f'Saving {self.model_type} to {path} ...') + # save model + unwrapped_model = self.accelerator.unwrap_model(self.model) + unwrapped_model.save_pretrained( + path, + is_main_process=True, + save_function=self.accelerator.save, + state_dict=self.accelerator.get_state_dict(self.model), + ) + # save tokenizer + if self.tokenizer is not None: + self.tokenizer.save_pretrained(path) + # step + torch.save(self.step, os.path.join(path, f'{self.step}.step')) + logger.info(f'{self.model_type} saved.') + + def info_rank0(self, content): + if self.accelerator.is_main_process: + logger.info(content) def info_rank0(self, content): if self.accelerator.is_main_process: @@ -842,8 +859,8 @@ def release_resources(self): remove_placement_group(self.placement_group) self.released = True - def save_model(self, path): - ray.get([actor.save_model.remote(path) for actor in self.ray_actors]) + def save(self, path): + ray.get([actor.save.remote(path) for actor in self.ray_actors]) def init_process_group(self, generator): refs = [ diff --git a/xtuner/rlhf/model_server/base_model_server.py b/xtuner/rlhf/model_server/base_model_server.py index 73702aefa..63526233e 100644 --- a/xtuner/rlhf/model_server/base_model_server.py +++ b/xtuner/rlhf/model_server/base_model_server.py @@ -149,8 +149,8 @@ def state_dict_get(self): return ray.get( self.trainer.get_state_dict(), timeout=DEFAULT_GET_TIMEOUT) - def save_model(self, path): - self.trainer.save_model(path) + def save(self, path): + self.trainer.save(path) if self.tokenizer is not None: self.tokenizer.save_pretrained(path) From e1c8c61096ba33b065c1156ec971fa7956cff9ba Mon Sep 17 00:00:00 2001 From: lishuaibin Date: Fri, 21 Jun 2024 14:37:19 +0800 Subject: [PATCH 10/37] fix resume/pretrain_data --- examples/rlhf/four_model_8gpu.py | 2 +- examples/rlhf/four_model_vllm_8gpu.py | 4 +-- xtuner/rlhf/dataset/base.py | 7 ++--- xtuner/rlhf/dataset/message_iter.py | 16 +++++++++- xtuner/rlhf/envs/txt_env.py | 8 ++--- xtuner/rlhf/main.py | 9 ++++-- xtuner/rlhf/model_backend/hf_model_runner.py | 31 +++++++++++++------- 7 files changed, 51 insertions(+), 26 deletions(-) diff --git a/examples/rlhf/four_model_8gpu.py b/examples/rlhf/four_model_8gpu.py index 15df7eaff..a058b1ab0 100644 --- a/examples/rlhf/four_model_8gpu.py +++ b/examples/rlhf/four_model_8gpu.py @@ -1,7 +1,7 @@ ####################################################################### # Settings # ####################################################################### -RESUME_STEP=-1 +RESUME_STEP = -1 MAX_PROMPT_LEN = 1024 MAX_ANSWER_LEN = 1024 MAX_PRETRAIN_LEN = 8192 diff --git a/examples/rlhf/four_model_vllm_8gpu.py b/examples/rlhf/four_model_vllm_8gpu.py index cce6bafd5..d4ba5d4c3 100644 --- a/examples/rlhf/four_model_vllm_8gpu.py +++ b/examples/rlhf/four_model_vllm_8gpu.py @@ -1,13 +1,13 @@ ####################################################################### # Settings # ####################################################################### -RESUME_STEP=-1 +RESUME_STEP = -1 MAX_PROMPT_LEN = 1024 MAX_ANSWER_LEN = 1024 MAX_PRETRAIN_LEN = 8192 PROMPT_BATCH_SIZE = 256 -PRETRAIN_BATCH_SIZE = 32 +PRETRAIN_BATCH_SIZE = 32 # 0 GENERATE_MICRO_BATCH_SIZE = 16 INFER_MICRO_BATCH_SIZE = 8 diff --git a/xtuner/rlhf/dataset/base.py b/xtuner/rlhf/dataset/base.py index ef4ee9fd4..f870ae5f2 100644 --- a/xtuner/rlhf/dataset/base.py +++ b/xtuner/rlhf/dataset/base.py @@ -112,10 +112,9 @@ def __init__(self, task_groups, tokenizer=None, random_seed=1024): filepath=file_path, sys_prompt=sys_prompt, rm_prompt=rm_prompt)) - logger.info( - f'[DataLoader] Load {_task} with prob:{prob}, ' - f'sys_prompt type: {sys_prompt}, ' - f'reward prompt type: {rm_prompt}') + logger.info(f'[DataLoader] Load {_task} with prob:{prob}, ' + f'sys_prompt type: {sys_prompt}, ' + f'reward prompt type: {rm_prompt}') else: logger.warning('[DataLoader] skip file, ' f'prob of {file_path} is {prob} ...') diff --git a/xtuner/rlhf/dataset/message_iter.py b/xtuner/rlhf/dataset/message_iter.py index e7f918f5b..94c78b83a 100644 --- a/xtuner/rlhf/dataset/message_iter.py +++ b/xtuner/rlhf/dataset/message_iter.py @@ -35,7 +35,7 @@ def __init__(self, message_type: str = 'prompt', tokenizer=None, max_len: int = 4096, - samples_each_epoch: int = 64, + samples_each_epoch: int = 0, random_seed: int = 110, sample_strategy: str = 'in_batch', **kwargs): @@ -44,6 +44,12 @@ def __init__(self, 'in_batch', 'in_data' ], ("`sample_strategy` should in ['in_batch', 'in_data']," f' but got {sample_strategy}') + if (message_datasets is None) or (samples_each_epoch == 0): + logger.warning(f'message_datasets: {message_datasets}' + f' samples_each_epoch: {samples_each_epoch}.') + self.message_datasets = None + self.samples_each_epoch = 0 + return None assert message_datasets is not None self.message_type = message_type self.sample_strategy = sample_strategy @@ -165,6 +171,14 @@ def yield_in_batch(self): mes_sequence.append(sequence) if len(mes_sequence) == self.samples_each_epoch: break + # TODO, len(mes_sequence) < self.samples_each_epoch, + # tmp: random sample from chosen data + if len(mes_sequence) < self.samples_each_epoch: + missed = self.samples_each_epoch - len(mes_sequence) + logger.warning( + f'[MES_ITER] {self.message_type} {missed} dirty data ...') + for i in range(missed): + mes_sequence.append(mes_sequence[i]) assert len(mes_sequence) == self.samples_each_epoch logger.info(f'[Epoch {self.epoch_index}] sample ' diff --git a/xtuner/rlhf/envs/txt_env.py b/xtuner/rlhf/envs/txt_env.py index 56aa3ebc1..c42d6e68b 100644 --- a/xtuner/rlhf/envs/txt_env.py +++ b/xtuner/rlhf/envs/txt_env.py @@ -32,7 +32,7 @@ def __init__( self.prompt_mes_iter = iter(prompt_mes_iter) self.pretrain_mes_iter = iter( - pretrain_mes_iter) if pretrain_mes_iter else None + pretrain_mes_iter) if pretrain_mes_iter.message_datasets else None self.max_new_tokens = max_new_tokens self.policy_micro_bs = policy_micro_bs @@ -43,10 +43,10 @@ def __init__( def rollout(self, display=True): while self.resume_step > 0: - logger.info(f"[Resume] {self.resume_step} consuming data...") - trained_prompt_datas = next(self.prompt_mes_iter) + logger.info(f'[Resume] {self.resume_step} consuming data...') + next(self.prompt_mes_iter) if self.pretrain_mes_iter is not None: - trained_pretrain_datas = next(self.pretrain_mes_iter) + next(self.pretrain_mes_iter) self.resume_step -= 1 prompt_datas = deepcopy(next(self.prompt_mes_iter)) prompt_input_messages = [] diff --git a/xtuner/rlhf/main.py b/xtuner/rlhf/main.py index e89556220..0d9a69c67 100644 --- a/xtuner/rlhf/main.py +++ b/xtuner/rlhf/main.py @@ -83,7 +83,7 @@ def validate_config(config: Config): prompt_dataset_config = config['prompt_dataset_config'] prompt_mes_iter = MessageIter( tokenizer=ref_model.tokenizer, **prompt_dataset_config) - pretrain_dataset_config = config['pretrain_dataset_config'] + pretrain_dataset_config = config.get('pretrain_dataset_config', {}) pretrain_mes_iter = MessageIter( tokenizer=ref_model.tokenizer, **pretrain_dataset_config) @@ -93,7 +93,7 @@ def validate_config(config: Config): policy_model=policy_model, reward_model=reward_model, prompt_mes_iter=prompt_mes_iter, - pretrain_mes_iter=pretrain_mes_iter, + pretrain_mes_iter=pretrain_mes_iter, # None **rollout_config, ) # init repeater @@ -113,6 +113,8 @@ def validate_config(config: Config): save_interval = train_config['save_interval'] max_train_step = train_config.get('max_train_step', float('inf')) resume_step = train_config.get('resume_step', -1) + critic_warmup_step = min(critic_warmup_step, + critic_warmup_step - resume_step) step = max(0, resume_step) while step <= max_train_step: @@ -131,7 +133,8 @@ def validate_config(config: Config): ppo_loss, pt_loss = None, None if critic_warmup_step <= 0: ppo_loss, pt_loss = ppo.policy_learn(trajectories) - logger_train.info(f'[Policy Train] Step: {step}, ' + logger_train.info( + f'[Policy Train] Step: {step}, ' f'ppo loss: {ppo_loss}, pretrain loss: {pt_loss}') # critic_loss = ppo.critic_learn_get(critic_loss_ref) diff --git a/xtuner/rlhf/model_backend/hf_model_runner.py b/xtuner/rlhf/model_backend/hf_model_runner.py index cfe62b8a5..4786d1038 100644 --- a/xtuner/rlhf/model_backend/hf_model_runner.py +++ b/xtuner/rlhf/model_backend/hf_model_runner.py @@ -1,3 +1,4 @@ +import glob import os import socket from typing import Optional, Union @@ -55,7 +56,7 @@ def initialize(self): parallel: dict = self.model_config['parallel'] assert parallel['tensor']['size'] == 1 # TODO: support TP assert parallel['pipeline']['size'] == 1 # TODO: support PP - self.step = 0 + self.update_step = 0 self.zero_stage = 1 mixed_precision = self.model_config.get('mixed_precision', None) if parallel['data'].get('mode') == ENGINE_PLUGIN_FSDP: @@ -135,6 +136,10 @@ def initialize(self): self.model, self.optimizer, self.lr_scheduler = self.accelerator.prepare( # noqa: E501 self.model, self.optimizer, self.lr_scheduler) + # resume optimizer, lr_scheduler + if bool(len(glob.glob(os.path.join(model_path, '*.step')))): + self._resume_load_pretrained(model_path=model_path) + # Others self.device = self.accelerator.device set_seed(self.model_config.get('seed')) @@ -149,6 +154,14 @@ def initialize(self): f'[{self.model_type}] __init__() done with optimizer {self.optimizer.optimizer}.' # noqa: E501 ) + def _resume_load_pretrained(self, model_path): + _, step_pt = os.path.split( + glob.glob(os.path.join(model_path, '*.step'))[0]) + self.update_step = int(step_pt.split('.step')[0]) + logger.info(f'Resume train step {self.update_step} from {model_path}') + assert os.path.exists(os.path.join(model_path, 'saved_state')) + self.accelerator.load_state(os.path.join(model_path, 'saved_state')) + def compute_loss( self, input_ids: torch.Tensor, @@ -206,8 +219,8 @@ def compute_loss( def parameter_update(self, step_interval=1): self.info_rank0(f'[{self.model_type}] self.parameter_update()') - self.step += 1 - if self.step % step_interval == 0: + self.update_step += 1 + if self.update_step % step_interval == 0: self.accelerator.clip_grad_norm_(self.model.parameters(), self.clip_grad_norm) self.optimizer.step() @@ -567,8 +580,8 @@ def save(self, path): return else: path = os.path.normpath(path) - logger.info(f'[Train step {self.step}] ' - f'Saving {self.model_type} to {path} ...') + logger.info(f'[Train step {self.update_step}] ' + f'Saving {self.model_type} to {path} ...') # save model unwrapped_model = self.accelerator.unwrap_model(self.model) unwrapped_model.save_pretrained( @@ -580,18 +593,14 @@ def save(self, path): # save tokenizer if self.tokenizer is not None: self.tokenizer.save_pretrained(path) - # step - torch.save(self.step, os.path.join(path, f'{self.step}.step')) + torch.save(self.update_step, + os.path.join(path, f'{self.update_step}.step')) logger.info(f'{self.model_type} saved.') def info_rank0(self, content): if self.accelerator.is_main_process: logger.info(content) - def info_rank0(self, content): - if self.accelerator.is_main_process: - logger.info(content) - # Adapted from https://github.com/OpenLLMAI/OpenRLHF/blob/v0.2.5/openrlhf/trainer/ray/ppo_actor.py # noqa: E501 class HfModelRunnerRayActor(HfModelRunner, RayActorMixin): From 89bfe9d24d799c4a2151d4ecccf397ddf404a456 Mon Sep 17 00:00:00 2001 From: lishuaibin Date: Fri, 21 Jun 2024 16:26:29 +0800 Subject: [PATCH 11/37] fix tokenizer bug --- xtuner/rlhf/coordinator.py | 6 ++++-- xtuner/rlhf/main.py | 8 +++++--- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/xtuner/rlhf/coordinator.py b/xtuner/rlhf/coordinator.py index 60c214c44..baa1345c2 100644 --- a/xtuner/rlhf/coordinator.py +++ b/xtuner/rlhf/coordinator.py @@ -15,9 +15,10 @@ class Coordinator: - def __init__(self, cluster_address: str, model_configs: dict): + def __init__(self, cluster_address: str, configs: dict): self.cluster_address = cluster_address - self.model_configs = model_configs + self.model_configs = configs['model_configs'] + self.tokenizer_config = configs.get('tokenizer_config', {}) self.model_dict = dict() self.context_type: str = None # "client" or "server" self.context: ray._private.workers.BaseContext = None @@ -58,6 +59,7 @@ def create_models(self) -> dict[str, BaseModelServer]: self.model_dict = {} for model_name, model_config in self.model_configs.items(): model_type = model_config['model_type'] + model_config['tokenizer_config'] = self.tokenizer_config if model_type == MODEL_TYPE_POLICY: self.model_dict[model_name] = PolicyModelServer( model_name, model_config) diff --git a/xtuner/rlhf/main.py b/xtuner/rlhf/main.py index 0d9a69c67..d51082be9 100644 --- a/xtuner/rlhf/main.py +++ b/xtuner/rlhf/main.py @@ -1,6 +1,7 @@ import argparse import json import os +import shutil import time from loguru import logger @@ -54,14 +55,15 @@ def validate_config(config: Config): work_dir = os.path.abspath(work_dir) logger.info(f'using work_dir: {work_dir}') os.makedirs(work_dir, exist_ok=True) + # save original config + shutil.copy2(args.config, f'{work_dir}/{os.path.basename(args.config)}') logger.add( f'{work_dir}/train_rlhf.log', filter=lambda record: record['extra'].get('name') == 'train') logger_train = logger.bind(name='train') - configs_path = args.config - config = Config.from_file(configs_path) + config = Config.from_file(args.config) logger.info('#################### CONFIG BGN ####################') for k, v in config.items(): logger.info(f'{k}: {v}') @@ -72,7 +74,7 @@ def validate_config(config: Config): if cluster_address != 'auto': cluster_address = f'ray://{cluster_address}:10001' logger.info(f'cluster_address={cluster_address}') - coordinator = Coordinator(cluster_address, config['model_configs']) + coordinator = Coordinator(cluster_address, config) model_dict = coordinator.create_models() ref_model = model_dict['reference'] policy_model = model_dict['policy'] From 82a3366b4908810c295b4744f36b86dacb33cc7d Mon Sep 17 00:00:00 2001 From: qzweng Date: Tue, 25 Jun 2024 08:28:46 +0000 Subject: [PATCH 12/37] async learn --- xtuner/rlhf/main.py | 11 ++++++++--- xtuner/rlhf/repeaters/kl_gae.py | 2 +- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/xtuner/rlhf/main.py b/xtuner/rlhf/main.py index d51082be9..9d692f6bf 100644 --- a/xtuner/rlhf/main.py +++ b/xtuner/rlhf/main.py @@ -117,6 +117,7 @@ def validate_config(config: Config): resume_step = train_config.get('resume_step', -1) critic_warmup_step = min(critic_warmup_step, critic_warmup_step - resume_step) + async_learn = train_config.get('async_learn', False) step = max(0, resume_step) while step <= max_train_step: @@ -129,8 +130,10 @@ def validate_config(config: Config): trajectories = ppo_repeater.process(trajectories) # critic & policy learn - critic_loss = ppo.critic_learn(trajectories) - # critic_loss_ref = ppo.critic_learn_async(trajectories) + if async_learn: + critic_loss_ref = ppo.critic_learn_async(trajectories) + else: + critic_loss = ppo.critic_learn(trajectories) ppo_loss, pt_loss = None, None if critic_warmup_step <= 0: @@ -139,7 +142,9 @@ def validate_config(config: Config): f'[Policy Train] Step: {step}, ' f'ppo loss: {ppo_loss}, pretrain loss: {pt_loss}') - # critic_loss = ppo.critic_learn_get(critic_loss_ref) + if async_learn: + critic_loss = ppo.critic_learn_get(critic_loss_ref) + logger_train.info( f'[Critic Train] step: {step}, critic loss: {critic_loss}') logger_train.info(f'rewards: {trajectories.rewards.mean()}') diff --git a/xtuner/rlhf/repeaters/kl_gae.py b/xtuner/rlhf/repeaters/kl_gae.py index a1304ac53..5ab69611a 100644 --- a/xtuner/rlhf/repeaters/kl_gae.py +++ b/xtuner/rlhf/repeaters/kl_gae.py @@ -90,7 +90,7 @@ def _get_kl_rewards(self, trajectories: PolicyOutput): attention_mask=trajectories.attention_mask, output_logits=False, output_logprobs=True) - with Timer('ref_model.infer_get'): + with Timer('policy_model.infer_get'): policy_output = self.policy_model.infer_get(policy_output) with Timer('ref_model.infer_get'): ref_output = self.ref_model.infer_get(ref_output) From 6531b5db8a12cd9d3167231462febe5aa57b0893 Mon Sep 17 00:00:00 2001 From: qzweng Date: Tue, 25 Jun 2024 08:29:17 +0000 Subject: [PATCH 13/37] add and rename ppo configs --- ....py => internlm2_chat_1_8b_ppo_ds_8gpu.py} | 0 ...> internlm2_chat_1_8b_ppo_ds_vllm_8gpu.py} | 0 examples/rlhf/llama2_7b_ppo_ds_vllm_16gpu.py | 244 ++++++++++++++++++ examples/rlhf/quick_start.md | 2 +- 4 files changed, 245 insertions(+), 1 deletion(-) rename examples/rlhf/{four_model_8gpu.py => internlm2_chat_1_8b_ppo_ds_8gpu.py} (100%) rename examples/rlhf/{four_model_vllm_8gpu.py => internlm2_chat_1_8b_ppo_ds_vllm_8gpu.py} (100%) create mode 100644 examples/rlhf/llama2_7b_ppo_ds_vllm_16gpu.py diff --git a/examples/rlhf/four_model_8gpu.py b/examples/rlhf/internlm2_chat_1_8b_ppo_ds_8gpu.py similarity index 100% rename from examples/rlhf/four_model_8gpu.py rename to examples/rlhf/internlm2_chat_1_8b_ppo_ds_8gpu.py diff --git a/examples/rlhf/four_model_vllm_8gpu.py b/examples/rlhf/internlm2_chat_1_8b_ppo_ds_vllm_8gpu.py similarity index 100% rename from examples/rlhf/four_model_vllm_8gpu.py rename to examples/rlhf/internlm2_chat_1_8b_ppo_ds_vllm_8gpu.py diff --git a/examples/rlhf/llama2_7b_ppo_ds_vllm_16gpu.py b/examples/rlhf/llama2_7b_ppo_ds_vllm_16gpu.py new file mode 100644 index 000000000..65b445630 --- /dev/null +++ b/examples/rlhf/llama2_7b_ppo_ds_vllm_16gpu.py @@ -0,0 +1,244 @@ +####################################################################### +# Settings # +####################################################################### +RESUME_STEP = -1 +MAX_PROMPT_LEN = 1024 +MAX_ANSWER_LEN = 1024 +MAX_PRETRAIN_LEN = 8192 + +PROMPT_BATCH_SIZE = 512 +PRETRAIN_BATCH_SIZE = 0 + +GENERATE_MICRO_BATCH_SIZE = 16 +INFER_MICRO_BATCH_SIZE = 16 +TRAIN_MICRO_BATCH_SIZE = 4 +REF_INFER_MICRO_BATCH_SIZE = 26 + +ZERO_STAGE = 3 +POLICY_DP_SIZE = 8 +CRITIC_DP_SIZE = 4 +POLICY_GRADIENT_ACC_STEP = (PROMPT_BATCH_SIZE + PRETRAIN_BATCH_SIZE + ) // POLICY_DP_SIZE // TRAIN_MICRO_BATCH_SIZE +CRITIC_GRADIENT_ACC_STEP = PROMPT_BATCH_SIZE // CRITIC_DP_SIZE // TRAIN_MICRO_BATCH_SIZE # noqa: E501 + +# checkout generate config +assert PROMPT_BATCH_SIZE % GENERATE_MICRO_BATCH_SIZE == 0 +assert PROMPT_BATCH_SIZE % POLICY_DP_SIZE == 0 +# checkout infer config +assert PROMPT_BATCH_SIZE % (INFER_MICRO_BATCH_SIZE * POLICY_DP_SIZE) == 0 +assert PROMPT_BATCH_SIZE % (INFER_MICRO_BATCH_SIZE * CRITIC_DP_SIZE) == 0 +# checkout learn config +assert (PROMPT_BATCH_SIZE + PRETRAIN_BATCH_SIZE) % (TRAIN_MICRO_BATCH_SIZE * + POLICY_DP_SIZE) == 0 +assert (PROMPT_BATCH_SIZE) % (TRAIN_MICRO_BATCH_SIZE * CRITIC_DP_SIZE) == 0 + +import torch # noqa: E402 + +MODEL_DTYPE = torch.float16 + +POLICY_MODEL_PATH = 'meta-llama/Llama-2-7b-chat-hf' +REWARD_MODEL_PATH = 'meta-llama/Llama-2-7b-chat-hf' # better using a well-trained reward model # noqa: E501 + +tokenizer_config = dict( + pad_token_id=2, + eos_token_id=2, + padding_side='left', + chat_template= # noqa: E251 + "{% for message in messages %}{% if message['role'] == 'user' %}{{'Human:\n' + message['content'] + '\n'}}{% elif message['role'] == 'assistant' %}{{'Assistant:\n' + message['content'] + '\n'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:\n' }}{% endif %}", # noqa: E501 +) + +rollout_config = dict( + policy_micro_bs=GENERATE_MICRO_BATCH_SIZE, + reward_micro_bs=GENERATE_MICRO_BATCH_SIZE, + max_new_tokens=MAX_ANSWER_LEN, + write_to_file=False, + resume_step=RESUME_STEP, + generate_kwargs={ + 'do_sample': True, + 'temperature': 1.0, + 'top_k': 0, + 'top_p': 0.9, + 'min_new_tokens': 1, + 'num_beams': 1, + 'early_stopping': True, + 'eos_token_id': 92542, + 'pad_token_id': 0, + }, +) + +repeater_config = dict( + policy_micro_bs=INFER_MICRO_BATCH_SIZE, + critic_micro_bs=INFER_MICRO_BATCH_SIZE, + ref_micro_bs=REF_INFER_MICRO_BATCH_SIZE, + kl_coeff=0.01, + gamma=1.0, + gae_lambda=0.99, + clip_reward_min=-5, + clip_reward_max=5, + norm_rewards=True, +) + +train_config = dict( + policy_micro_bs=TRAIN_MICRO_BATCH_SIZE, + critic_micro_bs=TRAIN_MICRO_BATCH_SIZE, + ppo_loss_weight=1.0, + pretrain_loss_weight=0.5, + critic_warmup_step=0, + save_interval=40, + max_train_step=400, + resume_step=RESUME_STEP, + async_learn=True, +) + +model_configs = dict( + policy=dict( + model_path=POLICY_MODEL_PATH, + model_type='policy', + trainer_config=dict( + torch_dtype=MODEL_DTYPE, + trainer_type='huggingface', + use_flash_attn=True, + gradient_checkpointing=False, + train_kwargs=dict( + micro_bsz=1, + lr=1e-6, + total_steps=1e9, + lr_decay_rate=1, + ), + parallel=dict( + data=dict(size=POLICY_DP_SIZE, mode='deepspeed'), + tensor=dict(size=1, mode='1d'), + pipeline=dict(size=1, interleaved_overlap=False), + sequence=False, + ), + deepspeed_config={ + 'zero_optimization': { + 'stage': ZERO_STAGE, + 'offload_param': { + 'device': 'none' + }, + 'reduce_bucket_size': 'auto', + 'zero_quantized_weights': False, + 'zero_quantized_gradients': False, + 'stage3_gather_16bit_weights_on_model_save': True, + }, + 'bf16': { + 'enabled': True if MODEL_DTYPE == torch.bfloat16 else False + }, + 'fp16': { + 'enabled': True if MODEL_DTYPE == torch.float16 else False + }, + 'gradient_clipping': 1.0, + 'prescale_gradients': False, + 'wall_clock_breakdown': False, + 'data_types': { + 'grad_accum_dtype': 'fp32' + }, + 'train_micro_batch_size_per_gpu': TRAIN_MICRO_BATCH_SIZE, + 'gradient_accumulation_steps': POLICY_GRADIENT_ACC_STEP, + 'train_batch_size': PROMPT_BATCH_SIZE + PRETRAIN_BATCH_SIZE, + }, + ), + generator_config=dict( + shared_with_trainer=False, + generator_type='vllm', + parallel=dict( + data=dict(size=1, mode='ddp'), + tensor=dict(size=1, mode='1d'), + pipeline=dict(size=1, interleaved_overlap=False), + sequence=False, + ), + ), + ), + critic=dict( + model_path=REWARD_MODEL_PATH, + model_type='critic', + head_name='value_head', + trainer_config=dict( + torch_dtype=MODEL_DTYPE, + trainer_type='huggingface', + use_flash_attn=True, + gradient_checkpointing=False, + train_kwargs=dict( + micro_bsz=1, + lr=5e-6, + total_steps=1e9, + lr_decay_rate=1, + ), + parallel=dict( + data=dict(size=CRITIC_DP_SIZE, mode='deepspeed'), + tensor=dict(size=1, mode='1d'), + pipeline=dict(size=1, interleaved_overlap=False), + sequence=False, + ), + deepspeed_config={ + 'zero_optimization': { + 'stage': ZERO_STAGE, + 'offload_param': { + 'device': 'none' + }, + 'reduce_bucket_size': 'auto', + 'zero_quantized_weights': False, + 'zero_quantized_gradients': False, + 'stage3_gather_16bit_weights_on_model_save': True, + }, + 'bf16': { + 'enabled': True if MODEL_DTYPE == torch.bfloat16 else False + }, + 'fp16': { + 'enabled': True if MODEL_DTYPE == torch.float16 else False + }, + 'gradient_clipping': 1.0, + 'prescale_gradients': False, + 'wall_clock_breakdown': False, + 'data_types': { + 'grad_accum_dtype': 'fp32' + }, + 'train_micro_batch_size_per_gpu': TRAIN_MICRO_BATCH_SIZE, + 'gradient_accumulation_steps': CRITIC_GRADIENT_ACC_STEP, + 'train_batch_size': PROMPT_BATCH_SIZE, + }, + ), + ), + reference=dict( + model_path=POLICY_MODEL_PATH, + model_type='reference', + trainer_config=dict( + torch_dtype=MODEL_DTYPE, + trainer_type='huggingface', + use_flash_attn=True, + parallel=dict( + data=dict(size=2, mode='ddp'), + tensor=dict(size=1, mode='1d'), + pipeline=dict(size=1, interleaved_overlap=False), + sequence=False, + ), + ), + ), + reward=dict( + model_path=REWARD_MODEL_PATH, + model_type='reward', + head_name='value_head', + trainer_config=dict( + torch_dtype=MODEL_DTYPE, + trainer_type='huggingface', + use_flash_attn=True, + parallel=dict( + data=dict(size=1, mode='ddp'), + tensor=dict(size=1, mode='1d'), + pipeline=dict(size=1, interleaved_overlap=False), + sequence=False, + ), + ), + ), +) + +prompt_dataset_config = dict( + samples_each_epoch=PROMPT_BATCH_SIZE, + max_len=MAX_PROMPT_LEN, + message_type='prompt', + random_seed=1024, + sample_strategy='in_batch', # 'in_data' + message_datasets=[ + '[HF]Anthropic/hh-rlhf/helpful-base::0.5[RM_PROMPT]:default', + ]) diff --git a/examples/rlhf/quick_start.md b/examples/rlhf/quick_start.md index bbd9b9d26..ab1fce0f7 100644 --- a/examples/rlhf/quick_start.md +++ b/examples/rlhf/quick_start.md @@ -34,5 +34,5 @@ pip install cupy-cuda11x==12.1 python -m cupyx.tools.install_library --library nccl --cuda 11.x # 启动任务,首次启动建议添加 HF_ENDPOINT=https://hf-mirror.com 方便数据集加载 -HF_ENDPOINT=https://hf-mirror.com xtuner rlhf -c examples/rlhf/four_model_vllm_8gpu.py +HF_ENDPOINT=https://hf-mirror.com xtuner rlhf -c examples/rlhf/internlm2_chat_1_8b_ppo_ds_vllm_8gpu.py ``` From 7f7bab86187a913639f686002ebbd912a6c5948a Mon Sep 17 00:00:00 2001 From: lishuaibin Date: Wed, 26 Jun 2024 07:16:10 +0000 Subject: [PATCH 14/37] fix rm/sys promt --- xtuner/rlhf/envs/txt_env.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/xtuner/rlhf/envs/txt_env.py b/xtuner/rlhf/envs/txt_env.py index c42d6e68b..2fcce30c5 100644 --- a/xtuner/rlhf/envs/txt_env.py +++ b/xtuner/rlhf/envs/txt_env.py @@ -113,15 +113,18 @@ def get_reward_async(self, prompt_datas, policyout): for i in range(len(prompt_datas)): if prompt_datas[i].mes_type != 'prompt': continue - if prompt_datas[i].rm_prompt != 'default': + if (prompt_datas[i].rm_prompt != + 'default') or (prompt_datas[i].sys_prompt != 'default'): # Conditional Reward Model # for queries from different domains, use appropriate conditional system prompts # noqa: E501 # From Alignment section of the InternLM2 Technical Report: # https://arxiv.org/pdf/2403.17297 + if prompt_datas[i].rm_prompt != 'default': + prompt = prompt_datas[i].rm_prompt + else: + prompt = prompt_datas[i].sys_prompt cur_rm_data = [ - dict( - role='system', - content=SYSTEM_PROMPT[prompt_datas[i].rm_prompt]) + dict(role='system', content=SYSTEM_PROMPT[prompt]) ] + prompt_datas[i].message + [ dict( role='assistant', content=policyout.output_ans_str[i]) From ea9c67f8807da0e91eda4ed4ad964972a02313c6 Mon Sep 17 00:00:00 2001 From: Zhu Zhihao Date: Wed, 26 Jun 2024 12:07:08 +0000 Subject: [PATCH 15/37] fix vllm dp size>1 --- xtuner/rlhf/model_backend/generate_utils.py | 26 ++++++++++++++++--- .../rlhf/model_backend/vllm_model_runner.py | 18 ++++--------- 2 files changed, 28 insertions(+), 16 deletions(-) diff --git a/xtuner/rlhf/model_backend/generate_utils.py b/xtuner/rlhf/model_backend/generate_utils.py index 095dba9f9..ea5ee0c4a 100644 --- a/xtuner/rlhf/model_backend/generate_utils.py +++ b/xtuner/rlhf/model_backend/generate_utils.py @@ -56,16 +56,26 @@ def partition_by_micro_batch_size( num_splits = int(batch_size // micro_batch_size) + ( batch_size % micro_batch_size > 0) + max_inputs_length = None if isinstance(input_ids, torch.Tensor): input_ids_split = torch.split(input_ids, micro_batch_size, dim=0) + attention_mask_split = ( + torch.split(attention_mask, micro_batch_size, dim=0) if + attention_mask is not None else [None for _ in range(num_splits)]) else: + max_inputs_length = get_longest_list_length(input_ids) input_ids_split = [ input_ids[i:i + micro_batch_size] for i in range(0, len(input_ids), micro_batch_size) ] - attention_mask_split = ( - torch.split(attention_mask, micro_batch_size, dim=0) - if attention_mask is not None else [None for _ in range(num_splits)]) + attention_mask_split = [ + attention_mask[i:i + micro_batch_size] if attention_mask + is not None else [None for _ in range(num_splits)] for i in range( + 0, + len(attention_mask + ) if attention_mask is not None else num_splits * + micro_batch_size, micro_batch_size) + ] position_ids_split = ( torch.split(position_ids, micro_batch_size, dim=0) if position_ids is not None else [None for _ in range(num_splits)]) @@ -79,6 +89,7 @@ def partition_by_micro_batch_size( micro_batch['attention_mask'] = attention_mask_split[i] micro_batch['position_ids'] = position_ids_split[i] micro_batch['labels'] = labels_split[i] + micro_batch['max_inputs_length'] = max_inputs_length micro_batches.append(micro_batch) return micro_batches @@ -172,3 +183,12 @@ def get_answer_str( clean_up_tokenization_spaces=False, ) return answer_str + + +def get_longest_list_length(list_of_lists): + max_length = 0 + for int_list in list_of_lists: + current_length = len(int_list) + if current_length > max_length: + max_length = current_length + return max_length diff --git a/xtuner/rlhf/model_backend/vllm_model_runner.py b/xtuner/rlhf/model_backend/vllm_model_runner.py index 6acda480b..70f7c76e9 100644 --- a/xtuner/rlhf/model_backend/vllm_model_runner.py +++ b/xtuner/rlhf/model_backend/vllm_model_runner.py @@ -103,6 +103,7 @@ def get_sampling_params_from_dict(generate_kwargs: dict) -> SamplingParams: def generate( self, inputs: Union[torch.Tensor, str, list[str]], + max_inputs_length: int, step=-1, output_str=True, output_logits=False, @@ -149,16 +150,6 @@ def generate( req_outputs = self.llm.generate( prompt_token_ids=prompt, sampling_params=sp) - def get_longest_list_length(list_of_lists): - max_length = 0 - for int_list in list_of_lists: - current_length = len(int_list) - if current_length > max_length: - max_length = current_length - return max_length - - _max_length = get_longest_list_length(prompt) - def pad_list_with_pad_token(int_list, max_length, pad_token_id): if len(int_list) < max_length: num_pad_token_to_add = max_length - len(int_list) @@ -171,7 +162,7 @@ def pad_list_with_pad_token(int_list, max_length, pad_token_id): for _, req_output in enumerate(req_outputs): output = PolicyOutput() input_ids = [item for item in req_output.prompt_token_ids] - input_ids = pad_list_with_pad_token(input_ids, _max_length, + input_ids = pad_list_with_pad_token(input_ids, max_inputs_length, self.tokenizer.pad_token_id) output_token_ids = [ item for item in req_output.outputs[0].token_ids @@ -191,8 +182,8 @@ def pad_list_with_pad_token(int_list, max_length, pad_token_id): ) output[ 'attention_mask'] = output.question_mask + output.answer_mask # noqa: E501 - output['action_mask'] = output['attention_mask'][:, _max_length - - 1:-1] + output['action_mask'] = output[ + 'attention_mask'][:, max_inputs_length - 1:-1] if output_logits: raise NotImplementedError('TODO: output_logits') if output_attentions: @@ -310,6 +301,7 @@ def generate_async(self, input_ids, attention_mask, *args, **kwargs): return [ self.ray_actors[index].generate.remote( inputs=micro_batch['input_ids'], + max_inputs_length=micro_batch['max_inputs_length'], attention_mask=micro_batch['attention_mask'], *args, **kwargs, From 7d6e1b3324f4e884c3db69028877b15578e54c4a Mon Sep 17 00:00:00 2001 From: lishuaibin Date: Thu, 27 Jun 2024 03:41:53 +0000 Subject: [PATCH 16/37] fix/add data_map_fn --- xtuner/rlhf/dataset/base.py | 29 +++-- xtuner/rlhf/dataset/utils/__init__.py | 14 ++- xtuner/rlhf/dataset/utils/from_hf.py | 148 ++++++++++++++++++++++---- xtuner/rlhf/dataset/utils/map_fns.py | 99 ++++++++++++++++- 4 files changed, 255 insertions(+), 35 deletions(-) diff --git a/xtuner/rlhf/dataset/base.py b/xtuner/rlhf/dataset/base.py index f870ae5f2..68bddb695 100644 --- a/xtuner/rlhf/dataset/base.py +++ b/xtuner/rlhf/dataset/base.py @@ -28,7 +28,13 @@ class InfiniteDataset(IterableDataset): """Load infinite data from original dataset with shuffle.""" def __init__(self, dataset, rng=None): - self.data = list(iter(dataset)) + logger.info(f'init [InfiniteDataset] for {dataset} ...') + self.data = list( + iter(dataset)) if dataset.data_list is None else dataset.data_list + self.tokenizer = dataset.tokenizer + self.sys_prompt = dataset.sys_prompt + self.rm_prompt = dataset.rm_prompt + self.indices = list(range(len(self.data))) if rng is None: rng = random.Random() @@ -38,7 +44,20 @@ def __iter__(self): while True: self.rng.shuffle(self.indices) for i in self.indices: - yield self.data[i] + if isinstance(self.data[i], dict): + yield self.data[i] + elif isinstance(self.data[i], list): + try: + self.tokenizer.apply_chat_template( + self.data[i], tokenize=True) + except Exception: + logger.info('[data tokenize check] ' + f'skip dirty data: {self.data[i]}') + continue + yield dict( + data=self.data[i], + sys_prompt=self.sys_prompt, + rm_prompt=self.rm_prompt) class IterDataset(IterableDataset): @@ -66,8 +85,6 @@ def __iter__(self): logger.info( f'[data tokenize check] skip dirty data: {data}') continue - if data is None: - continue yield dict( data=data, sys_prompt=self.sys_prompt, @@ -82,8 +99,6 @@ def __iter__(self): logger.info( f'[data tokenize check] skip dirty data: {data}') continue - if data is None: - continue yield dict( data=data, sys_prompt=self.sys_prompt, @@ -130,6 +145,7 @@ def __init__(self, task_groups, tokenizer=None, random_seed=1024): logger.info(f'Loading {hf_dir} from huggingface ...') dataset = load_from_hf(hf_dir, tokenizer=tokenizer) task['dataset'] = IterDataset( + filename=hf_dir, data_list=dataset['conversation'], tokenizer=tokenizer, sys_prompt=task['sys_prompt'], @@ -257,6 +273,7 @@ def __init__(self, task_groups, tokenizer=None, random_seed=1024): logger.info(f'Loading {hf_dir} with huggingface format ...') dataset = load_from_hf(hf_dir, tokenizer=tokenizer) task['dataset'] = JsonDataset( + filename=hf_dir, data_list=dataset['conversation'], tokenizer=tokenizer, sys_prompt=task['sys_prompt'], diff --git a/xtuner/rlhf/dataset/utils/__init__.py b/xtuner/rlhf/dataset/utils/__init__.py index ec03048b9..b16c4da71 100644 --- a/xtuner/rlhf/dataset/utils/__init__.py +++ b/xtuner/rlhf/dataset/utils/__init__.py @@ -1,7 +1,15 @@ from .collate_fns import message_data_collator, messages_collate_fn -from .map_fns import H4_summarize_map_fn, hhrlhf_map_fn +from .map_fns import (FW_fineweb_edu_map_fn, H4_hhh_alignment_map_fn, + H4_summarize_map_fn, argilla_prompt_map_fn, + default_map_fn, hhrlhf_map_fn, nvidia_HelpSteer_map_fn, + nvidia_OpenMathInstruct_map_fn, + nvidia_sft_datablend_v1_map_fn, + stingning_ultrachat_map_fn) __all__ = [ - 'message_data_collator', 'messages_collate_fn', 'hhrlhf_map_fn', - 'H4_summarize_map_fn' + 'message_data_collator', 'messages_collate_fn', 'default_map_fn', + 'hhrlhf_map_fn', 'H4_summarize_map_fn', 'H4_hhh_alignment_map_fn', + 'stingning_ultrachat_map_fn', 'nvidia_HelpSteer_map_fn', + 'nvidia_OpenMathInstruct_map_fn', 'nvidia_sft_datablend_v1_map_fn', + 'argilla_prompt_map_fn', 'FW_fineweb_edu_map_fn' ] diff --git a/xtuner/rlhf/dataset/utils/from_hf.py b/xtuner/rlhf/dataset/utils/from_hf.py index e4791d148..6242d4252 100644 --- a/xtuner/rlhf/dataset/utils/from_hf.py +++ b/xtuner/rlhf/dataset/utils/from_hf.py @@ -1,14 +1,26 @@ from datasets import load_dataset +from loguru import logger from xtuner.dataset import process_hf_dataset from xtuner.dataset.map_fns import template_map_fn_factory -from xtuner.rlhf.dataset.utils import H4_summarize_map_fn, hhrlhf_map_fn +# yapf: disable +from xtuner.rlhf.dataset.utils import (FW_fineweb_edu_map_fn, + H4_hhh_alignment_map_fn, + H4_summarize_map_fn, + argilla_prompt_map_fn, default_map_fn, + hhrlhf_map_fn, nvidia_HelpSteer_map_fn, + nvidia_OpenMathInstruct_map_fn, + nvidia_sft_datablend_v1_map_fn, + stingning_ultrachat_map_fn) +# yapf: enable from xtuner.utils import PROMPT_TEMPLATE def read_hf_dataset(tokenizer, path: str = None, data_dir: str = None, + name: str = None, + data_files: dict = None, dataset_map_fn=None, max_length=8192, split='train', @@ -16,9 +28,14 @@ def read_hf_dataset(tokenizer, remove_unused_columns=False, shuffle_before_pack=False, pack_to_max_length=False): - # https://huggingface.co/datasets/Anthropic/hh-rlhf template_map_fn = template_map_fn_factory(template=prompt_template) - dataset_org = load_dataset(path, data_dir=data_dir, trust_remote_code=True) + dataset_org = load_dataset( + path, + name=name, + data_dir=data_dir, + data_files=data_files, + trust_remote_code=True) + logger.info(f'load_dataset {path}, {dataset_org}') dataset = process_hf_dataset( dataset=dataset_org, tokenizer=tokenizer, @@ -34,21 +51,14 @@ def read_hf_dataset(tokenizer, def load_from_hf(hf_dir, tokenizer, data_dir=None): if 'Anthropic/hh-rlhf' in hf_dir: - # train: Dataset({ - # features: ['chosen', 'rejected'], - # num_rows: 160800 - # }) - # test: Dataset({ - # features: ['chosen', 'rejected'], - # num_rows: 8552 - # }) if data_dir is not None: data_dir = data_dir elif 'helpful-base' in hf_dir: data_dir = 'helpful-base' elif 'harmless-base' in hf_dir: data_dir = 'harmless-base' - + logger.info(f'loading from `Anthropic/hh-rlhf`, data_dir={data_dir},' + ' split=`train`, map_fn=hhrlhf_map_fn...') dataset = read_hf_dataset( tokenizer=tokenizer, path='Anthropic/hh-rlhf', @@ -56,20 +66,112 @@ def load_from_hf(hf_dir, tokenizer, data_dir=None): max_length=8192, split='train', dataset_map_fn=hhrlhf_map_fn) - if 'summarize_from_feedback' in hf_dir: - # train_prefs: Dataset({ - # features: ['prompt', 'chosen', 'rejected'], - # num_rows: 92858 - # }) - # train_sft: Dataset({ - # features: ['prompt', 'chosen', 'rejected'], - # num_rows: 92858 - # }) + elif 'HuggingFaceH4' in hf_dir: + if 'summarize_from_feedback' in hf_dir: + H4_path = 'HuggingFaceH4/summarize_from_feedback' + H4_map_fn = H4_summarize_map_fn + elif 'hhh_alignment': + H4_path = 'HuggingFaceH4/hhh_alignment' + H4_map_fn = H4_hhh_alignment_map_fn + else: + logger.warning(f'Please specify your dataset_map_fn for {hf_dir}') + H4_path = hf_dir + H4_map_fn = default_map_fn + logger.info(f'loading {H4_path}, data_dir={data_dir}, ' + f'split=`train_prefs`, map_fn={H4_map_fn}...') dataset = read_hf_dataset( tokenizer=tokenizer, - path='HuggingFaceH4/summarize_from_feedback', + path=H4_path, data_dir=data_dir, max_length=8192, split='train_prefs', - dataset_map_fn=H4_summarize_map_fn) + dataset_map_fn=H4_map_fn) + elif 'ultrachat' in hf_dir: + logger.info( + f'loading from `stingning/ultrachat`, data_dir={data_dir}, ' + 'split=`train`, map_fn=stingning_ultrachat_map_fn...') + dataset = read_hf_dataset( + tokenizer=tokenizer, + path='stingning/ultrachat', + data_dir=data_dir, + max_length=8192, + split='train', + dataset_map_fn=stingning_ultrachat_map_fn) + elif 'nvidia' in hf_dir: + if 'HelpSteer' in hf_dir: + nvidia_map_fn = nvidia_HelpSteer_map_fn + elif 'OpenMathInstruct' in hf_dir: + nvidia_map_fn = nvidia_OpenMathInstruct_map_fn + elif 'sft_datablend_v1' in hf_dir: + nvidia_map_fn = nvidia_sft_datablend_v1_map_fn + else: + logger.warning(f'Please specify your dataset_map_fn for {hf_dir}') + nvidia_map_fn = default_map_fn + logger.info(f'loading from {hf_dir}, data_dir={data_dir}, ' + f'split=`train`, map_fn={nvidia_map_fn}...') + dataset = read_hf_dataset( + tokenizer=tokenizer, + path=hf_dir, + data_dir=data_dir, + max_length=8192, + split='train', + dataset_map_fn=nvidia_map_fn) + elif 'argilla' in hf_dir: + if 'prompt-collective' in hf_dir: + argilla_path = 'argilla/prompt-collective' + argilla_map_fn = argilla_prompt_map_fn + else: + logger.warning(f'Please specify your dataset_map_fn for {hf_dir}') + argilla_path = hf_dir + argilla_map_fn = default_map_fn + logger.info(f'loading from {argilla_path}, data_dir={data_dir}, ' + f'split=`train`, map_fn={argilla_map_fn}...') + dataset = read_hf_dataset( + tokenizer=tokenizer, + path=argilla_path, + data_dir=data_dir, + max_length=8192, + split='train', + dataset_map_fn=argilla_map_fn) + elif 'HuggingFaceFW' in hf_dir: + if 'fineweb-edu' in hf_dir: + FW_path = 'HuggingFaceFW/fineweb-edu' + FW_name = 'CC-MAIN-2024-10' + FW_data_files = { + 'train': [ + 'data/CC-MAIN-2024-10/train-00000-of-00020.parquet', + ] + } + FW_map_fn = FW_fineweb_edu_map_fn + else: + logger.warning(f'Please specify your dataset_map_fn for {hf_dir}') + FW_path = hf_dir + FW_map_fn = default_map_fn + logger.info(f'loading from {FW_path}, name={FW_name}, ' + f'data_files={FW_data_files}, data_dir={data_dir}, ' + f'split=`train`, map_fn={FW_map_fn}...') + dataset = read_hf_dataset( + tokenizer=tokenizer, + path=FW_path, + name=FW_name, + data_files=FW_data_files, + data_dir=data_dir, + max_length=8192, + split='train', + dataset_map_fn=FW_map_fn) + else: + try: + logger.warning(f'Please specify your dataset_map_fn with {hf_dir}') + dataset = read_hf_dataset( + tokenizer=tokenizer, + path=hf_dir, + data_dir=data_dir, + max_length=8192, + split='train', + dataset_map_fn=default_map_fn) + except Exception as e: + logger.error(f'{e}') + logger.error(f'Cannot load {hf_dir}, ' + 'checkout your datapath or dataset_map_fn...') + logger.info(f'Loaded {hf_dir}, {dataset}') return dataset diff --git a/xtuner/rlhf/dataset/utils/map_fns.py b/xtuner/rlhf/dataset/utils/map_fns.py index f66dd3ec7..4bb68c4a3 100644 --- a/xtuner/rlhf/dataset/utils/map_fns.py +++ b/xtuner/rlhf/dataset/utils/map_fns.py @@ -1,24 +1,117 @@ import re +def default_map_fn(example): + return example + + def hhrlhf_map_fn(example): string = example['chosen'] - pattern = r'(\n\nHuman|\n\nAssistant)(.+?)(?=(\n\nHuman|\n\nAssistant|$))' # noqa: E501 + pattern = r'(\n\nHuman|\n\nAssistant)(.+?)(?=(\n\nHuman|\n\nAssistant|$))' matches = re.findall(pattern, string, re.DOTALL) messages = [] for match in matches: role, content = match[0].strip(), match[1].strip() if role == 'Human': - messages.append({'role': 'user', 'content': content[2:]}) + messages.append(dict(role='user', content=content[2:])) elif role == 'Assistant': - messages.append({'role': 'assistant', 'content': content[2:]}) + messages.append(dict(role='assistant', content=content[2:])) else: raise NotImplementedError('role must in Human or Assistant') return {'conversation': messages} +def H4_hhh_alignment_map_fn(example): + input = example['input'] + choices = example['targets']['choices'] + labels = example['targets']['labels'] + for label, choice in zip(labels, choices): + if label == 1: + chosen = choice + messages = [ + dict(role='user', content=input), + dict(role='assistant', content=chosen) + ] + return {'conversation': messages} + + def H4_summarize_map_fn(example): # prompt = example['prompt'] chosen = example['chosen'] # rejected = example['rejected'] return {'conversation': chosen} + + +def stingning_ultrachat_map_fn(example): + # id = example['id'] + data = example['data'] + messages = [] + for i, d in enumerate(data): + if i % 2 == 0: + role = 'user' + else: + role = 'assistant' + messages.append(dict(role=role, content=d)) + + return {'conversation': messages} + + +def nvidia_HelpSteer_map_fn(example): + prompt = example['prompt'] + response = example['response'] + messages = [ + dict(role='user', content=prompt), + dict(role='assistant', content=response) + ] + + return {'conversation': messages} + + +def nvidia_OpenMathInstruct_map_fn(example): + question = example['question'] + # expected_answer = example['expected_answer'] + generated_solution = example['generated_solution'] + messages = [ + dict(role='user', content=question), + dict(role='assistant', content=generated_solution) + ] + + return {'conversation': messages} + + +def nvidia_sft_datablend_v1_map_fn(example): + conversations = example['conversations'] + # system = example['system'] + messages = [] + for conv in conversations: + if conv['from'] == 'User': + role = 'user' + elif conv['from'] == 'Assistant': + role = 'assistant' + messages.append(dict(role=role, content=conv['value'])) + + return {'conversation': messages} + + +def argilla_prompt_map_fn(example): + prompt = example['prompt'] + messages = [dict(role='user', content=prompt)] + return {'conversation': messages} + + +def dibt_prompt_map_fn(example): + prompt = example['prompt'] + messages = [dict(role='user', content=prompt)] + return {'conversation': messages} + + +def FW_fineweb_edu_map_fn(example): + question = '' + answer = example['text'] + token_count = example['token_count'] + messages = [ + dict(role='user', content=question), + dict(role='assistant', content=answer) + ] + + return {'conversation': messages, 'token_count': token_count} From 1a284ff21c39d03bf24573dca73dc01c842e1bda Mon Sep 17 00:00:00 2001 From: Zhu Zhihao Date: Thu, 27 Jun 2024 11:07:08 +0000 Subject: [PATCH 17/37] fix: vllm dp&tp size --- xtuner/rlhf/model_backend/generate_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/xtuner/rlhf/model_backend/generate_utils.py b/xtuner/rlhf/model_backend/generate_utils.py index ea5ee0c4a..15fa0b669 100644 --- a/xtuner/rlhf/model_backend/generate_utils.py +++ b/xtuner/rlhf/model_backend/generate_utils.py @@ -40,6 +40,8 @@ def partition_by_micro_batch_size( labels: Optional[Union[list[torch.Tensor], torch.Tensor, dict[str, torch.Tensor]]] = None, ) -> list[dict[str, torch.Tensor]]: + max_inputs_length = get_longest_list_length(input_ids) if isinstance( + input_ids, list) else None micro_batches: list[dict[str, torch.Tensor]] = [] batch_size = input_ids.shape[0] if isinstance( input_ids, torch.Tensor) else len(input_ids) @@ -49,6 +51,7 @@ def partition_by_micro_batch_size( micro_batch['attention_mask'] = attention_mask micro_batch['position_ids'] = position_ids micro_batch['labels'] = labels + micro_batch['max_inputs_length'] = max_inputs_length micro_batches.append(micro_batch) return micro_batches if micro_batch_size > batch_size: @@ -56,14 +59,12 @@ def partition_by_micro_batch_size( num_splits = int(batch_size // micro_batch_size) + ( batch_size % micro_batch_size > 0) - max_inputs_length = None if isinstance(input_ids, torch.Tensor): input_ids_split = torch.split(input_ids, micro_batch_size, dim=0) attention_mask_split = ( torch.split(attention_mask, micro_batch_size, dim=0) if attention_mask is not None else [None for _ in range(num_splits)]) else: - max_inputs_length = get_longest_list_length(input_ids) input_ids_split = [ input_ids[i:i + micro_batch_size] for i in range(0, len(input_ids), micro_batch_size) From be26ad6da4c8714cd9a9d9862ad1adf4cea50b09 Mon Sep 17 00:00:00 2001 From: liangkaihuan Date: Tue, 23 Jul 2024 20:25:34 +0800 Subject: [PATCH 18/37] update for base profile --- .gitignore | 3 + examples/rlhf/internlm2_1_8b_test_8gpu.py | 279 +++++++++++++++++++++ examples/rlhf/internlm2_20b_test_32gpu.py | 280 ++++++++++++++++++++++ scripts/gpu_info.sh | 25 ++ scripts/train_1node.sh | 30 +++ scripts/train_ray.sh | 96 ++++++++ tools/count_gpu.py | 27 +++ tools/count_time.py | 48 ++++ xtuner/rlhf/dataset/message_iter.py | 6 +- xtuner/rlhf/main.py | 19 ++ 10 files changed, 810 insertions(+), 3 deletions(-) create mode 100644 examples/rlhf/internlm2_1_8b_test_8gpu.py create mode 100644 examples/rlhf/internlm2_20b_test_32gpu.py create mode 100644 scripts/gpu_info.sh create mode 100644 scripts/train_1node.sh create mode 100644 scripts/train_ray.sh create mode 100644 tools/count_gpu.py create mode 100644 tools/count_time.py diff --git a/.gitignore b/.gitignore index c13320a73..beb600a20 100644 --- a/.gitignore +++ b/.gitignore @@ -123,3 +123,6 @@ rlhf_trainlog*/ # srun *.out batchscript-* + +# custom +logs/ diff --git a/examples/rlhf/internlm2_1_8b_test_8gpu.py b/examples/rlhf/internlm2_1_8b_test_8gpu.py new file mode 100644 index 000000000..dccd8af82 --- /dev/null +++ b/examples/rlhf/internlm2_1_8b_test_8gpu.py @@ -0,0 +1,279 @@ +####################################################################### +# Settings # +####################################################################### +RESUME_STEP = -1 +MAX_PROMPT_LEN = 1024 +MAX_ANSWER_LEN = 1024 +MAX_PRETRAIN_LEN = 8192 + +PROMPT_BATCH_SIZE = 128 +PRETRAIN_BATCH_SIZE = 0 # 32 + +GENERATE_MICRO_BATCH_SIZE = 16 +INFER_MICRO_BATCH_SIZE = 8 +TRAIN_MICRO_BATCH_SIZE = 2 + +ZERO_STAGE = 3 +POLICY_DP_SIZE = 2 +CRITIC_DP_SIZE = 2 +POLICY_GRADIENT_ACC_STEP = (PROMPT_BATCH_SIZE + PRETRAIN_BATCH_SIZE + ) // POLICY_DP_SIZE // TRAIN_MICRO_BATCH_SIZE +CRITIC_GRADIENT_ACC_STEP = PROMPT_BATCH_SIZE // CRITIC_DP_SIZE // TRAIN_MICRO_BATCH_SIZE # noqa: E501 + +# checkout generate config +assert PROMPT_BATCH_SIZE % GENERATE_MICRO_BATCH_SIZE == 0 +assert PROMPT_BATCH_SIZE % POLICY_DP_SIZE == 0 +# checkout infer config +assert PROMPT_BATCH_SIZE % (INFER_MICRO_BATCH_SIZE * POLICY_DP_SIZE) == 0 +assert PROMPT_BATCH_SIZE % (INFER_MICRO_BATCH_SIZE * CRITIC_DP_SIZE) == 0 +# checkout learn config +assert (PROMPT_BATCH_SIZE + PRETRAIN_BATCH_SIZE) % (TRAIN_MICRO_BATCH_SIZE * + POLICY_DP_SIZE) == 0 +assert (PROMPT_BATCH_SIZE) % (TRAIN_MICRO_BATCH_SIZE * CRITIC_DP_SIZE) == 0 + +MODEL_DTYPE = 'auto' + +tokenizer_config = dict( + pad_token_id=0, + eos_token_id=92542, + padding_side='left', +) + +rollout_config = dict( + policy_micro_bs=GENERATE_MICRO_BATCH_SIZE, + reward_micro_bs=GENERATE_MICRO_BATCH_SIZE, + max_new_tokens=MAX_ANSWER_LEN, + write_to_file=False, + resume_step=RESUME_STEP, + generate_kwargs={ + 'do_sample': True, + 'temperature': 1.0, + 'top_k': 0, + 'top_p': 0.9, + 'min_new_tokens': 1, + 'num_beams': 1, + 'early_stopping': True, + 'eos_token_id': 92542, + 'pad_token_id': 0, + }, + async_reward=True, +) + +repeater_config = dict( + policy_micro_bs=INFER_MICRO_BATCH_SIZE, + critic_micro_bs=INFER_MICRO_BATCH_SIZE, + ref_micro_bs=INFER_MICRO_BATCH_SIZE, + reward_micro_bs=GENERATE_MICRO_BATCH_SIZE, + kl_coeff=0.01, + gamma=1.0, + gae_lambda=0.99, + clip_reward_min=-5, + clip_reward_max=5, + norm_rewards=True, +) + +train_config = dict( + policy_micro_bs=TRAIN_MICRO_BATCH_SIZE, + critic_micro_bs=TRAIN_MICRO_BATCH_SIZE, + ppo_loss_weight=1.0, + pretrain_loss_weight=0.5, + critic_warmup_step=0, + save_interval=40, + max_train_step=400, + resume_step=RESUME_STEP, + async_learn=False, +) + +model_configs = dict( + policy=dict( + # model_path='internlm/internlm2-chat-1_8b-sft', + model_path='/mnt/afs_2/wangxinjiang/models/internlm2_1_8b_sft/', + model_type='policy', + trainer_config=dict( + torch_dtype=MODEL_DTYPE, + trainer_type='huggingface', + use_flash_attn=True, + gradient_checkpointing=False, + train_kwargs=dict( + micro_bsz=1, + lr=1e-6, + total_steps=1e9, + lr_decay_rate=1, + ), + parallel=dict( + data=dict(size=POLICY_DP_SIZE, mode='deepspeed'), + tensor=dict(size=1, mode='1d'), + pipeline=dict(size=1, interleaved_overlap=False), + sequence=False, + ), + deepspeed_config={ + 'zero_optimization': { + 'stage': ZERO_STAGE, + 'offload_param': { + 'device': 'none' + }, + 'reduce_bucket_size': 'auto', + 'zero_hpz_partition_size': 1, + 'zero_quantized_weights': False, + 'zero_quantized_gradients': False, + 'stage3_gather_16bit_weights_on_model_save': True, + }, + 'bf16': { + 'enabled': True + }, + 'gradient_clipping': 1.0, + 'prescale_gradients': False, + 'wall_clock_breakdown': False, + 'data_types': { + 'grad_accum_dtype': 'fp32' + }, + 'train_micro_batch_size_per_gpu': TRAIN_MICRO_BATCH_SIZE, + 'gradient_accumulation_steps': POLICY_GRADIENT_ACC_STEP, + 'train_batch_size': PROMPT_BATCH_SIZE + PRETRAIN_BATCH_SIZE, + }, + ), + generator_config=dict( + shared_with_trainer=False, + generator_type='vllm', + parallel=dict( + data=dict(size=1, mode='ddp'), + tensor=dict(size=2, mode='1d'), + pipeline=dict(size=1, interleaved_overlap=False), + sequence=False, + ), + ), + ), + critic=dict( + # model_path=None, + model_path='/mnt/afs_2/wangxinjiang/models/internlm2_chat_1_8b_reward_full_varlenattn_jsonl_dataset', + model_type='critic', + trainer_config=dict( + torch_dtype=MODEL_DTYPE, + trainer_type='huggingface', + use_flash_attn=True, + gradient_checkpointing=False, + train_kwargs=dict( + micro_bsz=1, + lr=5e-6, + total_steps=1e9, + lr_decay_rate=1, + ), + parallel=dict( + data=dict(size=CRITIC_DP_SIZE, mode='deepspeed'), + tensor=dict(size=1, mode='1d'), + pipeline=dict(size=1, interleaved_overlap=False), + sequence=False, + ), + deepspeed_config={ + 'zero_optimization': { + 'stage': ZERO_STAGE, + 'offload_param': { + 'device': 'none' + }, + 'reduce_bucket_size': 'auto', + 'zero_hpz_partition_size': 1, + 'zero_quantized_weights': False, + 'zero_quantized_gradients': False + }, + 'bf16': { + 'enabled': True + }, + 'gradient_clipping': 1.0, + 'prescale_gradients': False, + 'wall_clock_breakdown': False, + 'data_types': { + 'grad_accum_dtype': 'fp32' + }, + 'train_micro_batch_size_per_gpu': TRAIN_MICRO_BATCH_SIZE, + 'gradient_accumulation_steps': CRITIC_GRADIENT_ACC_STEP, + 'train_batch_size': PROMPT_BATCH_SIZE, + }, + ), + ), + reference=dict( + # model_path='internlm/internlm2-chat-1_8b-sft', + model_path='/mnt/afs_2/wangxinjiang/models/internlm2_1_8b_sft/', + model_type='reference', + trainer_config=dict( + torch_dtype=MODEL_DTYPE, + trainer_type='huggingface', + use_flash_attn=True, + parallel=dict( + data=dict(size=1, mode='ddp'), + tensor=dict(size=1, mode='1d'), + pipeline=dict(size=1, interleaved_overlap=False), + sequence=False, + ), + ), + ), + reward=dict( + # model_path=None, + model_path='/mnt/afs_2/wangxinjiang/models/internlm2_chat_1_8b_reward_full_varlenattn_jsonl_dataset', + model_type='reward', + trainer_config=dict( + torch_dtype=MODEL_DTYPE, + trainer_type='huggingface', + use_flash_attn=True, + parallel=dict( + data=dict(size=1, mode='ddp'), + tensor=dict(size=1, mode='1d'), + pipeline=dict(size=1, interleaved_overlap=False), + sequence=False, + ), + ), + ), +) + +# prompt_dataset_config = dict( +# samples_each_epoch=PROMPT_BATCH_SIZE, +# max_len=MAX_PROMPT_LEN, +# message_type='prompt', +# random_seed=1024, +# sample_strategy='in_batch', # 'in_data' +# message_datasets=[ +# './examples/rlhf/demo_datas/prompt_data.json::0.01[SYS_PROMPT]:summarization', # noqa: E501 +# '[HF]Anthropic/hh-rlhf/helpful-base::0.5[RM_PROMPT]:default', +# '[HF]HuggingFaceH4/summarize_from_feedback::0.5', +# ]) + +# pretrain_dataset_config = dict( +# samples_each_epoch=PRETRAIN_BATCH_SIZE, +# max_len=MAX_PRETRAIN_LEN, +# message_type='pretrain', +# random_seed=1024, +# sample_strategy='in_batch', # 'in_data' +# message_datasets=[ +# './examples/rlhf/demo_datas/pretrain_data.json::0.01', +# '[HF]Anthropic/hh-rlhf/helpful-base::0.5', +# '[HF]HuggingFaceH4/summarize_from_feedback::0.5', +# ], +# ) + + +prompt_dataset_config = dict( + samples_each_epoch=PROMPT_BATCH_SIZE, + max_len=MAX_PROMPT_LEN, + message_type='prompt', + random_seed=1024, + sample_strategy='in_batch', # 'in_data' + message_datasets=[ + "/mnt/afs_2/wangxinjiang/data/rlhf/reward_data/zl0602/reformat_ppo/HelpSteer_internlm2.jsonl::1.0", + # "/mnt/afs_2/wangxinjiang/data/rlhf/reward_data/zl0602/reformat_ppo/gpt-4-llm-en-internlm2.jsonl::1.0", + # "/mnt/afs_2/wangxinjiang/data/rlhf/reward_data/zl0602/reformat_ppo/orca_dpo_pairs_hf4_internlm2.jsonl::1.0", + # "/mnt/afs_2/wangxinjiang/data/rlhf/reward_data/zl0602/reformat_ppo/ultrafeedback_cleaned_internlm2.jsonl::1.0", + ]) + +# pretrain_dataset_config = dict( +# samples_each_epoch=PRETRAIN_BATCH_SIZE, +# max_len=MAX_PRETRAIN_LEN, +# message_type='pretrain', +# random_seed=1024, +# sample_strategy='in_batch', # 'in_data' +# message_datasets=[ +# "/mnt/afs_2/wangxinjiang/data/rlhf/reward_data/zl0602/reformat_ppo_pt/HelpSteer_internlm2.jsonl::1.0", +# "/mnt/afs_2/wangxinjiang/data/rlhf/reward_data/zl0602/reformat_ppo_pt/gpt-4-llm-en-internlm2.jsonl::1.0", +# "/mnt/afs_2/wangxinjiang/data/rlhf/reward_data/zl0602/reformat_ppo_pt/orca_dpo_pairs_hf4_internlm2.jsonl::1.0", +# "/mnt/afs_2/wangxinjiang/data/rlhf/reward_data/zl0602/reformat_ppo_pt/ultrafeedback_cleaned_internlm2.jsonl::1.0", +# "/mnt/afs_2/wangxinjiang/data/rlhf/reward_data/zl0602/reformat_ppo_pt/zhuxian_60k_split0.jsonl::1.0" +# ], +# ) \ No newline at end of file diff --git a/examples/rlhf/internlm2_20b_test_32gpu.py b/examples/rlhf/internlm2_20b_test_32gpu.py new file mode 100644 index 000000000..076b32575 --- /dev/null +++ b/examples/rlhf/internlm2_20b_test_32gpu.py @@ -0,0 +1,280 @@ +####################################################################### +# Settings # +####################################################################### +RESUME_STEP = -1 +MAX_PROMPT_LEN = 1536 +MAX_ANSWER_LEN = 512 +MAX_PRETRAIN_LEN = 4096 + +PROMPT_BATCH_SIZE = 128 +PRETRAIN_BATCH_SIZE = 0 # 0 + +GENERATE_MICRO_BATCH_SIZE = 8 +INFER_MICRO_BATCH_SIZE = 2 +TRAIN_MICRO_BATCH_SIZE = 1 + +ZERO_STAGE = 3 +POLICY_DP_SIZE = 8 +CRITIC_DP_SIZE = 8 +REF_DP_SIZE = 4 +REWARD_DP_SIZE = 4 +VLLM_TP_SIZE=8 +POLICY_GRADIENT_ACC_STEP = (PROMPT_BATCH_SIZE + PRETRAIN_BATCH_SIZE + ) // POLICY_DP_SIZE // TRAIN_MICRO_BATCH_SIZE +CRITIC_GRADIENT_ACC_STEP = PROMPT_BATCH_SIZE // CRITIC_DP_SIZE // TRAIN_MICRO_BATCH_SIZE # noqa: E501 + +# checkout generate config +assert PROMPT_BATCH_SIZE % GENERATE_MICRO_BATCH_SIZE == 0 +assert PROMPT_BATCH_SIZE % POLICY_DP_SIZE == 0 +# checkout infer config +assert PROMPT_BATCH_SIZE % (INFER_MICRO_BATCH_SIZE * POLICY_DP_SIZE) == 0 +assert PROMPT_BATCH_SIZE % (INFER_MICRO_BATCH_SIZE * CRITIC_DP_SIZE) == 0 +# checkout learn config +assert (PROMPT_BATCH_SIZE + PRETRAIN_BATCH_SIZE) % (TRAIN_MICRO_BATCH_SIZE * + POLICY_DP_SIZE) == 0 +assert (PROMPT_BATCH_SIZE) % (TRAIN_MICRO_BATCH_SIZE * CRITIC_DP_SIZE) == 0 + +MODEL_DTYPE = 'auto' + +tokenizer_config = dict( + pad_token_id=0, + eos_token_id=92542, + padding_side='left', +) + +rollout_config = dict( + policy_micro_bs=GENERATE_MICRO_BATCH_SIZE, + reward_micro_bs=GENERATE_MICRO_BATCH_SIZE, + max_new_tokens=MAX_ANSWER_LEN, + # write_to_file=True, + write_to_file=False, ## Debug-Only + resume_step=RESUME_STEP, + generate_kwargs={ + 'do_sample': True, + 'temperature': 1.0, + 'top_k': 0, + 'top_p': 0.9, + 'min_new_tokens': 1, + 'num_beams': 1, + 'early_stopping': True, + 'eos_token_id': 92542, + 'pad_token_id': 0, + }, + async_reward=True, +) + +repeater_config = dict( + policy_micro_bs=INFER_MICRO_BATCH_SIZE, + critic_micro_bs=INFER_MICRO_BATCH_SIZE, + # ref_micro_bs=INFER_MICRO_BATCH_SIZE, + ref_micro_bs=8, ## Optimize + kl_coeff=0.01, + gamma=1.0, + gae_lambda=0.99, + clip_reward_min=-5, + clip_reward_max=5, + norm_rewards=True, +) + +train_config = dict( + policy_micro_bs=TRAIN_MICRO_BATCH_SIZE, + critic_micro_bs=TRAIN_MICRO_BATCH_SIZE, + ppo_loss_weight=1.0, + pretrain_loss_weight=0.5, + # critic_warmup_step=40, + critic_warmup_step=0, ## Debug-Only + save_interval=200, + max_train_step=800, + resume_step=RESUME_STEP, + async_learn=True, ## Optimize +) + +model_configs = dict( + policy=dict( + model_path="/mnt/afs_2/wangxinjiang/models/internlm2_20b_gauss_mg_pack_public_ckpt_data_merge_0413_32k_hf_impack/", + model_type='policy', + trainer_config=dict( + torch_dtype=MODEL_DTYPE, + trainer_type='huggingface', + use_flash_attn=True, + gradient_checkpointing=True, + train_kwargs=dict( + micro_bsz=1, + lr=5e-7, + total_steps=1e9, + lr_decay_rate=1, + ), + parallel=dict( + data=dict(size=POLICY_DP_SIZE, mode='deepspeed'), + tensor=dict(size=1, mode='1d'), + pipeline=dict(size=1, interleaved_overlap=False), + sequence=False, + ), + deepspeed_config={ + "zero_optimization": { + "stage": 3, + "overlap_comm": True, + "stage3_gather_16bit_weights_on_model_save": True + }, + 'bf16': { + 'enabled': True + }, + 'gradient_clipping': 1.0, + 'prescale_gradients': False, + 'wall_clock_breakdown': False, + 'data_types': { + 'grad_accum_dtype': 'fp32' + }, + 'train_micro_batch_size_per_gpu': TRAIN_MICRO_BATCH_SIZE, + 'gradient_accumulation_steps': POLICY_GRADIENT_ACC_STEP, + 'train_batch_size': PROMPT_BATCH_SIZE + PRETRAIN_BATCH_SIZE, + }, + ), + generator_config=dict( + shared_with_trainer=False, + generator_type='vllm', + parallel=dict( + data=dict(size=1, mode='ddp'), + tensor=dict(size=VLLM_TP_SIZE, mode='1d'), + pipeline=dict(size=1, interleaved_overlap=False), + sequence=False, + ), + ), + ), + critic=dict( + model_path="/mnt/afs_2/wangxinjiang/models/internlm2_chat_20b_reward_full_varlenattn_ultrafeedback_1node/", + model_type="critic", + trainer_config=dict( + torch_dtype=MODEL_DTYPE, + trainer_type='huggingface', + use_flash_attn=True, + gradient_checkpointing=True, + train_kwargs=dict( + micro_bsz=1, + lr=9e-6, + total_steps=1e9, + lr_decay_rate=1, + loss_type="per_seq", + ), + parallel=dict( + data=dict(size=CRITIC_DP_SIZE, mode='deepspeed'), + tensor=dict(size=1, mode='1d'), + pipeline=dict(size=1, interleaved_overlap=False), + sequence=False, + ), + deepspeed_config={ + "zero_optimization": { + "stage": 3, + "overlap_comm": True, + "stage3_gather_16bit_weights_on_model_save": True + }, + 'bf16': { + 'enabled': True + }, + 'gradient_clipping': 1.0, + 'prescale_gradients': False, + 'wall_clock_breakdown': False, + 'data_types': { + 'grad_accum_dtype': 'fp32' + }, + 'train_micro_batch_size_per_gpu': TRAIN_MICRO_BATCH_SIZE, + 'gradient_accumulation_steps': CRITIC_GRADIENT_ACC_STEP, + 'train_batch_size': PROMPT_BATCH_SIZE, + }, + ), + ), + reference=dict( + model_path="/mnt/afs_2/wangxinjiang/models/internlm2_20b_gauss_mg_pack_public_ckpt_data_merge_0413_32k_hf_impack/", + model_type="reference", + trainer_config=dict( + torch_dtype=MODEL_DTYPE, + trainer_type='huggingface', + use_flash_attn=True, + parallel=dict( + data=dict(size=REF_DP_SIZE, mode="deepspeed"), + tensor=dict(size=1, mode="1d"), + pipeline=dict(size=1, interleaved_overlap=False), + sequence=False, + ), + deepspeed_config={ + "zero_optimization": { + # "stage": 3, + "stage": 0, ## Optimize + "overlap_comm": True, + "stage3_gather_16bit_weights_on_model_save": True + }, + "bf16": { + "enabled": True + }, + "gradient_clipping": 1.0, + "prescale_gradients": False, + "wall_clock_breakdown": False, + "data_types": { + "grad_accum_dtype": "fp32" + }, + "train_micro_batch_size_per_gpu": 2 + }, + ), + ), + reward=dict( + model_path="/mnt/afs_2/wangxinjiang/models/internlm2_chat_20b_reward_full_varlenattn_ultrafeedback_1node/", + model_type="reward", + trainer_config=dict( + torch_dtype=MODEL_DTYPE, + trainer_type='huggingface', + use_flash_attn=True, + parallel=dict( + data=dict(size=REWARD_DP_SIZE, mode="deepspeed"), + tensor=dict(size=1, mode='1d'), + pipeline=dict(size=1, interleaved_overlap=False), + sequence=False, + ), + deepspeed_config={ + "zero_optimization": { + # "stage": 3, + "stage": 0, ## Optimize + "overlap_comm": True, + "stage3_gather_16bit_weights_on_model_save": True + }, + "bf16": { + "enabled": True + }, + "gradient_clipping": 1.0, + "prescale_gradients": False, + "wall_clock_breakdown": False, + "data_types": { + "grad_accum_dtype": "fp32" + }, + "train_micro_batch_size_per_gpu": 2 + }, + ), + ), +) + +prompt_dataset_config = dict( + samples_each_epoch=PROMPT_BATCH_SIZE, + max_len=MAX_PROMPT_LEN, + message_type='prompt', + random_seed=1024, + sample_strategy='in_batch', # 'in_data' + message_datasets=[ + "/mnt/afs_2/wangxinjiang/data/rlhf/reward_data/zl0602/reformat_ppo/HelpSteer_internlm2.jsonl::1.0", + "/mnt/afs_2/wangxinjiang/data/rlhf/reward_data/zl0602/reformat_ppo/gpt-4-llm-en-internlm2.jsonl::1.0", + "/mnt/afs_2/wangxinjiang/data/rlhf/reward_data/zl0602/reformat_ppo/orca_dpo_pairs_hf4_internlm2.jsonl::1.0", + "/mnt/afs_2/wangxinjiang/data/rlhf/reward_data/zl0602/reformat_ppo/ultrafeedback_cleaned_internlm2.jsonl::1.0", + ]) + +pretrain_dataset_config = dict( + samples_each_epoch=PRETRAIN_BATCH_SIZE, + max_len=MAX_PRETRAIN_LEN, + message_type='pretrain', + random_seed=1024, + sample_strategy='in_batch', # 'in_data' + message_datasets=[ + "/mnt/afs_2/wangxinjiang/data/rlhf/reward_data/zl0602/reformat_ppo_pt/HelpSteer_internlm2.jsonl::1.0", + "/mnt/afs_2/wangxinjiang/data/rlhf/reward_data/zl0602/reformat_ppo_pt/gpt-4-llm-en-internlm2.jsonl::1.0", + "/mnt/afs_2/wangxinjiang/data/rlhf/reward_data/zl0602/reformat_ppo_pt/orca_dpo_pairs_hf4_internlm2.jsonl::1.0", + "/mnt/afs_2/wangxinjiang/data/rlhf/reward_data/zl0602/reformat_ppo_pt/ultrafeedback_cleaned_internlm2.jsonl::1.0", + "/mnt/afs_2/wangxinjiang/data/rlhf/reward_data/zl0602/reformat_ppo_pt/zhuxian_60k_split0.jsonl::1.0" + ], +) diff --git a/scripts/gpu_info.sh b/scripts/gpu_info.sh new file mode 100644 index 000000000..26ff1afef --- /dev/null +++ b/scripts/gpu_info.sh @@ -0,0 +1,25 @@ +#!/bin/bash + +# 日志文件路径 +#LOG_FILE="gpu_vllm.log" +LOG_FILE=$1 + +# 每次循环的时间间隔(秒) +INTERVAL=1 + +while true; do + # 获取当前时间 + timestamp=$(date '+%Y-%m-%d %H:%M:%S') + + # 获取所有GPU的利用率 + utilization=$(nvidia-smi --query-gpu=utilization.gpu --format=csv,noheader,nounits | tr '\n' ', ' | sed 's/, $//') + + # 获取所有GPU的显存使用情况 + memory=$(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits | tr '\n' ', ' | sed 's/, $//') + + # 将结果写入文件 + echo "$timestamp, Utilization: $utilization; Memory: $memory" >> $LOG_FILE + + # 等待指定的时间间隔 + sleep $INTERVAL +done diff --git a/scripts/train_1node.sh b/scripts/train_1node.sh new file mode 100644 index 000000000..28e7ea3f6 --- /dev/null +++ b/scripts/train_1node.sh @@ -0,0 +1,30 @@ +#!/bin/bash +set -ex + +export XTERM=linux + +# export NCCL_DEBUG=INFO +export NCCL_IB_TIMEOUT=22 +export NCCL_IB_RETRY_CNT=13 +export NCCL_IB_AR_THRESHOLD=0 +export NCCL_P2P_LEVEL=NVL + +export PYTHONPATH=$PWD:$PYTHONPATH +export HF_HOME=$(realpath $PWD/../../cache/huggingface/) + +config_file=$1 +work_dirs=$2 +start_time=$(date +%Y%m%d%H%M) + +# config_file must exist +if [ ! -f $config_file ]; then + echo "Config file $config_file does not exist" + exit 1 +fi +mkdir -p $work_dirs + +#python xtuner/rlhf/main.py -c $config_file -w $work_dirs > $work_dirs/debug.log 2>&1 & +# python xtuner/rlhf/main.py -c $config_file -w $work_dirs 2>&1 | tee $work_dirs/main-$start_time.log +# python xtuner/rlhf/test_actor.py -c $config_file -w $work_dirs 2>&1 | tee $work_dirs/main-$start_time.log +python xtuner/rlhf/test_actor_background.py -c $config_file -w $work_dirs +# 2>&1 | tee $work_dirs/main-$start_time.log diff --git a/scripts/train_ray.sh b/scripts/train_ray.sh new file mode 100644 index 000000000..43429dc4a --- /dev/null +++ b/scripts/train_ray.sh @@ -0,0 +1,96 @@ +#!/bin/bash +set -ex + +# export NCCL_DEBUG=INFO +export NCCL_IB_TIMEOUT=22 +export NCCL_IB_RETRY_CNT=13 +export NCCL_IB_AR_THRESHOLD=0 + +echo $MASTER_ADDR +echo $MASTER_PORT +echo $WORLD_SIZE +echo $RANK + +export NCCL_P2P_LEVEL=NVL +export PYTHONPATH=$PWD:$PYTHONPATH +export HF_HOME=$(realpath $PWD/../../cache/huggingface/) + +config_file=$1 +work_dirs=$2 +num_nodes=$3 +start_time=$(date +%Y%m%d%H%M) + +# config_file must exist +if [ ! -f $config_file ]; then + echo "Config file $config_file does not exist" + exit 1 +fi +# num_nodes has to be at least 1 +if [ $num_nodes -lt 1 ]; then + echo "Number of nodes must be at least 1" + exit 1 +fi +mkdir -p $work_dirs + +# if HOST contains "master", then this is the head node +if [[ $RANK -eq 0 ]]; then + node_role="master" +else + node_role="worker" +fi +head_node_ip=$MASTER_ADDR + +wait_time=10 +if [ "$node_role" == "master" ]; then + echo "Starting Ray head node..." + # Start Ray on this node as the head node and extract its address + # The `ray start --head` command outputs information that includes the address, + # but here we're assuming it's known or statically assigned for simplicity. + ray start --head --dashboard-host 0.0.0.0 --port=6379 --resources '{"COMPUTE": 100000000000000.0, "HEAD": 100000000000000.0}' + sleep $wait_time +elif [ "$node_role" == "worker" ]; then + sleep $wait_time + attempt=1 + echo "Starting Ray worker node and attempting to connect to the head node at $head_node_ip:6379" + while true; do + # Attempt to start Ray and connect to the head node + ray start --address="$head_node_ip:6379" --resources '{"COMPUTE": 100000000000000.0, "virtual_cluster_default": 100000000000000.0}' && break || { + if [ $attempt -le 5 ]; then + echo "Ray worker start attempt $attempt failed. Retrying in $wait_time seconds..." + ((attempt++)) + sleep $wait_time + else + echo "Failed to connect to the head node after $wait_time attempts. Exiting." + exit 1 + fi + } + done +fi +# run the training script once Ray has been started on all nodes +sleep $wait_time +if [ "$node_role" == "master" ]; then + num_active_ray_nodes=$(ray list nodes | grep ALIVE | wc -l) + echo "Number of active Ray nodes: $num_active_ray_nodes" + if [ $num_active_ray_nodes -lt $num_nodes ]; then + echo "Waiting for all Ray nodes to start..." + attempt=1 + while true; do + num_active_ray_nodes=$(ray list nodes | grep ALIVE | wc -l) + if [ $num_active_ray_nodes -eq $num_nodes ]; then + break + elif [ $attempt -le 5 ]; then + echo "python command attempt $attempt failed. Retrying in $wait_time seconds..." + ((attempt++)) + sleep $wait_time + else + echo "Failed to connect to the head node after $wait_time attempts. Exiting." + exit 1 + fi + done + fi + # python xtuner/rlhf/main.py -c $config_file -w $work_dirs 2>&1 | tee $work_dirs/main-$start_time.log + python -u xtuner/rlhf/main.py -c $config_file -w $work_dirs > $work_dirs/main-$start_time.log 2>&1 & +else + sleep infinity 2>&1 & +fi + diff --git a/tools/count_gpu.py b/tools/count_gpu.py new file mode 100644 index 000000000..2253b010e --- /dev/null +++ b/tools/count_gpu.py @@ -0,0 +1,27 @@ +import sys +import numpy as np + +def calculate_averages(file_path): + utilization_values = [] + + with open(file_path, 'r') as file: + for line in file: + parts = line.split('Utilization:')[1].split(';')[0].split(',') + values = [int(val) for val in parts if val] + utilization_values.append(values) + + # Convert the list of lists into a NumPy array for easier manipulation + utilization_array = np.array(utilization_values) + + # Calculate the mean for each column + #mean_values = np.mean(utilization_array) + mean_values = np.mean(utilization_array[:,4:]) + + return mean_values + +# Usage +#file_path = 'gpu_ref_reward.log' +file_path = sys.argv[1] +averages = calculate_averages(file_path) +print("Average values for each position:", averages) + diff --git a/tools/count_time.py b/tools/count_time.py new file mode 100644 index 000000000..1e75a630d --- /dev/null +++ b/tools/count_time.py @@ -0,0 +1,48 @@ +import json + +def calculate_statistics(filename): + query_tokens_mean_list = [] + resp_tokens_mean_list = [] + time_fields = ["generate_time", "forward_time", "training_time"] + time_sums = {field: 0 for field in time_fields} + time_counts = {field: 0 for field in time_fields} + + count = 0 + with open(filename, 'r') as file: + for line in file: + count += 1 + + data = json.loads(line) + query_tokens_mean_list.append(data['query_tokens_mean']) + resp_tokens_mean_list.append(data['resp_tokens_mean']) + + for field in time_fields: + if field in data: + time_sums[field] += data[field] + time_counts[field] += 1 + + if count > 18: + break + + query_tokens_mean_avg = sum(query_tokens_mean_list) / len(query_tokens_mean_list) if query_tokens_mean_list else 0 + resp_tokens_mean_avg = sum(resp_tokens_mean_list) / len(resp_tokens_mean_list) if resp_tokens_mean_list else 0 + + time_averages = {field: (time_sums[field] / time_counts[field]) if time_counts[field] > 0 else 0 for field in time_fields} + + total_time_sum = sum(time_sums.values()) + total_time_avg= sum(time_averages.values()) + time_proportions = {field: (time_sums[field] / total_time_sum) if total_time_sum > 0 else 0 for field in time_fields} + + return { + "query_tokens_mean_avg": query_tokens_mean_avg, + "resp_tokens_mean_avg": resp_tokens_mean_avg, + **time_averages, + "total_time_avg": total_time_avg, + "time_proportions": time_proportions + } + +# 使用示例 +filename = 'logs/internlm2_20b_train_async/train_rlhf.log.jsonl' # 替换为你的jsonl文件路径 +statistics = calculate_statistics(filename) +print(statistics) + diff --git a/xtuner/rlhf/dataset/message_iter.py b/xtuner/rlhf/dataset/message_iter.py index 94c78b83a..59033decf 100644 --- a/xtuner/rlhf/dataset/message_iter.py +++ b/xtuner/rlhf/dataset/message_iter.py @@ -214,9 +214,9 @@ def _postprocess_sequence(self, message): if (token_ids.shape[-1] <= 4) or (token_ids.shape[-1] > self.max_len): # TODO truncation?? - logger.warning( - f'[MES_ITER] {self.message_type} message {message} ' - 'is too short or long, skipped.') + # logger.warning( + # f'[MES_ITER] {self.message_type} message {message} ' + # 'is too short or long, skipped.') return None elif self.message_type == 'pretrain': for _ in reversed(range(len(message_data))): diff --git a/xtuner/rlhf/main.py b/xtuner/rlhf/main.py index 9d692f6bf..3ba06f474 100644 --- a/xtuner/rlhf/main.py +++ b/xtuner/rlhf/main.py @@ -124,26 +124,36 @@ def validate_config(config: Config): s_t = time.time() with Timer(f'step {step}: end_to_end'): # generate trajectories + gen_start = time.time() trajectories = txt_env.rollout(display=True) + gen_time = time.time() - gen_start # deal with trajectories + fwd_start = time.time() trajectories = ppo_repeater.process(trajectories) + fwd_time = time.time() - fwd_start + train_start = time.time() # critic & policy learn if async_learn: critic_loss_ref = ppo.critic_learn_async(trajectories) else: + critic_train_start = time.time() critic_loss = ppo.critic_learn(trajectories) + critic_train_time = time.time() - critic_train_start ppo_loss, pt_loss = None, None if critic_warmup_step <= 0: ppo_loss, pt_loss = ppo.policy_learn(trajectories) + logger_train.info( f'[Policy Train] Step: {step}, ' f'ppo loss: {ppo_loss}, pretrain loss: {pt_loss}') if async_learn: critic_loss = ppo.critic_learn_get(critic_loss_ref) + train_time = time.time() - train_start + total_time = time.time() - s_t logger_train.info( f'[Critic Train] step: {step}, critic loss: {critic_loss}') @@ -172,6 +182,15 @@ def validate_config(config: Config): policy_loss=ppo_loss, pretrain_loss=pt_loss, critic_loss=critic_loss, + + query_tokens_mean=trajectories.question_mask.sum( + -1).float().mean().item(), + resp_tokens_mean=trajectories.answer_mask.sum( + -1).float().mean().item(), + generate_time=gen_time, + forward_time=fwd_time, + training_time=train_time, + total_time=total_time, ) with open(f'{work_dir}/train_rlhf.log.jsonl', 'a') as f: f.write(json.dumps(summaries) + '\n') From c7ce9a3f1f1aabc9f0dd3f23a35bb6ddf05b3da1 Mon Sep 17 00:00:00 2001 From: liangkaihuan Date: Tue, 23 Jul 2024 20:28:11 +0800 Subject: [PATCH 19/37] support stage pipeline --- xtuner/rlhf/envs/txt_env.py | 133 ++++--- xtuner/rlhf/model_backend/hf_model_runner.py | 119 +++++- .../rlhf/model_backend/vllm_model_runner.py | 169 +++++++++ .../rlhf/model_server/policy_model_server.py | 31 ++ xtuner/rlhf/repeaters/kl_gae.py | 101 ++++- xtuner/rlhf/test_actor.py | 333 +++++++++++++++++ xtuner/rlhf/test_actor_background.py | 329 +++++++++++++++++ xtuner/rlhf/test_policy_ref_pipe.py | 346 ++++++++++++++++++ xtuner/rlhf/test_ref.py | 218 +++++++++++ xtuner/rlhf/trainer/ppo.py | 27 +- 10 files changed, 1735 insertions(+), 71 deletions(-) create mode 100644 xtuner/rlhf/test_actor.py create mode 100644 xtuner/rlhf/test_actor_background.py create mode 100644 xtuner/rlhf/test_policy_ref_pipe.py create mode 100644 xtuner/rlhf/test_ref.py diff --git a/xtuner/rlhf/envs/txt_env.py b/xtuner/rlhf/envs/txt_env.py index 2fcce30c5..5d17d6842 100644 --- a/xtuner/rlhf/envs/txt_env.py +++ b/xtuner/rlhf/envs/txt_env.py @@ -9,6 +9,7 @@ from .base import EnvBase from .utils import SYSTEM_PROMPT +import time class TxtEnv(EnvBase): """A generic RL environment to generate textual sequences.""" @@ -17,8 +18,8 @@ def __init__( self, policy_model: BaseModelServer, reward_model: BaseModelServer, - prompt_mes_iter: Iterable, - pretrain_mes_iter: Iterable = None, + # prompt_mes_iter: Iterable, + # pretrain_mes_iter: Iterable = None, max_new_tokens: int = 1024, policy_micro_bs: int = 32, reward_micro_bs: int = 32, @@ -30,9 +31,9 @@ def __init__( self.policy_model = policy_model self.reward_model = reward_model - self.prompt_mes_iter = iter(prompt_mes_iter) - self.pretrain_mes_iter = iter( - pretrain_mes_iter) if pretrain_mes_iter.message_datasets else None + # self.prompt_mes_iter = iter(prompt_mes_iter) + # self.pretrain_mes_iter = iter( + # pretrain_mes_iter) if pretrain_mes_iter.message_datasets else None self.max_new_tokens = max_new_tokens self.policy_micro_bs = policy_micro_bs @@ -41,25 +42,28 @@ def __init__( self.generate_kwargs: dict = generate_kwargs self.resume_step = resume_step - def rollout(self, display=True): - while self.resume_step > 0: - logger.info(f'[Resume] {self.resume_step} consuming data...') - next(self.prompt_mes_iter) - if self.pretrain_mes_iter is not None: - next(self.pretrain_mes_iter) - self.resume_step -= 1 - prompt_datas = deepcopy(next(self.prompt_mes_iter)) - prompt_input_messages = [] - for data in prompt_datas: - assert data.mes_type == 'prompt' - if data.sys_prompt != 'default': - message = deepcopy([ - dict( - role='system', content=SYSTEM_PROMPT[data.sys_prompt]) - ] + data.message) - else: - message = deepcopy(data.message) - prompt_input_messages.append(message) + def rollout(self, prompt_datas, prompt_input_messages, display=True): + # while self.resume_step > 0: + # logger.info(f'[Resume] {self.resume_step} consuming data...') + # next(self.prompt_mes_iter) + # if self.pretrain_mes_iter is not None: + # next(self.pretrain_mes_iter) + # self.resume_step -= 1 + # prompt_datas = deepcopy(next(self.prompt_mes_iter)) + # prompt_input_messages = [] + # for data in prompt_datas: + # assert data.mes_type == 'prompt' + # if data.sys_prompt != 'default': + # message = deepcopy([ + # dict( + # role='system', content=SYSTEM_PROMPT[data.sys_prompt]) + # ] + data.message) + # else: + # message = deepcopy(data.message) + # prompt_input_messages.append(message) + + start = time.time() + # prompt data if display: logger.info( @@ -73,39 +77,64 @@ def rollout(self, display=True): generate_kwargs=self.generate_kwargs) logger.info(f'[Generate] len: {len(prompt_input_messages)}') - if self.async_reward: - reward_output_ref = self.get_reward_async(prompt_datas, - trajectories) - trajectories['reward_output_ref'] = reward_output_ref - else: - rewards = self.get_reward(prompt_datas, trajectories) - trajectories['rewards'] = rewards + end = time.time() + trajectories['stage1_time'] = end - start + + # if self.async_reward: + # reward_output_ref = self.get_reward_async(prompt_datas, + # trajectories) + # trajectories['reward_output_ref'] = reward_output_ref + # else: + # rewards = self.get_reward(prompt_datas, trajectories) + # trajectories['rewards'] = rewards # pretrain data - if self.pretrain_mes_iter is not None: - pretrain_datas = deepcopy(next(self.pretrain_mes_iter)) - pretrain_input_messages = [] - for data in pretrain_datas: - assert data.mes_type == 'pretrain' - pretrain_input_messages.append(message) - - from xtuner.rlhf.tokenizer import encode_inputs - pt_input_ids, pt_attention_mask = encode_inputs( - pretrain_input_messages, self.policy_model.tokenizer) - pretrain_labels = torch.nn.functional.pad( - pt_input_ids[:, 1:], (0, 1), mode='constant', value=-100) - - trajectories.pretrain_data = { - 'input_ids': pt_input_ids, - 'labels': pretrain_labels, - 'attention_mask': pt_attention_mask - } - logger.info(f'[TxtEnv] gets {pt_input_ids.shape} pretrain data.') - else: - trajectories.pretrain_data = None + # if self.pretrain_mes_iter is not None: + # pretrain_datas = deepcopy(next(self.pretrain_mes_iter)) + # pretrain_input_messages = [] + # for data in pretrain_datas: + # assert data.mes_type == 'pretrain' + # pretrain_input_messages.append(message) + + # from xtuner.rlhf.tokenizer import encode_inputs + # pt_input_ids, pt_attention_mask = encode_inputs( + # pretrain_input_messages, self.policy_model.tokenizer) + # pretrain_labels = torch.nn.functional.pad( + # pt_input_ids[:, 1:], (0, 1), mode='constant', value=-100) + + # trajectories.pretrain_data = { + # 'input_ids': pt_input_ids, + # 'labels': pretrain_labels, + # 'attention_mask': pt_attention_mask + # } + # logger.info(f'[TxtEnv] gets {pt_input_ids.shape} pretrain data.') + # else: + # trajectories.pretrain_data = None + trajectories.pretrain_data = None return trajectories + def rollout_background(self, prompt_input_messages): + with Timer('policy_model.generate'): + ref = self.policy_model.generate_background( + inputs=prompt_input_messages, + micro_batch_size=self.policy_micro_bs, + step=self.max_new_tokens, + output_str=True, + generate_kwargs=self.generate_kwargs) + logger.info(f'[Generate] len: {len(prompt_input_messages)}') + return ref + + def get_generate_finish(self, num): + start = time.time() + + trajectories = self.policy_model.get_generate_finish(num) + trajectories.pretrain_data = None + + end = time.time() + trajectories['stage1_time'] = end - start + return trajectories + # default get_reward() is blocking. # get_reward_async() needs to call get_reward_collect() def get_reward_async(self, prompt_datas, policyout): diff --git a/xtuner/rlhf/model_backend/hf_model_runner.py b/xtuner/rlhf/model_backend/hf_model_runner.py index 4786d1038..52dd14d8b 100644 --- a/xtuner/rlhf/model_backend/hf_model_runner.py +++ b/xtuner/rlhf/model_backend/hf_model_runner.py @@ -12,7 +12,7 @@ from ray.util.placement_group import remove_placement_group from torch.nn.modules.loss import _Loss from torch.optim.lr_scheduler import _LRScheduler -from transformers import AutoModelForCausalLM, PreTrainedModel +from transformers import AutoModelForCausalLM, PreTrainedModel, AutoConfig from transformers import get_scheduler as transformers_get_scheduler from transformers.dynamic_module_utils import init_hf_modules from transformers.generation.utils import GenerateDecoderOnlyOutput @@ -90,6 +90,13 @@ def initialize(self): if use_flash_attn else None, ) + # Note: Debug Only, random init model weight + # config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) + # self.model: PreTrainedModel = model_class.from_config( + # config, + # trust_remote_code=True, + # ) + # Graident checkpointing gradient_checkpointing = self.model_config.get( 'gradient_checkpointing', False) @@ -241,6 +248,7 @@ def train( # None means using the entire input as one batch micro_batch_size: Optional[Union[list[int], int]] = None, debug=False, + update_param=True, **_ignored, ): if isinstance(input_ids, torch.Tensor): @@ -305,7 +313,11 @@ def train( set_seed(1234) loss_list[index] = sum(loss_entry) / len(loss_entry) - self.parameter_update(step_interval) + # self.parameter_update(step_interval) + if update_param: + self.parameter_update(step_interval) + + loss_list = [loss.cpu() for loss in loss_list] return loss_list if len(loss_list) > 1 else loss_list[0] # Inference @@ -414,6 +426,83 @@ def infer( # Concatenate the policy outputs from each micro batch and return the result # noqa: E501 return concat_policy_outputs(policy_outputs) + + @torch.no_grad() + def infer_from_future( + self, + # input_ids: torch.Tensor, + # attention_mask=None, + object_refs, + micro_batch_size: Optional[ + int] = -1, # -1: use the entire input as one batch + tokenizer=None, # Only used for reward models + output_logprobs=False, + output_logits=True, + output_attentions=False, + output_hidden_states=False, + infer_kwargs: Optional[dict] = {}, + debug=False, + **_ignored, + ) -> PolicyOutput: + self.info_rank0( + f'[{self.model_type}] self.infer() kwargs: {infer_kwargs}') + + ########## + outputs = ray.get(object_refs, timeout=None) + padding_token_map = { + 'output_ids': 0, + } + trajectories = concat_policy_outputs(outputs, padding_token_map) + input_ids = trajectories.output_ids + attention_mask = trajectories.attention_mask + ########## + + input_ids = input_ids.to(self.device) + if attention_mask is not None: + attention_mask = attention_mask.to(self.device) + # returns entire-input-as-one-batch inference results + if micro_batch_size < 0: + self.info_rank0( + f'[{self.model_type}] infer() input_ids.shape: {input_ids.shape}' # noqa: E501 + ) + return self._infer( + input_ids, + attention_mask, + output_logprobs, + output_logits, + output_attentions, + output_hidden_states, + infer_kwargs, + ) + + # Otherwise, partition the input into micro batches and run inference on each micro batch separately # noqa: E501 + micro_batches = partition_by_micro_batch_size(input_ids, + micro_batch_size, + attention_mask) + policy_outputs = [] + for index, micro_batch in enumerate(micro_batches): + input_ids_mb = micro_batch['input_ids'] + attention_mask_mb = micro_batch['attention_mask'] + if index == 0: + self.info_rank0( + f'[{self.model_type}] will infer() input_ids_mb.shape: {input_ids_mb.shape} * {len(micro_batches)} times' # noqa: E501 + ) + policy_output_mb = self._infer( + input_ids_mb, + attention_mask_mb, + output_logprobs, + output_logits, + output_attentions, + output_hidden_states, + infer_kwargs, + ) + policy_outputs.append(policy_output_mb) + if debug: + self.set_seed(1234) + # Concatenate the policy outputs from each micro batch and return the result # noqa: E501 + return concat_policy_outputs(policy_outputs) + + # Generate @torch.no_grad() def _generate( @@ -674,6 +763,8 @@ class HfModelRunnerRayActorGroup(RayActorGroup): init_hf_modules() def __init__(self, name: str, config: dict): + init_hf_modules() + super().__init__(name, config) self.released = True num_gpus = get_gpu_requirement(config) @@ -704,6 +795,10 @@ def __init__(self, name: str, config: dict): world_size=len(self.ray_actors), ) for rank, actor in enumerate(self.ray_actors) ]) + + for actor in self.ray_actors: + logger.info(ray.get(actor.get_metadata.remote())) + self.initialize_ref = [ actor.initialize.remote() for actor in self.ray_actors ] @@ -814,6 +909,26 @@ def infer(self, *args, **kwargs): object_refs = self.infer_async(*args, **kwargs) return self.infer_get(object_refs) + # Inference + def infer_from_future(self, object_refs, *args, **kwargs): + # micro_batch_size = input_ids.shape[0] // self.dp_size + ( + # input_ids.shape[0] % self.dp_size > 0 + # ) # round up division, i.e., math.ceil(a / b) + # micro_batches = partition_by_micro_batch_size(input_ids, + # micro_batch_size, + # attention_mask) + + # assert len(micro_batches) == self.dp_size + return [ + self.ray_actors[0].infer_from_future.remote( + # input_ids=micro_batch['input_ids'], + # attention_mask=micro_batch['attention_mask'], + object_refs, + *args, + **kwargs, + ) + ] + # Generation def generate_async(self, input_ids, attention_mask, *args, **kwargs): micro_batch_size = input_ids.shape[0] // self.dp_size + ( diff --git a/xtuner/rlhf/model_backend/vllm_model_runner.py b/xtuner/rlhf/model_backend/vllm_model_runner.py index 70f7c76e9..31d6ffd70 100644 --- a/xtuner/rlhf/model_backend/vllm_model_runner.py +++ b/xtuner/rlhf/model_backend/vllm_model_runner.py @@ -18,6 +18,9 @@ from .ray_actor_mixin import RayActorMixin from .ray_utils import DEFAULT_NUM_CPUS, DEFAULT_NUM_GPUS, set_runtime_env +import threading +import queue + VLLM_DEFAULT_DEVICE = 'cuda' @@ -67,6 +70,7 @@ def _set_cuda_visible_devices(device_ids: list[int]): swap_space=0, tensor_parallel_size=tensor_parallel_size, device=VLLM_DEFAULT_DEVICE, + # load_format='dummy', ) self.tokenizer = self.llm.get_tokenizer() tokenizer_config = self.model_config.get('tokenizer_config', {}) @@ -207,6 +211,130 @@ def pad_list_with_pad_token(int_list, max_length, pad_token_id): padding_token_map) return concated_policy_out + def generate_background( + self, + inputs: Union[torch.Tensor, str, list[str]], + max_inputs_length: int, + step=-1, + output_str=True, + output_logits=False, + output_attentions=False, + output_hidden_states=False, + generate_kwargs: Optional[dict] = {}, + **_ignored, + ) -> list[tuple[list[int], str]]: + sp = VllmGenerator.get_sampling_params_from_dict(generate_kwargs) + sp.max_tokens = step if step > 0 else None + logger.info( + f'[{self.__class__.__name__}] self.generate() SamplingParams: {sp}' + ) + + if isinstance(inputs, torch.Tensor): + if len(inputs.shape) == 2: # e.g., [batch_size, seq_len] + prompt = self.tokenizer.batch_decode( + inputs, + skip_special_tokens=False, + clean_up_tokenization_spaces=False, + ) + elif len(inputs.shape) == 1: # e.g., [seq_len] + prompt = self.tokenizer.decode( + inputs, + skip_special_tokens=False, + clean_up_tokenization_spaces=False, + ) + else: + raise ValueError( + f'Unsupported tensor inputs of shape({inputs.shape})') + + elif isinstance(inputs, str): + prompt = inputs # str + elif isinstance(inputs, list): + if isinstance(inputs[0], list): + prompt = inputs # list[int] + else: + raise ValueError( + f'Unsupported inputs[0] with type({type(inputs[0])})') + else: + raise ValueError(f'Unsupported inputs with type({type(inputs)})') + + self.max_inputs_length = max_inputs_length + self.output_str = output_str + self.output_logits = output_logits + self.output_attentions = output_attentions + self.output_hidden_states = output_hidden_states + self.generate_kwargs = generate_kwargs + self.queue = queue.Queue() + + self.generate_thread = threading.Thread(target=self.llm.generate_to_queue, + kwargs={'prompt_token_ids':prompt, + 'sampling_params':sp, + 'queue':self.queue}) + self.generate_thread.start() + + def get_generate_finish(self, num): + req_outputs = [] + while num > 0: + req_outputs.append(self.queue.get()) + self.queue.task_done() + num -= 1 + + def pad_list_with_pad_token(int_list, max_length, pad_token_id): + if len(int_list) < max_length: + num_pad_token_to_add = max_length - len(int_list) + padded_list = [pad_token_id] * num_pad_token_to_add + int_list + return padded_list + else: + return int_list + + policy_outputs = [] + for _, req_output in enumerate(req_outputs): + output = PolicyOutput() + input_ids = [item for item in req_output.prompt_token_ids] + input_ids = pad_list_with_pad_token(input_ids, self.max_inputs_length, + self.tokenizer.pad_token_id) + output_token_ids = [ + item for item in req_output.outputs[0].token_ids + ] + output_ids = input_ids + output_token_ids # concat + output['input_ids'] = torch.Tensor(input_ids).to( + torch.long).unsqueeze(0) + output['output_ids'] = torch.tensor(output_ids).to( + torch.long).unsqueeze(0) + + output['question_mask'], output[ + 'answer_mask'] = get_question_answer_mask( + output['input_ids'], + output['output_ids'], + tokenizer_pad_token_id=self.tokenizer.pad_token_id, + generate_pad_token_id=self.generate_kwargs.get('pad_token_id'), + ) + output[ + 'attention_mask'] = output.question_mask + output.answer_mask # noqa: E501 + output['action_mask'] = output[ + 'attention_mask'][:, self.max_inputs_length - 1:-1] + if self.output_logits: + raise NotImplementedError('TODO: output_logits') + if self.output_attentions: + raise NotImplementedError('TODO: output_attentions') + if self.output_hidden_states: + raise NotImplementedError('TODO: output_hidden_states') + if self.output_str: # return list[str] + output['output_ans_str'] = [req_output.outputs[0].text] + output_str = self.tokenizer.decode( + output_ids, + skip_special_tokens=True, + clean_up_tokenization_spaces=False, + ) + output['output_str'] = [output_str] + output.to('cpu') + + policy_outputs.append(output) + + padding_token_map = {'output_ids': self.tokenizer.pad_token_id} + concated_policy_out = concat_policy_outputs(policy_outputs, + padding_token_map) + return concated_policy_out + class VllmGeneratorRayActor(VllmGenerator, RayActorMixin): @@ -230,6 +358,9 @@ def update_weight(self, name, dtype, shape, empty_cache=False): class VllmGeneratorRayActorGroup(RayActorGroup): def __init__(self, name: str, config: dict): + from transformers.dynamic_module_utils import init_hf_modules + init_hf_modules() + import uuid self.released = True self.config = config @@ -268,6 +399,9 @@ def __init__(self, name: str, config: dict): runtime_env=set_runtime_env(), ).remote(config)) + # for actor in self.ray_actors: + # logger.info(ray.get(actor.get_metadata.remote())) + self.released = False self.initialize_ref = [ actor.initialize.remote() for actor in self.ray_actors @@ -285,6 +419,41 @@ def initialize_get(self): 'self.initialize_ref is None when calling initialize_get()') self.initialize_ref = None + def generate_background(self, input_ids, attention_mask, *args, **kwargs): + assert ( + len(input_ids) >= self.dp_size + ), f'The length of input_ids({len(input_ids)}) must not be less than dp_size({self.dp_size}).' # noqa: E501 + micro_batch_size = len(input_ids) // self.dp_size + ( + len(input_ids) % self.dp_size > 0 + ) # round up division, i.e., math.ceil(a / b) + micro_batches = partition_by_micro_batch_size(input_ids, + micro_batch_size, + attention_mask) + assert len(micro_batches + ) == self.dp_size, f'{len(micro_batches)}, :{self.dp_size}' + object_refs = [ + self.ray_actors[index].generate_background.remote( + inputs=micro_batch['input_ids'], + max_inputs_length=micro_batch['max_inputs_length'], + attention_mask=micro_batch['attention_mask'], + *args, + **kwargs, + ) for index, micro_batch in enumerate(micro_batches) + ] + return ray.get(object_refs) + + def get_generate_finish(self, num, timeout=None): + object_refs = [ + self.ray_actors[index].get_generate_finish.remote( + num + ) for index in range(len(self.ray_actors)) + ] + outputs = ray.get(object_refs, timeout=timeout) + padding_token_map = { + 'output_ids': self.config.tokenizer_config['pad_token_id'] + } + return concat_policy_outputs(outputs, padding_token_map) + # Generation def generate_async(self, input_ids, attention_mask, *args, **kwargs): assert ( diff --git a/xtuner/rlhf/model_server/policy_model_server.py b/xtuner/rlhf/model_server/policy_model_server.py index bbe819347..6e70d163e 100644 --- a/xtuner/rlhf/model_server/policy_model_server.py +++ b/xtuner/rlhf/model_server/policy_model_server.py @@ -46,6 +46,37 @@ def initialize_get(self): self.is_initialized = True logger.info(f'{self.model_name} has been initialized. ') + # Generation + def generate_background(self, + inputs, + attention_mask=None, + *args, + **generate_kwargs): + if isinstance(inputs, torch.Tensor): + input_ids = inputs + elif isinstance(inputs, list): + if not self.generator_eq_trainer: + input_ids, attention_mask = encode_inputs( + inputs, + self.tokenizer, + return_tensors=None, + padding=False, + add_generation_prompt=True) + else: + input_ids, attention_mask = encode_inputs( + inputs, self.tokenizer, add_generation_prompt=True) + else: + raise NotImplementedError(f'unknown inputs: {inputs}') + + return self.generator.generate_background( + input_ids=input_ids, + attention_mask=attention_mask, + *args, + **generate_kwargs) + + def get_generate_finish(self, num): + return self.generator.get_generate_finish(num) + # Generation def generate_async(self, inputs, diff --git a/xtuner/rlhf/repeaters/kl_gae.py b/xtuner/rlhf/repeaters/kl_gae.py index 5ab69611a..a97b9e6c5 100644 --- a/xtuner/rlhf/repeaters/kl_gae.py +++ b/xtuner/rlhf/repeaters/kl_gae.py @@ -6,6 +6,9 @@ from .base import RepeaterBase from .utils import RunningStates +from loguru import logger +from xtuner.rlhf.envs.utils import SYSTEM_PROMPT +import time class KLGAERepeater(RepeaterBase): @@ -14,9 +17,11 @@ def __init__( ref_model: BaseModelServer, policy_model: BaseModelServer, critic_model: BaseModelServer, + reward_model: BaseModelServer, policy_micro_bs: int = 8, ref_micro_bs: int = 8, critic_micro_bs: int = 32, + reward_micro_bs: int = 8, kl_coeff=0.01, gamma=1.0, gae_lambda=0.99, @@ -31,10 +36,12 @@ def __init__( self.ref_model = ref_model self.policy_model = policy_model self.critic_model = critic_model + self.reward_model = reward_model self.policy_micro_bs = policy_micro_bs self.ref_micro_bs = ref_micro_bs self.critic_micro_bs = critic_micro_bs + self.reward_micro_bs = reward_micro_bs self.kl_coeff = kl_coeff self.gamma = gamma self.gae_lambda = gae_lambda @@ -49,12 +56,13 @@ def __init__( # only used for async reward model.infer_get() in _get_kl_rewards self.env = env - def process(self, trajectories: PolicyOutput): + def process(self, prompt_datas, trajectories: PolicyOutput): + start = time.time() critic_output_ref = self._get_values_async(trajectories) action_mask = trajectories['action_mask'] num_actions = action_mask.size(1) (kl_rewards, entropy, kl_distance, policy_logprobs, - ref_logprobs) = self._get_kl_rewards(trajectories) + ref_logprobs) = self._get_kl_rewards(prompt_datas, trajectories) trajectories['kl'] = (kl_distance * action_mask).sum( axis=-1) / action_mask.sum(axis=-1) trajectories['entropy'] = entropy @@ -72,10 +80,20 @@ def process(self, trajectories: PolicyOutput): trajectories['advantages'] = advantages trajectories['returns'] = returns trajectories['old_values'] = old_values - + end = time.time() + trajectories['stage2_time'] = end - start return trajectories - def _get_kl_rewards(self, trajectories: PolicyOutput): + # for _ in range(10): + # critic_output_ref = self._get_values_async(trajectories) + # action_mask = trajectories['action_mask'] + # num_actions = action_mask.size(1) + # values = self._get_values_collect(critic_output_ref) + # old_values = values[:, -num_actions:] + # trajectories['old_values'] = old_values + # return trajectories + + def _get_kl_rewards(self, prompt_datas, trajectories: PolicyOutput): with Timer('policy_model.infer_async'): policy_output = self.policy_model.infer_async( inputs=trajectories.output_ids, @@ -90,18 +108,27 @@ def _get_kl_rewards(self, trajectories: PolicyOutput): attention_mask=trajectories.attention_mask, output_logits=False, output_logprobs=True) + + reward_output_ref = self.get_reward_async(prompt_datas, + trajectories) + rewards = self.get_reward_collect(reward_output_ref) + trajectories['rewards'] = rewards + with Timer('policy_model.infer_get'): policy_output = self.policy_model.infer_get(policy_output) with Timer('ref_model.infer_get'): ref_output = self.ref_model.infer_get(ref_output) - # Experimental - if self.env.async_reward: - rewards = self.env.get_reward_collect( - trajectories['reward_output_ref']) - trajectories['reward_output_ref'] = None - trajectories['rewards'] = rewards - # Experimental + # # Experimental + # if self.env.async_reward: + # rewards = self.env.get_reward_collect( + # trajectories['reward_output_ref']) + # trajectories['reward_output_ref'] = None + # trajectories['rewards'] = ray.get(rewards) + # else: + # rewards = trajectories['rewards'] + # # Experimental + # rewards = trajectories['rewards'] clipped_rewards = torch.clamp( rewards, min=self.clip_reward_min, max=self.clip_reward_max) @@ -206,3 +233,55 @@ def get_advantages_and_returns( advantages = torch.stack(advantages_reversed[::-1], dim=1) returns = advantages + values return advantages.detach(), returns + + # default get_reward() is blocking. + # get_reward_async() needs to call get_reward_collect() + def get_reward_async(self, prompt_datas, policyout): + # rm_input_messages = [] + # for i in range(len(prompt_datas)): + # if prompt_datas[i].mes_type != 'prompt': + # continue + # if (prompt_datas[i].rm_prompt != + # 'default') or (prompt_datas[i].sys_prompt != 'default'): + # # Conditional Reward Model + # # for queries from different domains, use appropriate conditional system prompts # noqa: E501 + # # From Alignment section of the InternLM2 Technical Report: + # # https://arxiv.org/pdf/2403.17297 + # if prompt_datas[i].rm_prompt != 'default': + # prompt = prompt_datas[i].rm_prompt + # else: + # prompt = prompt_datas[i].sys_prompt + # cur_rm_data = [ + # dict(role='system', content=SYSTEM_PROMPT[prompt]) + # ] + prompt_datas[i].message + [ + # dict( + # role='assistant', content=policyout.output_ans_str[i]) + # ] + # else: + # cur_rm_data = prompt_datas[i].message + [ + # dict( + # role='assistant', content=policyout.output_ans_str[i]) + # ] + # rm_input_messages.append(cur_rm_data) + + # logger.info(f'[For Reward]: {rm_input_messages[0]}') + + # with Timer('reward_model.infer_async'): + # reward_output_ref = self.reward_model.infer_async( + # rm_input_messages, + # output_logprobs=False, + # micro_batch_size=self.reward_micro_bs) + + with Timer('reward_model.infer_async'): + reward_output_ref = self.reward_model.infer_async( + inputs=policyout.output_ids, + attention_mask=policyout.attention_mask, + output_logprobs=False, + micro_batch_size=self.reward_micro_bs) + return reward_output_ref + + def get_reward_collect(self, reward_output_ref): + with Timer('reward_model.infer_get'): + rm_out = self.reward_model.infer_get(reward_output_ref) + rewards = rm_out.logits.squeeze(-1) + return rewards \ No newline at end of file diff --git a/xtuner/rlhf/test_actor.py b/xtuner/rlhf/test_actor.py new file mode 100644 index 000000000..30d0cc2f5 --- /dev/null +++ b/xtuner/rlhf/test_actor.py @@ -0,0 +1,333 @@ +import argparse +import json +import os +import shutil +import time + +from loguru import logger + +from xtuner.rlhf.config.config import Config +from xtuner.rlhf.coordinator import Coordinator +from xtuner.rlhf.dataset import MessageIter +from xtuner.rlhf.envs import TxtEnv +from xtuner.rlhf.repeaters import KLGAERepeater +from xtuner.rlhf.timer import Timer +from xtuner.rlhf.trainer import PPOTrainer + +import ray +from copy import deepcopy +from xtuner.rlhf.envs.utils import SYSTEM_PROMPT +from policy_output import (PolicyOutput, concat_policy_outputs, + logprobs_from_logits) + +def parse_args(): + parser = argparse.ArgumentParser(description='Train LLM') + parser.add_argument( + '-c', + '--config', + help='config file name or path.', + type=str, + default='examples/rlhf/four_model_vllm_8gpu.py') + parser.add_argument( + '-w', + '--work_dir', + help='the dir to save logs and models', + type=str, + default=None) + parser.add_argument( + '-a', '--address', help='ray head address', type=str, default='auto') + args = parser.parse_args() + return args + + +def validate_config(config: Config): + assert config['model_configs'] is not None + assert config['model_configs']['policy'] is not None + assert config['model_configs']['policy']['model_path'] is not None + assert config['dataset_config'] is not None + assert config['rollout_config'] is not None + assert config['rollout_config']['generate_kwargs'] is not None + assert config['rollout_config']['max_new_tokens'] is not None + + +def flatten_list(nested_list): + flattened = [] + for item in nested_list: + if isinstance(item, list): + flattened.extend(flatten_list(item)) + else: + flattened.append(item) + return flattened + + +if __name__ == '__main__': + args = parse_args() + assert args.config is not None, 'config should not be None' + work_dir = args.work_dir + if work_dir is None: + work_dir = os.getcwd() + '/rlhf_trainlog_' + time.strftime( + '%Y-%m-%d-%H:%M:%S') + work_dir = os.path.abspath(work_dir) + logger.info(f'using work_dir: {work_dir}') + os.makedirs(work_dir, exist_ok=True) + # save original config + shutil.copy2(args.config, f'{work_dir}/{os.path.basename(args.config)}') + + logger.add( + f'{work_dir}/train_rlhf.log', + filter=lambda record: record['extra'].get('name') == 'train') + logger_train = logger.bind(name='train') + + config = Config.from_file(args.config) + logger.info('#################### CONFIG BGN ####################') + for k, v in config.items(): + logger.info(f'{k}: {v}') + logger.info('#################### CONFIG END ####################') + + # init model + cluster_address = args.address + if cluster_address != 'auto': + cluster_address = f'ray://{cluster_address}:10001' + logger.info(f'cluster_address={cluster_address}') + coordinator = Coordinator(cluster_address, config) + model_dict = coordinator.create_models() + ref_model = model_dict['reference'] + policy_model = model_dict['policy'] + reward_model = model_dict['reward'] + critic_model = model_dict['critic'] + + # init prompt & pretrain dataset + prompt_dataset_config = config['prompt_dataset_config'] + prompt_mes_iter = MessageIter( + tokenizer=ref_model.tokenizer, **prompt_dataset_config) + pretrain_dataset_config = config.get('pretrain_dataset_config', {}) + pretrain_mes_iter = MessageIter( + tokenizer=ref_model.tokenizer, **pretrain_dataset_config) + + prompt_mes_iter = iter(prompt_mes_iter) + pretrain_mes_iter = iter( + pretrain_mes_iter) if pretrain_mes_iter.message_datasets else None + + # init txt env + rollout_config = config.get('rollout_config', {}) + txt_env = ray.remote(TxtEnv).remote( + policy_model=policy_model, + reward_model=reward_model, + # prompt_mes_iter=prompt_mes_iter, + # pretrain_mes_iter=pretrain_mes_iter, # None + **rollout_config, + ) + # init repeater + repeater_config = config.get('repeater_config', {}) + ppo_repeater = ray.remote(KLGAERepeater).remote( + ref_model=ref_model, + policy_model=policy_model, + critic_model=critic_model, + reward_model=reward_model, + env=txt_env, + **repeater_config, + ) + # init trainer + train_config = config.get('train_config', {}) + ppo = ray.remote(PPOTrainer).options(max_concurrency=2).remote( + policy_model=policy_model, critic_model=critic_model, **train_config) + critic_warmup_step = train_config['critic_warmup_step'] + save_interval = train_config['save_interval'] + max_train_step = train_config.get('max_train_step', float('inf')) + resume_step = train_config.get('resume_step', -1) + critic_warmup_step = min(critic_warmup_step, + critic_warmup_step - resume_step) + async_learn = train_config.get('async_learn', False) + + step = max(0, resume_step) + while step <= max_train_step: + s_t = time.time() + with Timer(f'step {step}: end_to_end'): + + prompt_datas = deepcopy(next(prompt_mes_iter)) + prompt_input_messages = [] + for data in prompt_datas: + assert data.mes_type == 'prompt' + if data.sys_prompt != 'default': + message = deepcopy([ + dict( + role='system', content=SYSTEM_PROMPT[data.sys_prompt]) + ] + data.message) + else: + message = deepcopy(data.message) + prompt_input_messages.append(message) + + # critic_loss_refs, ppo_loss_refs, trajectories_refs = [], [], [] + # micro_bs = 32 + # # for start in range(0, len(prompt_input_messages), micro_bs): + # for idx in range(0, len(prompt_input_messages)//micro_bs): + # # breakpoint() + # # update param in last micro batch + # if idx == len(prompt_input_messages) // micro_bs - 1: + # update_param = True + # else: + # update_param = False + + # # generate trajectories + # trajectories_ref = txt_env.rollout.remote(prompt_datas[idx*micro_bs : (idx+1)*micro_bs], + # prompt_input_messages[idx*micro_bs : (idx+1)*micro_bs], + # display=True) + # # trajectories = ray.get(trajectories_ref) + + # # deal with trajectories + # trajectories_ref = ppo_repeater.process.remote(trajectories_ref) + # # trajectories = ray.get(trajectories_ref) + + # # critic & policy learn + # critic_loss_ref = ppo.critic_learn.remote(trajectories_ref, update_param) + + # # ppo_loss, pt_loss = None, None + # if critic_warmup_step <= 0: + # # ppo_loss, pt_loss = ppo.policy_learn.remote(trajectories) + # ppo_loss_ref = ppo.policy_learn.remote(trajectories_ref, update_param) + # pt_loss = None + + # # logger_train.info( + # # f'[Policy Train] Step: {step}, ' + # # f'ppo loss: {ppo_loss}, pretrain loss: {pt_loss}') + # critic_loss_refs.append(critic_loss_ref) + # ppo_loss_refs.append(ppo_loss_ref) + # trajectories_refs.append(trajectories_ref) + + # ppo_losses = flatten_list(ray.get(ppo_loss_refs)) + # critic_losses = flatten_list(ray.get(critic_loss_refs)) + # trajectories = ray.get(trajectories_refs) + # # trajectories = concat_policy_outputs(trajectories) + # ray.get(ppo.sync_model.remote()) + + + critic_loss_refs, ppo_loss_refs, trajectories_refs = [], [], [] + micro_bs = 128 + num_batches = len(prompt_input_messages) // micro_bs + + # Create placeholder lists to manage intermediate results + trajectories_refs_stage1 = [None] * num_batches + trajectories_refs_stage2 = [None] * num_batches + critic_loss_refs_stage3 = [None] * num_batches + critic_time_refs_stage3 = [None] * num_batches + ppo_loss_refs_stage3 = [None] * num_batches + ppo_time_refs_stage3 = [None] * num_batches + + # Stage 1: Generate trajectories + for idx in range(num_batches): + trajectories_ref = txt_env.rollout.remote( + prompt_datas[idx * micro_bs: (idx + 1) * micro_bs], + prompt_input_messages[idx * micro_bs: (idx + 1) * micro_bs], + display=True + ) + trajectories_refs_stage1[idx] = trajectories_ref + + # breakpoint() + # Stage 2: Process trajectories + for idx in range(num_batches): + trajectories_ref = trajectories_refs_stage1[idx] + trajectories_ref = ppo_repeater.process.remote(prompt_datas[idx * micro_bs: (idx + 1) * micro_bs], + trajectories_ref) + trajectories_refs_stage2[idx] = trajectories_ref + + # Stage 3: Critic & Policy learn + for idx in range(num_batches): + update_param = idx == num_batches - 1 + trajectories_ref = trajectories_refs_stage2[idx] + critic_loss_ref, critic_time_ref = ppo.critic_learn.options(num_returns=2).remote(trajectories_ref, update_param) + critic_loss_refs_stage3[idx] = critic_loss_ref + critic_time_refs_stage3[idx] = critic_time_ref + + if critic_warmup_step <= 0: + ppo_loss_ref, ppo_time_ref = ppo.policy_learn.options(num_returns=2).remote(trajectories_ref, update_param) + ppo_loss_refs_stage3[idx] = ppo_loss_ref + ppo_time_refs_stage3[idx] = ppo_time_ref + + # # # Stage 4: Policy learn + # for idx in range(num_batches): + # update_param = idx == num_batches - 1 + # if critic_warmup_step <= 0: + # trajectories_ref = trajectories_refs_stage2[idx] + # ppo_loss_ref = ppo.policy_learn.remote(trajectories_ref, update_param) + # ppo_loss_refs_stage4[idx] = ppo_loss_ref + + # Collect results + ppo_losses = flatten_list(ray.get([ref for ref in ppo_loss_refs_stage3 if ref is not None])) + critic_losses = flatten_list(ray.get(critic_loss_refs_stage3)) + trajectories = ray.get(trajectories_refs_stage2) + ray.get(ppo.sync_model.remote()) + critic_times = ray.get(critic_time_refs_stage3) + policy_times = ray.get(ppo_time_refs_stage3) + + total_time = time.time() - s_t + + # logger_train.info( + # f'[Critic Train] step: {step}, critic loss: {critic_loss}') + # logger_train.info(f'rewards: {trajectories.rewards.mean()}') + # critic_warmup_step -= 1 + + # if config['rollout_config'].get('write_to_file', True): + # if not os.path.exists(f'{work_dir}/rollouts'): + # os.makedirs(f'{work_dir}/rollouts') + # with open(f'{work_dir}/rollouts/step{step}_rollout.log', + # 'a') as file: + # for output_s, r in zip(trajectories.output_str, + # trajectories.rewards): + # file.write(output_s + '\n' + 'Reward: ' + str(r.item()) + + # '\n' + '=' * 30 + '\n') + + # breakpoint() + # critic_loss_avg = sum(critic_losses) / len(critic_losses) + # ppo_loss_avg = sum(ppo_losses) / len(ppo_losses) + + # breakpoint() + + query_tokens = [] + resp_tokens = [] + stage1_times = [] + stage2_times = [] + for traj in trajectories: + query_tokens.append(traj.question_mask.sum(-1).float().mean().item()) + resp_tokens.append(traj.answer_mask.sum(-1).float().mean().item()) + stage1_times.append(traj['stage1_time']) + stage2_times.append(traj['stage2_time']) + query_tokens_mean = sum(query_tokens) / len(query_tokens) + resp_tokens_mean = sum(resp_tokens) / len(resp_tokens) + + summaries = dict( + # reward_mean=trajectories.rewards.mean().item(), + # reward_std=trajectories.rewards.std().item(), + # new_tokens_mean=trajectories.action_mask.sum( + # -1).float().mean().item(), + # new_tokens_std=trajectories.action_mask.sum( + # -1).float().std().item(), + # kl=trajectories.kl.mean().item(), + # entropy=trajectories.entropy.mean().item(), + step=step, + # policy_loss=ppo_loss_avg, + # pretrain_loss=pt_loss, + # critic_loss=critic_loss_avg, + + # query_tokens_mean=trajectories.question_mask.sum( + # -1).float().mean().item(), + # resp_tokens_mean=trajectories.answer_mask.sum( + # -1).float().mean().item(), + # generate_time=gen_time, + # forward_time=fwd_time, + # training_time=train_time, + stage1_time=stage1_times, + stage2_time=stage2_times, + critic_time=critic_times, + policy_time=policy_times, + total_time=total_time, + query_tokens_mean=query_tokens_mean, + resp_tokens_mean=resp_tokens_mean, + ) + with open(f'{work_dir}/train_rlhf.log.jsonl', 'a') as f: + f.write(json.dumps(summaries) + '\n') + logger_train.info(f'[end to end] duration: {time.time() - s_t} s') + + step += 1 + if (step % save_interval == 0) or (step == max_train_step): + policy_model.save(f'{work_dir}/ckpt/policy_model/{step}') + critic_model.save(f'{work_dir}/ckpt/critic_model/{step}') diff --git a/xtuner/rlhf/test_actor_background.py b/xtuner/rlhf/test_actor_background.py new file mode 100644 index 000000000..1a618bc9d --- /dev/null +++ b/xtuner/rlhf/test_actor_background.py @@ -0,0 +1,329 @@ +import argparse +import json +import os +import shutil +import time + +from loguru import logger + +from xtuner.rlhf.config.config import Config +from xtuner.rlhf.coordinator import Coordinator +from xtuner.rlhf.dataset import MessageIter +from xtuner.rlhf.envs import TxtEnv +from xtuner.rlhf.repeaters import KLGAERepeater +from xtuner.rlhf.timer import Timer +from xtuner.rlhf.trainer import PPOTrainer + +import ray +from copy import deepcopy +from xtuner.rlhf.envs.utils import SYSTEM_PROMPT +from policy_output import (PolicyOutput, concat_policy_outputs, + logprobs_from_logits) + +def parse_args(): + parser = argparse.ArgumentParser(description='Train LLM') + parser.add_argument( + '-c', + '--config', + help='config file name or path.', + type=str, + default='examples/rlhf/four_model_vllm_8gpu.py') + parser.add_argument( + '-w', + '--work_dir', + help='the dir to save logs and models', + type=str, + default=None) + parser.add_argument( + '-a', '--address', help='ray head address', type=str, default='auto') + args = parser.parse_args() + return args + + +def validate_config(config: Config): + assert config['model_configs'] is not None + assert config['model_configs']['policy'] is not None + assert config['model_configs']['policy']['model_path'] is not None + assert config['dataset_config'] is not None + assert config['rollout_config'] is not None + assert config['rollout_config']['generate_kwargs'] is not None + assert config['rollout_config']['max_new_tokens'] is not None + + +def flatten_list(nested_list): + flattened = [] + for item in nested_list: + if isinstance(item, list): + flattened.extend(flatten_list(item)) + else: + flattened.append(item) + return flattened + + +if __name__ == '__main__': + args = parse_args() + assert args.config is not None, 'config should not be None' + work_dir = args.work_dir + if work_dir is None: + work_dir = os.getcwd() + '/rlhf_trainlog_' + time.strftime( + '%Y-%m-%d-%H:%M:%S') + work_dir = os.path.abspath(work_dir) + logger.info(f'using work_dir: {work_dir}') + os.makedirs(work_dir, exist_ok=True) + # save original config + shutil.copy2(args.config, f'{work_dir}/{os.path.basename(args.config)}') + import pdb;pdb.set_trace() + logger.add( + f'{work_dir}/train_rlhf.log', + filter=lambda record: record['extra'].get('name') == 'train') + logger_train = logger.bind(name='train') + + config = Config.from_file(args.config) + logger.info('#################### CONFIG BGN ####################') + for k, v in config.items(): + logger.info(f'{k}: {v}') + logger.info('#################### CONFIG END ####################') + + # init model + cluster_address = args.address + if cluster_address != 'auto': + cluster_address = f'ray://{cluster_address}:10001' + logger.info(f'cluster_address={cluster_address}') + coordinator = Coordinator(cluster_address, config) + model_dict = coordinator.create_models() + ref_model = model_dict['reference'] + policy_model = model_dict['policy'] + reward_model = model_dict['reward'] + critic_model = model_dict['critic'] + + # init prompt & pretrain dataset + prompt_dataset_config = config['prompt_dataset_config'] + prompt_mes_iter = MessageIter( + tokenizer=ref_model.tokenizer, **prompt_dataset_config) + pretrain_dataset_config = config.get('pretrain_dataset_config', {}) + pretrain_mes_iter = MessageIter( + tokenizer=ref_model.tokenizer, **pretrain_dataset_config) + + prompt_mes_iter = iter(prompt_mes_iter) + pretrain_mes_iter = iter( + pretrain_mes_iter) if pretrain_mes_iter.message_datasets else None + + # init txt env + rollout_config = config.get('rollout_config', {}) + txt_env = ray.remote(TxtEnv).remote( + policy_model=policy_model, + reward_model=reward_model, + # prompt_mes_iter=prompt_mes_iter, + # pretrain_mes_iter=pretrain_mes_iter, # None + **rollout_config, + ) + # init repeater + repeater_config = config.get('repeater_config', {}) + ppo_repeater = ray.remote(KLGAERepeater).remote( + ref_model=ref_model, + policy_model=policy_model, + critic_model=critic_model, + reward_model=reward_model, + env=txt_env, + **repeater_config, + ) + # init trainer + train_config = config.get('train_config', {}) + ppo = ray.remote(PPOTrainer).options(max_concurrency=2).remote( + policy_model=policy_model, critic_model=critic_model, **train_config) + critic_warmup_step = train_config['critic_warmup_step'] + save_interval = train_config['save_interval'] + max_train_step = train_config.get('max_train_step', float('inf')) + resume_step = train_config.get('resume_step', -1) + critic_warmup_step = min(critic_warmup_step, + critic_warmup_step - resume_step) + async_learn = train_config.get('async_learn', False) + + step = max(0, resume_step) + while step <= max_train_step: + s_t = time.time() + with Timer(f'step {step}: end_to_end'): + + prompt_datas = deepcopy(next(prompt_mes_iter)) + prompt_input_messages = [] + for data in prompt_datas: + assert data.mes_type == 'prompt' + if data.sys_prompt != 'default': + message = deepcopy([ + dict( + role='system', content=SYSTEM_PROMPT[data.sys_prompt]) + ] + data.message) + else: + message = deepcopy(data.message) + prompt_input_messages.append(message) + + # critic_loss_refs, ppo_loss_refs, trajectories_refs = [], [], [] + # micro_bs = 32 + # # for start in range(0, len(prompt_input_messages), micro_bs): + # for idx in range(0, len(prompt_input_messages)//micro_bs): + # # breakpoint() + # # update param in last micro batch + # if idx == len(prompt_input_messages) // micro_bs - 1: + # update_param = True + # else: + # update_param = False + + # # generate trajectories + # trajectories_ref = txt_env.rollout.remote(prompt_datas[idx*micro_bs : (idx+1)*micro_bs], + # prompt_input_messages[idx*micro_bs : (idx+1)*micro_bs], + # display=True) + # # trajectories = ray.get(trajectories_ref) + + # # deal with trajectories + # trajectories_ref = ppo_repeater.process.remote(trajectories_ref) + # # trajectories = ray.get(trajectories_ref) + + # # critic & policy learn + # critic_loss_ref = ppo.critic_learn.remote(trajectories_ref, update_param) + + # # ppo_loss, pt_loss = None, None + # if critic_warmup_step <= 0: + # # ppo_loss, pt_loss = ppo.policy_learn.remote(trajectories) + # ppo_loss_ref = ppo.policy_learn.remote(trajectories_ref, update_param) + # pt_loss = None + + # # logger_train.info( + # # f'[Policy Train] Step: {step}, ' + # # f'ppo loss: {ppo_loss}, pretrain loss: {pt_loss}') + # critic_loss_refs.append(critic_loss_ref) + # ppo_loss_refs.append(ppo_loss_ref) + # trajectories_refs.append(trajectories_ref) + + # ppo_losses = flatten_list(ray.get(ppo_loss_refs)) + # critic_losses = flatten_list(ray.get(critic_loss_refs)) + # trajectories = ray.get(trajectories_refs) + # # trajectories = concat_policy_outputs(trajectories) + # ray.get(ppo.sync_model.remote()) + + + critic_loss_refs, ppo_loss_refs, trajectories_refs = [], [], [] + micro_bs = 128 + num_batches = len(prompt_input_messages) // micro_bs + + # Create placeholder lists to manage intermediate results + trajectories_refs_stage1 = [None] * num_batches + trajectories_refs_stage2 = [None] * num_batches + critic_loss_refs_stage3 = [None] * num_batches + policy_loss_refs_stage3 = [None] * num_batches + + critic_time_refs_stage3 = [None] * num_batches + policy_time_refs_stage3 = [None] * num_batches + + # breakpoint() + ref = txt_env.rollout_background.remote(prompt_input_messages) + ray.get(ref) + + # Stage 1: Generate trajectories + for idx in range(num_batches): + trajectories_ref = txt_env.get_generate_finish.remote(micro_bs) + trajectories_refs_stage1[idx] = trajectories_ref + + # breakpoint() + # Stage 2: Process trajectories + for idx in range(num_batches): + trajectories_ref = trajectories_refs_stage1[idx] + trajectories_ref = ppo_repeater.process.remote(prompt_datas[idx * micro_bs: (idx + 1) * micro_bs], + trajectories_ref) + trajectories_refs_stage2[idx] = trajectories_ref + + # Stage 3: Critic & Policy learn + for idx in range(num_batches): + update_param = idx == num_batches - 1 + trajectories_ref = trajectories_refs_stage2[idx] + critic_loss_ref, critic_time_ref = ppo.critic_learn.options(num_returns=2).remote(trajectories_ref, update_param) + critic_loss_refs_stage3[idx] = critic_loss_ref + critic_time_refs_stage3[idx] = critic_time_ref + + if critic_warmup_step <= 0: + policy_loss_ref, policy_time_ref = ppo.policy_learn.options(num_returns=2).remote(trajectories_ref, update_param) + policy_loss_refs_stage3[idx] = policy_loss_ref + policy_time_refs_stage3[idx] = policy_time_ref + + # Collect results + policy_losses = flatten_list(ray.get([ref for ref in policy_loss_refs_stage3 if ref is not None])) + critic_losses = flatten_list(ray.get(critic_loss_refs_stage3)) + trajectories = ray.get(trajectories_refs_stage2) + ray.get(ppo.sync_model.remote()) + critic_times = ray.get(critic_time_refs_stage3) + policy_times = ray.get(policy_time_refs_stage3) + + total_time = time.time() - s_t + + # logger_train.info( + # f'[Critic Train] step: {step}, critic loss: {critic_loss}') + # logger_train.info(f'rewards: {trajectories.rewards.mean()}') + # critic_warmup_step -= 1 + + # if config['rollout_config'].get('write_to_file', True): + # if not os.path.exists(f'{work_dir}/rollouts'): + # os.makedirs(f'{work_dir}/rollouts') + # with open(f'{work_dir}/rollouts/step{step}_rollout.log', + # 'a') as file: + # for output_s, r in zip(trajectories.output_str, + # trajectories.rewards): + # file.write(output_s + '\n' + 'Reward: ' + str(r.item()) + + # '\n' + '=' * 30 + '\n') + + query_tokens = [] + resp_tokens = [] + rewards = [] + stage1_times = [] + stage2_times = [] + for traj in trajectories: + query_tokens.append(traj.question_mask.sum(-1).float().mean().item()) + resp_tokens.append(traj.answer_mask.sum(-1).float().mean().item()) + rewards.append(traj.rewards.mean().item()) + stage1_times.append(traj['stage1_time']) + stage2_times.append(traj['stage2_time']) + + query_tokens_mean = sum(query_tokens) / len(query_tokens) + resp_tokens_mean = sum(resp_tokens) / len(resp_tokens) + reward_mean = sum(rewards) / len(rewards) + + policy_loss_mean = sum(policy_losses) / len(policy_losses) + critic_loss_mean = sum(critic_losses) / len(critic_losses) + + summaries = dict( + # reward_mean=trajectories.rewards.mean().item(), + # reward_std=trajectories.rewards.std().item(), + # new_tokens_mean=trajectories.action_mask.sum( + # -1).float().mean().item(), + # new_tokens_std=trajectories.action_mask.sum( + # -1).float().std().item(), + # kl=trajectories.kl.mean().item(), + # entropy=trajectories.entropy.mean().item(), + step=step, + policy_loss=policy_loss_mean, + # pretrain_loss=pt_loss, + critic_loss=critic_loss_mean, + + # query_tokens_mean=trajectories.question_mask.sum( + # -1).float().mean().item(), + # resp_tokens_mean=trajectories.answer_mask.sum( + # -1).float().mean().item(), + # generate_time=gen_time, + # forward_time=fwd_time, + # training_time=train_time, + stage1_time=stage1_times, + stage2_time=stage2_times, + critic_time=critic_times, + policy_time=policy_times, + total_time=total_time, + reward_mean=reward_mean, + query_tokens=query_tokens_mean, + resp_tokens=resp_tokens_mean, + ) + with open(f'{work_dir}/train_rlhf.log.jsonl', 'a') as f: + f.write(json.dumps(summaries) + '\n') + f.flush() + logger_train.info(f'[end to end] duration: {time.time() - s_t} s') + + step += 1 + if (step % save_interval == 0) or (step == max_train_step): + policy_model.save(f'{work_dir}/ckpt/policy_model/{step}') + critic_model.save(f'{work_dir}/ckpt/critic_model/{step}') diff --git a/xtuner/rlhf/test_policy_ref_pipe.py b/xtuner/rlhf/test_policy_ref_pipe.py new file mode 100644 index 000000000..9ee0e7420 --- /dev/null +++ b/xtuner/rlhf/test_policy_ref_pipe.py @@ -0,0 +1,346 @@ +import argparse +import json +import os +import shutil +import time + +from loguru import logger + +from xtuner.rlhf.config.config import Config +from xtuner.rlhf.coordinator import Coordinator +from xtuner.rlhf.dataset import MessageIter +from xtuner.rlhf.envs import TxtEnv +from xtuner.rlhf.repeaters import KLGAERepeater +from xtuner.rlhf.timer import Timer +from xtuner.rlhf.trainer import PPOTrainer + +import ray + +def parse_args(): + parser = argparse.ArgumentParser(description='Train LLM') + parser.add_argument( + '-c', + '--config', + help='config file name or path.', + type=str, + default='examples/rlhf/four_model_vllm_8gpu.py') + parser.add_argument( + '-w', + '--work_dir', + help='the dir to save logs and models', + type=str, + default=None) + parser.add_argument( + '-a', '--address', help='ray head address', type=str, default='auto') + args = parser.parse_args() + return args + + +def validate_config(config: Config): + assert config['model_configs'] is not None + assert config['model_configs']['policy'] is not None + assert config['model_configs']['policy']['model_path'] is not None + assert config['dataset_config'] is not None + assert config['rollout_config'] is not None + assert config['rollout_config']['generate_kwargs'] is not None + assert config['rollout_config']['max_new_tokens'] is not None + + +from policy_output import (PolicyOutput, concat_policy_outputs, + logprobs_from_logits) + +@ray.remote +def generate_async(policy_model, prompt_input_messages, policy_micro_bs, max_new_tokens, generate_kwargs): + breakpoint() + with Timer('policy_model.generate'): + trajectories = policy_model.generate( + inputs=prompt_input_messages, + micro_batch_size=policy_micro_bs, + step=max_new_tokens, + output_str=True, + generate_kwargs=generate_kwargs) + logger.info(f'[Generate] len: {len(prompt_input_messages)}') + return trajectories + +@ray.remote +def ref_model_async(trajectories_refs, ref_model, ref_micro_bs): + # trajectories = ray.get(trajectories_refs, timeout=None) + outputs = ray.get(trajectories_refs, timeout=None) + # breakpoint() + padding_token_map = { + 'output_ids': 0, + } + trajectories = concat_policy_outputs(outputs, padding_token_map) + + with Timer('ref_model.infer'): + ref_output = ref_model.infer( + inputs=trajectories.output_ids, + micro_batch_size=ref_micro_bs, + attention_mask=trajectories.attention_mask, + output_logits=False, + output_logprobs=True) + + return ref_output + +if __name__ == '__main__': + args = parse_args() + assert args.config is not None, 'config should not be None' + work_dir = args.work_dir + if work_dir is None: + work_dir = os.getcwd() + '/rlhf_trainlog_' + time.strftime( + '%Y-%m-%d-%H:%M:%S') + work_dir = os.path.abspath(work_dir) + logger.info(f'using work_dir: {work_dir}') + os.makedirs(work_dir, exist_ok=True) + # save original config + shutil.copy2(args.config, f'{work_dir}/{os.path.basename(args.config)}') + + logger.add( + f'{work_dir}/train_rlhf.log', + filter=lambda record: record['extra'].get('name') == 'train') + logger_train = logger.bind(name='train') + + config = Config.from_file(args.config) + logger.info('#################### CONFIG BGN ####################') + for k, v in config.items(): + logger.info(f'{k}: {v}') + logger.info('#################### CONFIG END ####################') + + # init model + cluster_address = args.address + if cluster_address != 'auto': + cluster_address = f'ray://{cluster_address}:10001' + logger.info(f'cluster_address={cluster_address}') + coordinator = Coordinator(cluster_address, config) + model_dict = coordinator.create_models() + ref_model = model_dict['reference'] + policy_model = model_dict['policy'] + # reward_model = model_dict['reward'] + # critic_model = model_dict['critic'] + + + ref_micro_bs = config['repeater_config']['ref_micro_bs'] + + ## test generate + max_new_tokens = config['rollout_config']['max_new_tokens'] + policy_micro_bs = config['rollout_config']['policy_micro_bs'] + generate_kwargs = config['rollout_config']['generate_kwargs'] + + from copy import deepcopy + from xtuner.rlhf.envs.utils import SYSTEM_PROMPT + + # init prompt & pretrain dataset + prompt_dataset_config = config['prompt_dataset_config'] + prompt_mes_iter = MessageIter( + tokenizer=ref_model.tokenizer, **prompt_dataset_config) + + prompt_mes_iter = iter(prompt_mes_iter) + prompt_datas = deepcopy(next(prompt_mes_iter)) + prompt_input_messages = [] + for data in prompt_datas: + assert data.mes_type == 'prompt' + if data.sys_prompt != 'default': + message = deepcopy([ + dict( + role='system', content=SYSTEM_PROMPT[data.sys_prompt]) + ] + data.message) + else: + message = deepcopy(data.message) + prompt_input_messages.append(message) + + + # while True: + # with Timer('end to end'): + # with Timer('policy_model.generate'): + # trajectories = policy_model.generate( + # inputs=prompt_input_messages, + # micro_batch_size=policy_micro_bs, + # step=max_new_tokens, + # output_str=True, + # generate_kwargs=generate_kwargs) + # logger.info(f'[Generate] len: {len(prompt_input_messages)}') + + # with Timer('ref_model.infer'): + # ref_output = ref_model.infer( + # inputs=trajectories.output_ids, + # micro_batch_size=ref_micro_bs, + # attention_mask=trajectories.attention_mask, + # output_logits=False, + # output_logprobs=True) + + + + while True: + with Timer('end to end'): + with Timer('policy_model.generate'): + object_refs_list = [] + for start in range(0, len(prompt_input_messages), 64): + input_messages = prompt_input_messages[start : start+64] + # ref = generate_async.remote(policy_model, + # prompt_input_messages, + # policy_micro_bs, + # max_new_tokens, + # generate_kwargs) + + # trajectories_refs.append( + # generate_async.remote(policy_model, + # prompt_input_messages, + # policy_micro_bs, + # max_new_tokens, + # generate_kwargs) + # ) + + object_refs_list.append(policy_model.generate_async( + inputs=input_messages, + micro_batch_size=policy_micro_bs, + step=max_new_tokens, + output_str=True, + generate_kwargs=generate_kwargs)) + + second_object_refs = [] + with Timer('ref_model.infer'): + for object_refs in object_refs_list: + # second_object_refs.extend(ref_model.infer_from_future( + # object_refs, + # micro_batch_size=ref_micro_bs, + # output_logits=False, + # output_logprobs=True)) + second_object_refs.append(ref_model_async.remote(object_refs, ref_model, ref_micro_bs)) + # breakpoint() + output = ray.get(second_object_refs) + + # import torch + # global_bs = 128 + # micro_bs = config['repeater_config']['ref_micro_bs'] + # output_ids = torch.randint(low=0, high=100, size=(global_bs, 2048), dtype=torch.int64, device='cpu') + # attention_mask = torch.zeros((global_bs, 2048), dtype=torch.int32, device='cpu') + # attention_mask[:, -512:] = 1 + # while True: + # with Timer('ref_model.infer'): + # ref_output = ref_model.infer( + # inputs=output_ids, + # micro_batch_size=micro_bs, + # attention_mask=attention_mask, + # output_logits=False, + # output_logprobs=True) + + + # init prompt & pretrain dataset + prompt_dataset_config = config['prompt_dataset_config'] + prompt_mes_iter = MessageIter( + tokenizer=ref_model.tokenizer, **prompt_dataset_config) + pretrain_dataset_config = config.get('pretrain_dataset_config', {}) + pretrain_mes_iter = MessageIter( + tokenizer=ref_model.tokenizer, **pretrain_dataset_config) + + # init txt env + rollout_config = config.get('rollout_config', {}) + txt_env = TxtEnv( + policy_model=policy_model, + reward_model=reward_model, + prompt_mes_iter=prompt_mes_iter, + pretrain_mes_iter=pretrain_mes_iter, # None + **rollout_config, + ) + # init repeater + repeater_config = config.get('repeater_config', {}) + ppo_repeater = KLGAERepeater( + ref_model=ref_model, + policy_model=policy_model, + critic_model=critic_model, + env=txt_env, + **repeater_config, + ) + # init trainer + train_config = config.get('train_config', {}) + ppo = PPOTrainer( + policy_model=policy_model, critic_model=critic_model, **train_config) + critic_warmup_step = train_config['critic_warmup_step'] + save_interval = train_config['save_interval'] + max_train_step = train_config.get('max_train_step', float('inf')) + resume_step = train_config.get('resume_step', -1) + critic_warmup_step = min(critic_warmup_step, + critic_warmup_step - resume_step) + async_learn = train_config.get('async_learn', False) + + step = max(0, resume_step) + while step <= max_train_step: + s_t = time.time() + with Timer(f'step {step}: end_to_end'): + # generate trajectories + gen_start = time.time() + trajectories = txt_env.rollout(display=True) + gen_time = time.time() - gen_start + + # deal with trajectories + fwd_start = time.time() + trajectories = ppo_repeater.process(trajectories) + fwd_time = time.time() - fwd_start + + train_start = time.time() + # critic & policy learn + if async_learn: + critic_loss_ref = ppo.critic_learn_async(trajectories) + else: + critic_train_start = time.time() + critic_loss = ppo.critic_learn(trajectories) + critic_train_time = time.time() - critic_train_start + + ppo_loss, pt_loss = None, None + if critic_warmup_step <= 0: + ppo_loss, pt_loss = ppo.policy_learn(trajectories) + + logger_train.info( + f'[Policy Train] Step: {step}, ' + f'ppo loss: {ppo_loss}, pretrain loss: {pt_loss}') + + if async_learn: + critic_loss = ppo.critic_learn_get(critic_loss_ref) + train_time = time.time() - train_start + total_time = time.time() - s_t + + logger_train.info( + f'[Critic Train] step: {step}, critic loss: {critic_loss}') + logger_train.info(f'rewards: {trajectories.rewards.mean()}') + critic_warmup_step -= 1 + + if config['rollout_config'].get('write_to_file', True): + if not os.path.exists(f'{work_dir}/rollouts'): + os.makedirs(f'{work_dir}/rollouts') + with open(f'{work_dir}/rollouts/step{step}_rollout.log', + 'a') as file: + for output_s, r in zip(trajectories.output_str, + trajectories.rewards): + file.write(output_s + '\n' + 'Reward: ' + str(r.item()) + + '\n' + '=' * 30 + '\n') + summaries = dict( + reward_mean=trajectories.rewards.mean().item(), + reward_std=trajectories.rewards.std().item(), + new_tokens_mean=trajectories.action_mask.sum( + -1).float().mean().item(), + new_tokens_std=trajectories.action_mask.sum( + -1).float().std().item(), + kl=trajectories.kl.mean().item(), + entropy=trajectories.entropy.mean().item(), + step=step, + policy_loss=ppo_loss, + pretrain_loss=pt_loss, + critic_loss=critic_loss, + + query_tokens_mean=trajectories.question_mask.sum( + -1).float().mean().item(), + resp_tokens_mean=trajectories.answer_mask.sum( + -1).float().mean().item(), + generate_time=gen_time, + forward_time=fwd_time, + training_time=train_time, + total_time=total_time, + ) + with open(f'{work_dir}/train_rlhf.log.jsonl', 'a') as f: + f.write(json.dumps(summaries) + '\n') + logger_train.info(f'[end to end] duration: {time.time() - s_t} s') + + step += 1 + if (step % save_interval == 0) or (step == max_train_step): + policy_model.save(f'{work_dir}/ckpt/policy_model/{step}') + critic_model.save(f'{work_dir}/ckpt/critic_model/{step}') diff --git a/xtuner/rlhf/test_ref.py b/xtuner/rlhf/test_ref.py new file mode 100644 index 000000000..f1e9d95ef --- /dev/null +++ b/xtuner/rlhf/test_ref.py @@ -0,0 +1,218 @@ +import argparse +import json +import os +import shutil +import time + +from loguru import logger + +from xtuner.rlhf.config.config import Config +from xtuner.rlhf.coordinator import Coordinator +from xtuner.rlhf.dataset import MessageIter +from xtuner.rlhf.envs import TxtEnv +from xtuner.rlhf.repeaters import KLGAERepeater +from xtuner.rlhf.timer import Timer +from xtuner.rlhf.trainer import PPOTrainer + + +def parse_args(): + parser = argparse.ArgumentParser(description='Train LLM') + parser.add_argument( + '-c', + '--config', + help='config file name or path.', + type=str, + default='examples/rlhf/four_model_vllm_8gpu.py') + parser.add_argument( + '-w', + '--work_dir', + help='the dir to save logs and models', + type=str, + default=None) + parser.add_argument( + '-a', '--address', help='ray head address', type=str, default='auto') + args = parser.parse_args() + return args + + +def validate_config(config: Config): + assert config['model_configs'] is not None + assert config['model_configs']['policy'] is not None + assert config['model_configs']['policy']['model_path'] is not None + assert config['dataset_config'] is not None + assert config['rollout_config'] is not None + assert config['rollout_config']['generate_kwargs'] is not None + assert config['rollout_config']['max_new_tokens'] is not None + + +if __name__ == '__main__': + args = parse_args() + assert args.config is not None, 'config should not be None' + work_dir = args.work_dir + if work_dir is None: + work_dir = os.getcwd() + '/rlhf_trainlog_' + time.strftime( + '%Y-%m-%d-%H:%M:%S') + work_dir = os.path.abspath(work_dir) + logger.info(f'using work_dir: {work_dir}') + os.makedirs(work_dir, exist_ok=True) + # save original config + shutil.copy2(args.config, f'{work_dir}/{os.path.basename(args.config)}') + + logger.add( + f'{work_dir}/train_rlhf.log', + filter=lambda record: record['extra'].get('name') == 'train') + logger_train = logger.bind(name='train') + + config = Config.from_file(args.config) + logger.info('#################### CONFIG BGN ####################') + for k, v in config.items(): + logger.info(f'{k}: {v}') + logger.info('#################### CONFIG END ####################') + + # init model + cluster_address = args.address + if cluster_address != 'auto': + cluster_address = f'ray://{cluster_address}:10001' + logger.info(f'cluster_address={cluster_address}') + coordinator = Coordinator(cluster_address, config) + model_dict = coordinator.create_models() + ref_model = model_dict['reference'] + policy_model = model_dict['policy'] + reward_model = model_dict['reward'] + critic_model = model_dict['critic'] + + import torch + global_bs = 128 + micro_bs = config['repeater_config']['ref_micro_bs'] + output_ids = torch.randint(low=0, high=100, size=(global_bs, 2048), dtype=torch.int64, device='cpu') + attention_mask = torch.zeros((global_bs, 2048), dtype=torch.int32, device='cpu') + attention_mask[:, -512:] = 1 + while True: + with Timer('ref_model.infer'): + ref_output = ref_model.infer( + inputs=output_ids, + micro_batch_size=micro_bs, + attention_mask=attention_mask, + output_logits=False, + output_logprobs=True) + + + # init prompt & pretrain dataset + prompt_dataset_config = config['prompt_dataset_config'] + prompt_mes_iter = MessageIter( + tokenizer=ref_model.tokenizer, **prompt_dataset_config) + pretrain_dataset_config = config.get('pretrain_dataset_config', {}) + pretrain_mes_iter = MessageIter( + tokenizer=ref_model.tokenizer, **pretrain_dataset_config) + + # init txt env + rollout_config = config.get('rollout_config', {}) + txt_env = TxtEnv( + policy_model=policy_model, + reward_model=reward_model, + prompt_mes_iter=prompt_mes_iter, + pretrain_mes_iter=pretrain_mes_iter, # None + **rollout_config, + ) + # init repeater + repeater_config = config.get('repeater_config', {}) + ppo_repeater = KLGAERepeater( + ref_model=ref_model, + policy_model=policy_model, + critic_model=critic_model, + env=txt_env, + **repeater_config, + ) + # init trainer + train_config = config.get('train_config', {}) + ppo = PPOTrainer( + policy_model=policy_model, critic_model=critic_model, **train_config) + critic_warmup_step = train_config['critic_warmup_step'] + save_interval = train_config['save_interval'] + max_train_step = train_config.get('max_train_step', float('inf')) + resume_step = train_config.get('resume_step', -1) + critic_warmup_step = min(critic_warmup_step, + critic_warmup_step - resume_step) + async_learn = train_config.get('async_learn', False) + + step = max(0, resume_step) + while step <= max_train_step: + s_t = time.time() + with Timer(f'step {step}: end_to_end'): + # generate trajectories + gen_start = time.time() + trajectories = txt_env.rollout(display=True) + gen_time = time.time() - gen_start + + # deal with trajectories + fwd_start = time.time() + trajectories = ppo_repeater.process(trajectories) + fwd_time = time.time() - fwd_start + + train_start = time.time() + # critic & policy learn + if async_learn: + critic_loss_ref = ppo.critic_learn_async(trajectories) + else: + critic_train_start = time.time() + critic_loss = ppo.critic_learn(trajectories) + critic_train_time = time.time() - critic_train_start + + ppo_loss, pt_loss = None, None + if critic_warmup_step <= 0: + ppo_loss, pt_loss = ppo.policy_learn(trajectories) + + logger_train.info( + f'[Policy Train] Step: {step}, ' + f'ppo loss: {ppo_loss}, pretrain loss: {pt_loss}') + + if async_learn: + critic_loss = ppo.critic_learn_get(critic_loss_ref) + train_time = time.time() - train_start + total_time = time.time() - s_t + + logger_train.info( + f'[Critic Train] step: {step}, critic loss: {critic_loss}') + logger_train.info(f'rewards: {trajectories.rewards.mean()}') + critic_warmup_step -= 1 + + if config['rollout_config'].get('write_to_file', True): + if not os.path.exists(f'{work_dir}/rollouts'): + os.makedirs(f'{work_dir}/rollouts') + with open(f'{work_dir}/rollouts/step{step}_rollout.log', + 'a') as file: + for output_s, r in zip(trajectories.output_str, + trajectories.rewards): + file.write(output_s + '\n' + 'Reward: ' + str(r.item()) + + '\n' + '=' * 30 + '\n') + summaries = dict( + reward_mean=trajectories.rewards.mean().item(), + reward_std=trajectories.rewards.std().item(), + new_tokens_mean=trajectories.action_mask.sum( + -1).float().mean().item(), + new_tokens_std=trajectories.action_mask.sum( + -1).float().std().item(), + kl=trajectories.kl.mean().item(), + entropy=trajectories.entropy.mean().item(), + step=step, + policy_loss=ppo_loss, + pretrain_loss=pt_loss, + critic_loss=critic_loss, + + query_tokens_mean=trajectories.question_mask.sum( + -1).float().mean().item(), + resp_tokens_mean=trajectories.answer_mask.sum( + -1).float().mean().item(), + generate_time=gen_time, + forward_time=fwd_time, + training_time=train_time, + total_time=total_time, + ) + with open(f'{work_dir}/train_rlhf.log.jsonl', 'a') as f: + f.write(json.dumps(summaries) + '\n') + logger_train.info(f'[end to end] duration: {time.time() - s_t} s') + + step += 1 + if (step % save_interval == 0) or (step == max_train_step): + policy_model.save(f'{work_dir}/ckpt/policy_model/{step}') + critic_model.save(f'{work_dir}/ckpt/critic_model/{step}') diff --git a/xtuner/rlhf/trainer/ppo.py b/xtuner/rlhf/trainer/ppo.py index 655c91f94..545230ed9 100644 --- a/xtuner/rlhf/trainer/ppo.py +++ b/xtuner/rlhf/trainer/ppo.py @@ -4,6 +4,7 @@ from ..model_server.base_model_server import BaseModelServer from ..timer import Timer +import time class PPOTrainer: @@ -44,7 +45,9 @@ def __init__( self.critic_criterion = critic_criterion - def policy_learn(self, trajectories): + def policy_learn(self, trajectories, update_param=True): + start = time.time() + if self.policy_minibatch is None: self.policy_minibatch = len(trajectories.output_ids) assert len(trajectories.output_ids) % self.policy_minibatch == 0 @@ -101,7 +104,8 @@ def policy_learn(self, trajectories): # position_ids=train_position_ids, criterion=train_criterion, loss_weights=loss_weights, - micro_batch_size=micro_batch_size) + micro_batch_size=micro_batch_size, + update_param=update_param) if isinstance(p_loss, list): ppo_loss.append(p_loss[0].item()) pretrain_loss.append(p_loss[1].item()) @@ -114,11 +118,19 @@ def policy_learn(self, trajectories): f'[Policy Train] prompt data: {train_input_ids[0].shape}, ppo loss: {p_loss.item()}' # noqa: E501 ) + # with Timer('policy_model.sync_model'): + # self.policy_model.sync_model() + + # return ppo_loss, pretrain_loss + end = time.time() + return ppo_loss, end-start + + def sync_model(self): with Timer('policy_model.sync_model'): self.policy_model.sync_model() - return ppo_loss, pretrain_loss - def critic_learn(self, trajectories): + def critic_learn(self, trajectories, update_param=True): + start = time.time() if self.critic_minibatch is None: self.critic_minibatch = len(trajectories.output_ids) assert len(trajectories.output_ids) % self.critic_minibatch == 0 @@ -139,11 +151,13 @@ def critic_learn(self, trajectories): attention_mask=critic_batch_inputs['attention_mask'], criterion=self.critic_criterion, micro_batch_size=self.critic_micro_bs, + update_param=update_param, ) logger.info(f'[Critic train] {self.critic_minibatch} batch, ' f'critic loss: {v_loss.item()}') critic_loss.append(v_loss.item()) - return critic_loss + end = time.time() + return critic_loss, end-start def _critic_learn_prepare(self, step_i, learn_i, trajectories, critic_updates): @@ -165,7 +179,7 @@ def _critic_learn_prepare(self, step_i, learn_i, trajectories, ) return critic_batch_inputs, labels - def critic_learn_async(self, trajectories): + def critic_learn_async(self, trajectories, update_param=True): if self.critic_minibatch is None: self.critic_minibatch = len(trajectories.output_ids) assert len(trajectories.output_ids) % self.critic_minibatch == 0 @@ -182,6 +196,7 @@ def critic_learn_async(self, trajectories): attention_mask=critic_batch_inputs['attention_mask'], criterion=self.critic_criterion, micro_batch_size=self.critic_micro_bs, + update_param=update_param, ) logger.info(f'[critic train] {self.critic_minibatch} batch') critic_loss.append(v_loss_ref) From 931526dc05dc7254a330348b1e6385bf98e291c4 Mon Sep 17 00:00:00 2001 From: liangkaihuan Date: Wed, 24 Jul 2024 09:57:50 +0800 Subject: [PATCH 20/37] update ownership of train_1node --- scripts/train_1node.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/train_1node.sh b/scripts/train_1node.sh index 28e7ea3f6..637b435cb 100644 --- a/scripts/train_1node.sh +++ b/scripts/train_1node.sh @@ -27,4 +27,4 @@ mkdir -p $work_dirs # python xtuner/rlhf/main.py -c $config_file -w $work_dirs 2>&1 | tee $work_dirs/main-$start_time.log # python xtuner/rlhf/test_actor.py -c $config_file -w $work_dirs 2>&1 | tee $work_dirs/main-$start_time.log python xtuner/rlhf/test_actor_background.py -c $config_file -w $work_dirs -# 2>&1 | tee $work_dirs/main-$start_time.log +2>&1 | tee $work_dirs/main-$start_time.log From be56c66d1cf483ef6fba3e0f89d480e761e1856a Mon Sep 17 00:00:00 2001 From: liangkaihuan Date: Fri, 26 Jul 2024 09:57:13 +0800 Subject: [PATCH 21/37] add thread & queue join in generate background --- .../rlhf/model_backend/vllm_model_runner.py | 21 ++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/xtuner/rlhf/model_backend/vllm_model_runner.py b/xtuner/rlhf/model_backend/vllm_model_runner.py index 31d6ffd70..4c589fd4a 100644 --- a/xtuner/rlhf/model_backend/vllm_model_runner.py +++ b/xtuner/rlhf/model_backend/vllm_model_runner.py @@ -28,6 +28,8 @@ class VllmGenerator: def __init__(self, model_config) -> None: self.model_config: dict = model_config + self.generate_thread = None + self.queue = queue.Queue() # Adapted from https://github.com/OpenLLMAI/OpenRLHF/blob/v0.2.5/openrlhf/trainer/ray/vllm_engine.py # noqa: E501 def initialize(self) -> None: @@ -257,13 +259,18 @@ def generate_background( else: raise ValueError(f'Unsupported inputs with type({type(inputs)})') + # self.max_inputs_length = 1024 self.max_inputs_length = max_inputs_length self.output_str = output_str self.output_logits = output_logits self.output_attentions = output_attentions self.output_hidden_states = output_hidden_states self.generate_kwargs = generate_kwargs - self.queue = queue.Queue() + self.batch_size = len(prompt) + + if self.generate_thread is not None: + self.generate_thread.join() + self.queue.join() self.generate_thread = threading.Thread(target=self.llm.generate_to_queue, kwargs={'prompt_token_ids':prompt, @@ -285,6 +292,14 @@ def pad_list_with_pad_token(int_list, max_length, pad_token_id): return padded_list else: return int_list + + # def pad_list_with_pad_token_right(int_list, max_length, pad_token_id): + # if len(int_list) < max_length: + # num_pad_token_to_add = max_length - len(int_list) + # padded_list = int_list + [pad_token_id] * num_pad_token_to_add + # return padded_list + # else: + # return int_list policy_outputs = [] for _, req_output in enumerate(req_outputs): @@ -295,6 +310,9 @@ def pad_list_with_pad_token(int_list, max_length, pad_token_id): output_token_ids = [ item for item in req_output.outputs[0].token_ids ] + # output_token_ids = pad_list_with_pad_token_right(output_token_ids, 1024, + # self.tokenizer.pad_token_id) + output_ids = input_ids + output_token_ids # concat output['input_ids'] = torch.Tensor(input_ids).to( torch.long).unsqueeze(0) @@ -326,6 +344,7 @@ def pad_list_with_pad_token(int_list, max_length, pad_token_id): clean_up_tokenization_spaces=False, ) output['output_str'] = [output_str] + output['req_ids'] = [int(req_output.request_id)%self.batch_size] output.to('cpu') policy_outputs.append(output) From 67eff37e71d4d2ca4d6132025338b412a39f3872 Mon Sep 17 00:00:00 2001 From: liangkaihuan Date: Fri, 26 Jul 2024 09:57:54 +0800 Subject: [PATCH 22/37] fix reward model input data --- xtuner/rlhf/repeaters/kl_gae.py | 66 ++++++++++++++++----------------- 1 file changed, 31 insertions(+), 35 deletions(-) diff --git a/xtuner/rlhf/repeaters/kl_gae.py b/xtuner/rlhf/repeaters/kl_gae.py index a97b9e6c5..30fe992c9 100644 --- a/xtuner/rlhf/repeaters/kl_gae.py +++ b/xtuner/rlhf/repeaters/kl_gae.py @@ -139,6 +139,8 @@ def _get_kl_rewards(self, prompt_datas, trajectories: PolicyOutput): norm_reward_score = (clipped_rewards - self.running_states.mean) / ( self.running_states.var.sqrt() + 1e-8) + else: + norm_reward_score = clipped_rewards action_mask = trajectories.action_mask num_actions = action_mask.size(1) @@ -237,45 +239,39 @@ def get_advantages_and_returns( # default get_reward() is blocking. # get_reward_async() needs to call get_reward_collect() def get_reward_async(self, prompt_datas, policyout): - # rm_input_messages = [] + rm_input_messages = [] # for i in range(len(prompt_datas)): - # if prompt_datas[i].mes_type != 'prompt': - # continue - # if (prompt_datas[i].rm_prompt != - # 'default') or (prompt_datas[i].sys_prompt != 'default'): - # # Conditional Reward Model - # # for queries from different domains, use appropriate conditional system prompts # noqa: E501 - # # From Alignment section of the InternLM2 Technical Report: - # # https://arxiv.org/pdf/2403.17297 - # if prompt_datas[i].rm_prompt != 'default': - # prompt = prompt_datas[i].rm_prompt - # else: - # prompt = prompt_datas[i].sys_prompt - # cur_rm_data = [ - # dict(role='system', content=SYSTEM_PROMPT[prompt]) - # ] + prompt_datas[i].message + [ - # dict( - # role='assistant', content=policyout.output_ans_str[i]) - # ] - # else: - # cur_rm_data = prompt_datas[i].message + [ - # dict( - # role='assistant', content=policyout.output_ans_str[i]) - # ] - # rm_input_messages.append(cur_rm_data) - - # logger.info(f'[For Reward]: {rm_input_messages[0]}') - - # with Timer('reward_model.infer_async'): - # reward_output_ref = self.reward_model.infer_async( - # rm_input_messages, - # output_logprobs=False, - # micro_batch_size=self.reward_micro_bs) + for i, req_id in enumerate(policyout.req_ids): + if prompt_datas[req_id].mes_type != 'prompt': + continue + if (prompt_datas[req_id].rm_prompt != + 'default') or (prompt_datas[req_id].sys_prompt != 'default'): + # Conditional Reward Model + # for queries from different domains, use appropriate conditional system prompts # noqa: E501 + # From Alignment section of the InternLM2 Technical Report: + # https://arxiv.org/pdf/2403.17297 + if prompt_datas[req_id].rm_prompt != 'default': + prompt = prompt_datas[req_id].rm_prompt + else: + prompt = prompt_datas[req_id].sys_prompt + cur_rm_data = [ + dict(role='system', content=SYSTEM_PROMPT[prompt]) + ] + prompt_datas[req_id].message + [ + dict( + role='assistant', content=policyout.output_ans_str[i]) + ] + else: + cur_rm_data = prompt_datas[req_id].message + [ + dict( + role='assistant', content=policyout.output_ans_str[i]) + ] + rm_input_messages.append(cur_rm_data) + + logger.info(f'[For Reward]: {rm_input_messages[0]}') with Timer('reward_model.infer_async'): reward_output_ref = self.reward_model.infer_async( - inputs=policyout.output_ids, - attention_mask=policyout.attention_mask, + rm_input_messages, output_logprobs=False, micro_batch_size=self.reward_micro_bs) return reward_output_ref From 671256a3d0cfdda8e728231afa16ad9a7c65b8ed Mon Sep 17 00:00:00 2001 From: liangkaihuan Date: Mon, 29 Jul 2024 14:20:20 +0800 Subject: [PATCH 23/37] add debug log to align 1.8B --- examples/rlhf/internlm2_1_8b_test_8gpu.py | 25 +++++++++++-------- xtuner/rlhf/model_backend/hf_model_runner.py | 20 ++++++++++++--- .../rlhf/model_backend/vllm_model_runner.py | 14 +++++------ .../rlhf/model_server/reward_model_server.py | 18 +++++++++++++ 4 files changed, 55 insertions(+), 22 deletions(-) diff --git a/examples/rlhf/internlm2_1_8b_test_8gpu.py b/examples/rlhf/internlm2_1_8b_test_8gpu.py index dccd8af82..d35e8223e 100644 --- a/examples/rlhf/internlm2_1_8b_test_8gpu.py +++ b/examples/rlhf/internlm2_1_8b_test_8gpu.py @@ -14,8 +14,8 @@ TRAIN_MICRO_BATCH_SIZE = 2 ZERO_STAGE = 3 -POLICY_DP_SIZE = 2 -CRITIC_DP_SIZE = 2 +POLICY_DP_SIZE = 1 +CRITIC_DP_SIZE = 1 POLICY_GRADIENT_ACC_STEP = (PROMPT_BATCH_SIZE + PRETRAIN_BATCH_SIZE ) // POLICY_DP_SIZE // TRAIN_MICRO_BATCH_SIZE CRITIC_GRADIENT_ACC_STEP = PROMPT_BATCH_SIZE // CRITIC_DP_SIZE // TRAIN_MICRO_BATCH_SIZE # noqa: E501 @@ -32,6 +32,7 @@ assert (PROMPT_BATCH_SIZE) % (TRAIN_MICRO_BATCH_SIZE * CRITIC_DP_SIZE) == 0 MODEL_DTYPE = 'auto' +USE_FLASH_ATTN = False tokenizer_config = dict( pad_token_id=0, @@ -43,13 +44,13 @@ policy_micro_bs=GENERATE_MICRO_BATCH_SIZE, reward_micro_bs=GENERATE_MICRO_BATCH_SIZE, max_new_tokens=MAX_ANSWER_LEN, - write_to_file=False, + write_to_file=True, resume_step=RESUME_STEP, generate_kwargs={ 'do_sample': True, 'temperature': 1.0, - 'top_k': 0, - 'top_p': 0.9, + 'top_k': 1, + # 'top_p': 0.9, 'min_new_tokens': 1, 'num_beams': 1, 'early_stopping': True, @@ -69,7 +70,8 @@ gae_lambda=0.99, clip_reward_min=-5, clip_reward_max=5, - norm_rewards=True, + # norm_rewards=True, + norm_rewards=False, ) train_config = dict( @@ -92,7 +94,7 @@ trainer_config=dict( torch_dtype=MODEL_DTYPE, trainer_type='huggingface', - use_flash_attn=True, + use_flash_attn=USE_FLASH_ATTN, gradient_checkpointing=False, train_kwargs=dict( micro_bsz=1, @@ -150,7 +152,7 @@ trainer_config=dict( torch_dtype=MODEL_DTYPE, trainer_type='huggingface', - use_flash_attn=True, + use_flash_attn=USE_FLASH_ATTN, gradient_checkpointing=False, train_kwargs=dict( micro_bsz=1, @@ -173,7 +175,8 @@ 'reduce_bucket_size': 'auto', 'zero_hpz_partition_size': 1, 'zero_quantized_weights': False, - 'zero_quantized_gradients': False + 'zero_quantized_gradients': False, + 'stage3_gather_16bit_weights_on_model_save': True, }, 'bf16': { 'enabled': True @@ -197,7 +200,7 @@ trainer_config=dict( torch_dtype=MODEL_DTYPE, trainer_type='huggingface', - use_flash_attn=True, + use_flash_attn=USE_FLASH_ATTN, parallel=dict( data=dict(size=1, mode='ddp'), tensor=dict(size=1, mode='1d'), @@ -213,7 +216,7 @@ trainer_config=dict( torch_dtype=MODEL_DTYPE, trainer_type='huggingface', - use_flash_attn=True, + use_flash_attn=USE_FLASH_ATTN, parallel=dict( data=dict(size=1, mode='ddp'), tensor=dict(size=1, mode='1d'), diff --git a/xtuner/rlhf/model_backend/hf_model_runner.py b/xtuner/rlhf/model_backend/hf_model_runner.py index 52dd14d8b..aaea7224e 100644 --- a/xtuner/rlhf/model_backend/hf_model_runner.py +++ b/xtuner/rlhf/model_backend/hf_model_runner.py @@ -45,6 +45,7 @@ class HfModelRunner: def __init__(self, model_config): self.model_config: dict = model_config + self.index = 0 def initialize(self): # 0. Environment @@ -247,7 +248,7 @@ def train( step_interval: int = 1, # None means using the entire input as one batch micro_batch_size: Optional[Union[list[int], int]] = None, - debug=False, + debug=True, update_param=True, **_ignored, ): @@ -311,6 +312,16 @@ def train( loss_entry.append(loss) if debug: set_seed(1234) + self.info_rank0(f"[{self.model_type}] Train mb_index: {mb_index}, loss: {loss.item()}") + + #debug + # if self.accelerator.is_main_process: + # mbs = micro_batch['input_ids'].shape[0] * len(micro_batches) + # micro_batch['loss'] = loss.item() + # torch.save(micro_batch, f'/mnt/afs_2/liangkaihuan/Codes/xtuner/data/{self.model_type}_mbs{mbs}_index{self.index}.pth') + # self.index += 1 + # breakpoint() + loss_list[index] = sum(loss_entry) / len(loss_entry) # self.parameter_update(step_interval) @@ -376,7 +387,7 @@ def infer( output_attentions=False, output_hidden_states=False, infer_kwargs: Optional[dict] = {}, - debug=False, + debug=True, **_ignored, ) -> PolicyOutput: self.info_rank0( @@ -423,6 +434,7 @@ def infer( policy_outputs.append(policy_output_mb) if debug: self.set_seed(1234) + self.info_rank0(f"[{self.model_type}] Infer mb_index: {index}") # Concatenate the policy outputs from each micro batch and return the result # noqa: E501 return concat_policy_outputs(policy_outputs) @@ -441,7 +453,7 @@ def infer_from_future( output_attentions=False, output_hidden_states=False, infer_kwargs: Optional[dict] = {}, - debug=False, + debug=True, **_ignored, ) -> PolicyOutput: self.info_rank0( @@ -594,7 +606,7 @@ def generate( output_hidden_states=False, chat_template=None, generate_kwargs: Optional[dict] = {}, - debug=False, + debug=True, **_ignored, ) -> PolicyOutput: self.info_rank0( diff --git a/xtuner/rlhf/model_backend/vllm_model_runner.py b/xtuner/rlhf/model_backend/vllm_model_runner.py index 4c589fd4a..f5ff19e65 100644 --- a/xtuner/rlhf/model_backend/vllm_model_runner.py +++ b/xtuner/rlhf/model_backend/vllm_model_runner.py @@ -293,13 +293,13 @@ def pad_list_with_pad_token(int_list, max_length, pad_token_id): else: return int_list - # def pad_list_with_pad_token_right(int_list, max_length, pad_token_id): - # if len(int_list) < max_length: - # num_pad_token_to_add = max_length - len(int_list) - # padded_list = int_list + [pad_token_id] * num_pad_token_to_add - # return padded_list - # else: - # return int_list + def pad_list_with_pad_token_right(int_list, max_length, pad_token_id): + if len(int_list) < max_length: + num_pad_token_to_add = max_length - len(int_list) + padded_list = int_list + [pad_token_id] * num_pad_token_to_add + return padded_list + else: + return int_list policy_outputs = [] for _, req_output in enumerate(req_outputs): diff --git a/xtuner/rlhf/model_server/reward_model_server.py b/xtuner/rlhf/model_server/reward_model_server.py index 84e5e42af..8540eb82c 100644 --- a/xtuner/rlhf/model_server/reward_model_server.py +++ b/xtuner/rlhf/model_server/reward_model_server.py @@ -30,6 +30,24 @@ def init_tokenizer_and_config(self, model_config): def infer_async(self, inputs, attention_mask=None, *args, **infer_kwargs): if not isinstance(inputs, torch.Tensor): input_ids, attention_mask = encode_inputs(inputs, self.tokenizer) + + # Debug + # if isinstance(inputs[0], list): + # inputs = [ + # self.tokenizer.apply_chat_template( + # input, + # tokenize=False, + # add_generation_prompt=False, + # return_tensors='pt', + # ) for input in inputs + # ] + # output = self.tokenizer( + # inputs, + # return_tensors='pt', + # padding='max_length', + # max_length=2048, + # add_special_tokens=False) + # input_ids, attention_mask = output.input_ids, output.attention_mask else: input_ids = inputs From 82e669dd3d32d0b01fd30645d817ea288ef9c43b Mon Sep 17 00:00:00 2001 From: liangkaihuan Date: Mon, 29 Jul 2024 14:21:44 +0800 Subject: [PATCH 24/37] refactor pipeline to fix stuck issue --- examples/rlhf/internlm2_20b_test_32gpu.py | 12 +- xtuner/rlhf/repeaters/kl_gae.py | 84 +++++- xtuner/rlhf/test_actor_background.py | 138 ++++++--- xtuner/rlhf/test_vllm.py | 352 ++++++++++++++++++++++ xtuner/rlhf/trainer/ppo.py | 69 ++++- 5 files changed, 611 insertions(+), 44 deletions(-) create mode 100644 xtuner/rlhf/test_vllm.py diff --git a/examples/rlhf/internlm2_20b_test_32gpu.py b/examples/rlhf/internlm2_20b_test_32gpu.py index 076b32575..53d02d61f 100644 --- a/examples/rlhf/internlm2_20b_test_32gpu.py +++ b/examples/rlhf/internlm2_20b_test_32gpu.py @@ -66,8 +66,8 @@ repeater_config = dict( policy_micro_bs=INFER_MICRO_BATCH_SIZE, critic_micro_bs=INFER_MICRO_BATCH_SIZE, - # ref_micro_bs=INFER_MICRO_BATCH_SIZE, - ref_micro_bs=8, ## Optimize + ref_micro_bs=INFER_MICRO_BATCH_SIZE, + #ref_micro_bs=8, ## Optimize kl_coeff=0.01, gamma=1.0, gae_lambda=0.99, @@ -198,8 +198,8 @@ ), deepspeed_config={ "zero_optimization": { - # "stage": 3, - "stage": 0, ## Optimize + "stage": 3, + #"stage": 0, ## Optimize "overlap_comm": True, "stage3_gather_16bit_weights_on_model_save": True }, @@ -231,8 +231,8 @@ ), deepspeed_config={ "zero_optimization": { - # "stage": 3, - "stage": 0, ## Optimize + "stage": 3, + #"stage": 0, ## Optimize "overlap_comm": True, "stage3_gather_16bit_weights_on_model_save": True }, diff --git a/xtuner/rlhf/repeaters/kl_gae.py b/xtuner/rlhf/repeaters/kl_gae.py index 30fe992c9..c449e7bcf 100644 --- a/xtuner/rlhf/repeaters/kl_gae.py +++ b/xtuner/rlhf/repeaters/kl_gae.py @@ -280,4 +280,86 @@ def get_reward_collect(self, reward_output_ref): with Timer('reward_model.infer_get'): rm_out = self.reward_model.infer_get(reward_output_ref) rewards = rm_out.logits.squeeze(-1) - return rewards \ No newline at end of file + return rewards + + def get_reward_refer(self, prompt_datas, trajectories): + reward_output_ref = self.get_reward_async(prompt_datas, + trajectories) + ref_output = self.ref_model.infer_async( + inputs=trajectories.output_ids, + micro_batch_size=self.ref_micro_bs, + attention_mask=trajectories.attention_mask, + output_logits=False, + output_logprobs=True) + + ref_output = self.ref_model.infer_get(ref_output) + rewards = self.get_reward_collect(reward_output_ref) + + return rewards, ref_output.logprobs + + def process_kl_gae(self, rewards, ref_logprobs, values, policy_logprobs, trajectories): + trajectories['rewards'] = rewards + clipped_rewards = torch.clamp( + rewards, min=self.clip_reward_min, max=self.clip_reward_max) + trajectories['clipped_rewards'] = clipped_rewards + + if self.norm_rewards: + self.running_states.update(clipped_rewards) + norm_reward_score = (clipped_rewards - + self.running_states.mean) / ( + self.running_states.var.sqrt() + 1e-8) + else: + norm_reward_score = clipped_rewards + + action_mask = trajectories.action_mask + num_actions = action_mask.size(1) + + policy_logprobs = policy_logprobs[:, -num_actions:] + ref_logprobs = ref_logprobs[:, -num_actions:] + + if self.kl_coeff <= 0.0: + self.kl_coeff = 0.0 + # compute_approx_kl + log_ratio = policy_logprobs - ref_logprobs + kl = log_ratio * action_mask + kl_reward = -self.kl_coeff * kl + + eos_indices = action_mask.size( + 1) - 1 - action_mask.long().fliplr().argmax( + dim=1, keepdim=True) + last_reward = torch.zeros_like(kl).scatter_( + dim=1, + index=eos_indices, + src=norm_reward_score.unsqueeze(1).to(kl.dtype)) + + reward = last_reward + kl_reward + + entropy = -(policy_logprobs * + action_mask).sum(axis=-1) / action_mask.sum(axis=-1) + + kl_rewards = reward + kl_distance = kl + # return reward, entropy, kl, policy_logprobs, ref_logprobs + + # (kl_rewards, entropy, kl_distance, policy_logprobs, + # ref_logprobs) = self._get_kl_rewards(prompt_datas, trajectories) + + trajectories['kl'] = (kl_distance * action_mask).sum( + axis=-1) / action_mask.sum(axis=-1) + trajectories['entropy'] = entropy + trajectories['kl_rewards'] = kl_rewards + trajectories['policy_logprobs'] = policy_logprobs + trajectories['ref_logprobs'] = ref_logprobs + + old_values = values[:, -num_actions:] + advantages, returns = self.get_advantages_and_returns( + old_values, kl_rewards, action_mask) + if self.norm_adv: + advantages = (advantages - advantages.mean()) / ( + advantages.std() + 1e-8) + trajectories['advantages'] = advantages + trajectories['returns'] = returns + trajectories['old_values'] = old_values + trajectories['orig_values'] = values + + return trajectories \ No newline at end of file diff --git a/xtuner/rlhf/test_actor_background.py b/xtuner/rlhf/test_actor_background.py index 1a618bc9d..267d4dc20 100644 --- a/xtuner/rlhf/test_actor_background.py +++ b/xtuner/rlhf/test_actor_background.py @@ -72,7 +72,6 @@ def flatten_list(nested_list): os.makedirs(work_dir, exist_ok=True) # save original config shutil.copy2(args.config, f'{work_dir}/{os.path.basename(args.config)}') - import pdb;pdb.set_trace() logger.add( f'{work_dir}/train_rlhf.log', filter=lambda record: record['extra'].get('name') == 'train') @@ -127,9 +126,19 @@ def flatten_list(nested_list): env=txt_env, **repeater_config, ) + klgae_repeater = ray.remote(KLGAERepeater).remote( + ref_model=ref_model, + policy_model=policy_model, + critic_model=critic_model, + reward_model=reward_model, + env=txt_env, + **repeater_config, + ) # init trainer train_config = config.get('train_config', {}) - ppo = ray.remote(PPOTrainer).options(max_concurrency=2).remote( + # ppo = ray.remote(PPOTrainer).options(max_concurrency=2).remote( + # policy_model=policy_model, critic_model=critic_model, **train_config) + ppo = ray.remote(PPOTrainer).remote( policy_model=policy_model, critic_model=critic_model, **train_config) critic_warmup_step = train_config['critic_warmup_step'] save_interval = train_config['save_interval'] @@ -139,6 +148,22 @@ def flatten_list(nested_list): critic_warmup_step - resume_step) async_learn = train_config.get('async_learn', False) + # init log file + json_f = open(f'{work_dir}/train_rlhf.log.jsonl', 'w') + + # prompt_datas = deepcopy(next(prompt_mes_iter)) + # prompt_input_messages = [] + # for data in prompt_datas: + # assert data.mes_type == 'prompt' + # if data.sys_prompt != 'default': + # message = deepcopy([ + # dict( + # role='system', content=SYSTEM_PROMPT[data.sys_prompt]) + # ] + data.message) + # else: + # message = deepcopy(data.message) + # prompt_input_messages.append(message) + step = max(0, resume_step) while step <= max_train_step: s_t = time.time() @@ -202,12 +227,13 @@ def flatten_list(nested_list): critic_loss_refs, ppo_loss_refs, trajectories_refs = [], [], [] - micro_bs = 128 + micro_bs = 64 num_batches = len(prompt_input_messages) // micro_bs # Create placeholder lists to manage intermediate results trajectories_refs_stage1 = [None] * num_batches trajectories_refs_stage2 = [None] * num_batches + reward_refer_refs_stage2 = [None] * num_batches critic_loss_refs_stage3 = [None] * num_batches policy_loss_refs_stage3 = [None] * num_batches @@ -227,31 +253,47 @@ def flatten_list(nested_list): # Stage 2: Process trajectories for idx in range(num_batches): trajectories_ref = trajectories_refs_stage1[idx] - trajectories_ref = ppo_repeater.process.remote(prompt_datas[idx * micro_bs: (idx + 1) * micro_bs], - trajectories_ref) - trajectories_refs_stage2[idx] = trajectories_ref + # trajectories_ref = ppo_repeater.process.remote(prompt_datas, + # trajectories_ref) + # trajectories_refs_stage2[idx] = trajectories_ref + reward_ref, refer_logprobs_ref = ppo_repeater.get_reward_refer.options(num_returns=2).remote( + prompt_datas, trajectories_ref) + reward_refer_refs_stage2[idx] = (reward_ref, refer_logprobs_ref) # Stage 3: Critic & Policy learn for idx in range(num_batches): + # update_param = idx == num_batches - 1 + # trajectories_ref = trajectories_refs_stage2[idx] + # critic_loss_ref, critic_time_ref = ppo.critic_learn.options(num_returns=2).remote(trajectories_ref, update_param) + # critic_loss_refs_stage3[idx] = critic_loss_ref + # critic_time_refs_stage3[idx] = critic_time_ref + + # if critic_warmup_step <= 0: + # policy_loss_ref, policy_time_ref = ppo.policy_learn.options(num_returns=2).remote(trajectories_ref, update_param) + # policy_loss_refs_stage3[idx] = policy_loss_ref + # policy_time_refs_stage3[idx] = policy_time_ref + + trajectories_ref = trajectories_refs_stage1[idx] + values_ref, policy_logprobs_ref = ppo.infer.options(num_returns=2).remote(trajectories_ref) + + reward_ref, refer_logprobs_ref = reward_refer_refs_stage2[idx] + trajectories_ref = klgae_repeater.process_kl_gae.remote( + reward_ref, refer_logprobs_ref, values_ref, policy_logprobs_ref, trajectories_ref) + update_param = idx == num_batches - 1 - trajectories_ref = trajectories_refs_stage2[idx] - critic_loss_ref, critic_time_ref = ppo.critic_learn.options(num_returns=2).remote(trajectories_ref, update_param) + policy_loss_ref, critic_loss_ref = ppo.learn.options(num_returns=2).remote(trajectories_ref, update_param) + trajectories_refs_stage2[idx] = trajectories_ref critic_loss_refs_stage3[idx] = critic_loss_ref - critic_time_refs_stage3[idx] = critic_time_ref - - if critic_warmup_step <= 0: - policy_loss_ref, policy_time_ref = ppo.policy_learn.options(num_returns=2).remote(trajectories_ref, update_param) - policy_loss_refs_stage3[idx] = policy_loss_ref - policy_time_refs_stage3[idx] = policy_time_ref + policy_loss_refs_stage3[idx] = policy_loss_ref # Collect results policy_losses = flatten_list(ray.get([ref for ref in policy_loss_refs_stage3 if ref is not None])) critic_losses = flatten_list(ray.get(critic_loss_refs_stage3)) trajectories = ray.get(trajectories_refs_stage2) ray.get(ppo.sync_model.remote()) - critic_times = ray.get(critic_time_refs_stage3) - policy_times = ray.get(policy_time_refs_stage3) - + # critic_times = ray.get(critic_time_refs_stage3) + # policy_times = ray.get(policy_time_refs_stage3) + # breakpoint() total_time = time.time() - s_t # logger_train.info( @@ -259,27 +301,51 @@ def flatten_list(nested_list): # logger_train.info(f'rewards: {trajectories.rewards.mean()}') # critic_warmup_step -= 1 - # if config['rollout_config'].get('write_to_file', True): - # if not os.path.exists(f'{work_dir}/rollouts'): - # os.makedirs(f'{work_dir}/rollouts') - # with open(f'{work_dir}/rollouts/step{step}_rollout.log', - # 'a') as file: - # for output_s, r in zip(trajectories.output_str, - # trajectories.rewards): - # file.write(output_s + '\n' + 'Reward: ' + str(r.item()) + - # '\n' + '=' * 30 + '\n') + if config['rollout_config'].get('write_to_file', True): + if not os.path.exists(f'{work_dir}/rollouts'): + os.makedirs(f'{work_dir}/rollouts') + with open(f'{work_dir}/rollouts/step{step}_rollout.log', + 'w') as file: + for traj in trajectories: + for output_s, r, req_id in zip(traj.output_str, + traj.rewards, + traj.req_ids): + # breakpoint() + file.write(output_s + '\n' + + 'Reward: ' + str(r.item()) + '\n' + + 'Req_id: ' + str(req_id) + '\n' + + '=' * 30 + '\n') + + # import torch + # input_ids = [traj.input_ids for traj in trajectories] + # torch.save(torch.concat(input_ids, dim=0), f'{work_dir}/rollouts/step{step}_input_ids.pth') + + # output_ids = [traj.output_ids for traj in trajectories] + # torch.save(torch.concat(output_ids, dim=0), f'{work_dir}/rollouts/step{step}_output_ids.pth') + + # attention_mask = [traj.attention_mask for traj in trajectories] + # torch.save(torch.concat(attention_mask, dim=0), f'{work_dir}/rollouts/step{step}_attention_mask.pth') + + # kl_rewards = [traj.kl_rewards for traj in trajectories] + # torch.save(torch.concat(kl_rewards, dim=0), f'{work_dir}/rollouts/step{step}_kl_rewards.pth') + + # old_values = [traj.old_values for traj in trajectories] + # torch.save(torch.concat(old_values, dim=0), f'{work_dir}/rollouts/step{step}_old_values.pth') + + # orig_values = [traj.orig_values for traj in trajectories] + # torch.save(torch.concat(orig_values, dim=0), f'{work_dir}/rollouts/step{step}_orig_values.pth') query_tokens = [] resp_tokens = [] rewards = [] stage1_times = [] - stage2_times = [] + # stage2_times = [] for traj in trajectories: query_tokens.append(traj.question_mask.sum(-1).float().mean().item()) resp_tokens.append(traj.answer_mask.sum(-1).float().mean().item()) rewards.append(traj.rewards.mean().item()) stage1_times.append(traj['stage1_time']) - stage2_times.append(traj['stage2_time']) + # stage2_times.append(traj['stage2_time']) query_tokens_mean = sum(query_tokens) / len(query_tokens) resp_tokens_mean = sum(resp_tokens) / len(resp_tokens) @@ -309,21 +375,23 @@ def flatten_list(nested_list): # generate_time=gen_time, # forward_time=fwd_time, # training_time=train_time, - stage1_time=stage1_times, - stage2_time=stage2_times, - critic_time=critic_times, - policy_time=policy_times, + # stage1_time=stage1_times, + # stage2_time=stage2_times, + # critic_time=critic_times, + # policy_time=policy_times, total_time=total_time, reward_mean=reward_mean, query_tokens=query_tokens_mean, resp_tokens=resp_tokens_mean, ) - with open(f'{work_dir}/train_rlhf.log.jsonl', 'a') as f: - f.write(json.dumps(summaries) + '\n') - f.flush() + # with open(f'{work_dir}/train_rlhf.log.jsonl', 'a') as f: + json_f.write(json.dumps(summaries) + '\n') + json_f.flush() logger_train.info(f'[end to end] duration: {time.time() - s_t} s') step += 1 if (step % save_interval == 0) or (step == max_train_step): policy_model.save(f'{work_dir}/ckpt/policy_model/{step}') critic_model.save(f'{work_dir}/ckpt/critic_model/{step}') + + json_f.close() \ No newline at end of file diff --git a/xtuner/rlhf/test_vllm.py b/xtuner/rlhf/test_vllm.py new file mode 100644 index 000000000..cdbb2116b --- /dev/null +++ b/xtuner/rlhf/test_vllm.py @@ -0,0 +1,352 @@ +import argparse +import json +import os +import shutil +import time + +from loguru import logger + +from xtuner.rlhf.config.config import Config +from xtuner.rlhf.coordinator import Coordinator +from xtuner.rlhf.dataset import MessageIter +from xtuner.rlhf.envs import TxtEnv +from xtuner.rlhf.repeaters import KLGAERepeater +from xtuner.rlhf.timer import Timer +from xtuner.rlhf.trainer import PPOTrainer + +import ray +from copy import deepcopy +from xtuner.rlhf.envs.utils import SYSTEM_PROMPT +from policy_output import (PolicyOutput, concat_policy_outputs, + logprobs_from_logits) + +def parse_args(): + parser = argparse.ArgumentParser(description='Train LLM') + parser.add_argument( + '-c', + '--config', + help='config file name or path.', + type=str, + default='examples/rlhf/four_model_vllm_8gpu.py') + parser.add_argument( + '-w', + '--work_dir', + help='the dir to save logs and models', + type=str, + default=None) + parser.add_argument( + '-a', '--address', help='ray head address', type=str, default='auto') + args = parser.parse_args() + return args + + +def validate_config(config: Config): + assert config['model_configs'] is not None + assert config['model_configs']['policy'] is not None + assert config['model_configs']['policy']['model_path'] is not None + assert config['dataset_config'] is not None + assert config['rollout_config'] is not None + assert config['rollout_config']['generate_kwargs'] is not None + assert config['rollout_config']['max_new_tokens'] is not None + + +def flatten_list(nested_list): + flattened = [] + for item in nested_list: + if isinstance(item, list): + flattened.extend(flatten_list(item)) + else: + flattened.append(item) + return flattened + + +if __name__ == '__main__': + args = parse_args() + assert args.config is not None, 'config should not be None' + work_dir = args.work_dir + if work_dir is None: + work_dir = os.getcwd() + '/rlhf_trainlog_' + time.strftime( + '%Y-%m-%d-%H:%M:%S') + work_dir = os.path.abspath(work_dir) + logger.info(f'using work_dir: {work_dir}') + os.makedirs(work_dir, exist_ok=True) + # save original config + shutil.copy2(args.config, f'{work_dir}/{os.path.basename(args.config)}') + logger.add( + f'{work_dir}/train_rlhf.log', + filter=lambda record: record['extra'].get('name') == 'train') + logger_train = logger.bind(name='train') + + config = Config.from_file(args.config) + logger.info('#################### CONFIG BGN ####################') + for k, v in config.items(): + logger.info(f'{k}: {v}') + logger.info('#################### CONFIG END ####################') + + # init model + cluster_address = args.address + if cluster_address != 'auto': + cluster_address = f'ray://{cluster_address}:10001' + logger.info(f'cluster_address={cluster_address}') + coordinator = Coordinator(cluster_address, config) + model_dict = coordinator.create_models() + ref_model = model_dict['reference'] + policy_model = model_dict['policy'] + reward_model = model_dict['reward'] + critic_model = model_dict['critic'] + + # init prompt & pretrain dataset + prompt_dataset_config = config['prompt_dataset_config'] + prompt_mes_iter = MessageIter( + tokenizer=ref_model.tokenizer, **prompt_dataset_config) + pretrain_dataset_config = config.get('pretrain_dataset_config', {}) + pretrain_mes_iter = MessageIter( + tokenizer=ref_model.tokenizer, **pretrain_dataset_config) + + prompt_mes_iter = iter(prompt_mes_iter) + pretrain_mes_iter = iter( + pretrain_mes_iter) if pretrain_mes_iter.message_datasets else None + + # init txt env + rollout_config = config.get('rollout_config', {}) + txt_env = ray.remote(TxtEnv).remote( + policy_model=policy_model, + reward_model=reward_model, + # prompt_mes_iter=prompt_mes_iter, + # pretrain_mes_iter=pretrain_mes_iter, # None + **rollout_config, + ) + # init repeater + repeater_config = config.get('repeater_config', {}) + ppo_repeater = ray.remote(KLGAERepeater).remote( + ref_model=ref_model, + policy_model=policy_model, + critic_model=critic_model, + reward_model=reward_model, + env=txt_env, + **repeater_config, + ) + # init trainer + train_config = config.get('train_config', {}) + ppo = ray.remote(PPOTrainer).options(max_concurrency=2).remote( + policy_model=policy_model, critic_model=critic_model, **train_config) + critic_warmup_step = train_config['critic_warmup_step'] + save_interval = train_config['save_interval'] + max_train_step = train_config.get('max_train_step', float('inf')) + resume_step = train_config.get('resume_step', -1) + critic_warmup_step = min(critic_warmup_step, + critic_warmup_step - resume_step) + async_learn = train_config.get('async_learn', False) + + # init log file + json_f = open(f'{work_dir}/train_rlhf.log.jsonl', 'w') + + prompt_datas = deepcopy(next(prompt_mes_iter)) + prompt_input_messages = [] + for data in prompt_datas: + assert data.mes_type == 'prompt' + if data.sys_prompt != 'default': + message = deepcopy([ + dict( + role='system', content=SYSTEM_PROMPT[data.sys_prompt]) + ] + data.message) + else: + message = deepcopy(data.message) + prompt_input_messages.append(message) + + step = max(0, resume_step) + while step <= max_train_step: + s_t = time.time() + with Timer(f'step {step}: end_to_end'): + + # prompt_datas = deepcopy(next(prompt_mes_iter)) + # prompt_input_messages = [] + # for data in prompt_datas: + # assert data.mes_type == 'prompt' + # if data.sys_prompt != 'default': + # message = deepcopy([ + # dict( + # role='system', content=SYSTEM_PROMPT[data.sys_prompt]) + # ] + data.message) + # else: + # message = deepcopy(data.message) + # prompt_input_messages.append(message) + + # critic_loss_refs, ppo_loss_refs, trajectories_refs = [], [], [] + # micro_bs = 32 + # # for start in range(0, len(prompt_input_messages), micro_bs): + # for idx in range(0, len(prompt_input_messages)//micro_bs): + # # breakpoint() + # # update param in last micro batch + # if idx == len(prompt_input_messages) // micro_bs - 1: + # update_param = True + # else: + # update_param = False + + # # generate trajectories + # trajectories_ref = txt_env.rollout.remote(prompt_datas[idx*micro_bs : (idx+1)*micro_bs], + # prompt_input_messages[idx*micro_bs : (idx+1)*micro_bs], + # display=True) + # # trajectories = ray.get(trajectories_ref) + + # # deal with trajectories + # trajectories_ref = ppo_repeater.process.remote(trajectories_ref) + # # trajectories = ray.get(trajectories_ref) + + # # critic & policy learn + # critic_loss_ref = ppo.critic_learn.remote(trajectories_ref, update_param) + + # # ppo_loss, pt_loss = None, None + # if critic_warmup_step <= 0: + # # ppo_loss, pt_loss = ppo.policy_learn.remote(trajectories) + # ppo_loss_ref = ppo.policy_learn.remote(trajectories_ref, update_param) + # pt_loss = None + + # # logger_train.info( + # # f'[Policy Train] Step: {step}, ' + # # f'ppo loss: {ppo_loss}, pretrain loss: {pt_loss}') + # critic_loss_refs.append(critic_loss_ref) + # ppo_loss_refs.append(ppo_loss_ref) + # trajectories_refs.append(trajectories_ref) + + # ppo_losses = flatten_list(ray.get(ppo_loss_refs)) + # critic_losses = flatten_list(ray.get(critic_loss_refs)) + # trajectories = ray.get(trajectories_refs) + # # trajectories = concat_policy_outputs(trajectories) + # ray.get(ppo.sync_model.remote()) + + + critic_loss_refs, ppo_loss_refs, trajectories_refs = [], [], [] + micro_bs = 128 + num_batches = len(prompt_input_messages) // micro_bs + + # Create placeholder lists to manage intermediate results + trajectories_refs_stage1 = [None] * num_batches + trajectories_refs_stage2 = [None] * num_batches + critic_loss_refs_stage3 = [None] * num_batches + policy_loss_refs_stage3 = [None] * num_batches + + critic_time_refs_stage3 = [None] * num_batches + policy_time_refs_stage3 = [None] * num_batches + + # trajectories = ray.get(txt_env.rollout.remote(None, prompt_input_messages)) + + ref = txt_env.rollout_background.remote(prompt_input_messages) + ray.get(ref) + + # Stage 1: Generate trajectories + for idx in range(num_batches): + trajectories_ref = txt_env.get_generate_finish.remote(micro_bs) + trajectories_refs_stage1[idx] = trajectories_ref + trajectories = ray.get(trajectories_refs_stage1) + + # # Stage 2: Process trajectories + # for idx in range(num_batches): + # trajectories_ref = trajectories_refs_stage1[idx] + # trajectories_ref = ppo_repeater.process.remote(prompt_datas, + # trajectories_ref) + # trajectories_refs_stage2[idx] = trajectories_ref + + # # Stage 3: Critic & Policy learn + # for idx in range(num_batches): + # update_param = idx == num_batches - 1 + # trajectories_ref = trajectories_refs_stage2[idx] + # critic_loss_ref, critic_time_ref = ppo.critic_learn.options(num_returns=2).remote(trajectories_ref, update_param) + # critic_loss_refs_stage3[idx] = critic_loss_ref + # critic_time_refs_stage3[idx] = critic_time_ref + + # if critic_warmup_step <= 0: + # policy_loss_ref, policy_time_ref = ppo.policy_learn.options(num_returns=2).remote(trajectories_ref, update_param) + # policy_loss_refs_stage3[idx] = policy_loss_ref + # policy_time_refs_stage3[idx] = policy_time_ref + + # # Collect results + # policy_losses = flatten_list(ray.get([ref for ref in policy_loss_refs_stage3 if ref is not None])) + # critic_losses = flatten_list(ray.get(critic_loss_refs_stage3)) + # trajectories = ray.get(trajectories_refs_stage2) + # # ray.get(ppo.sync_model.remote()) + # critic_times = ray.get(critic_time_refs_stage3) + # policy_times = ray.get(policy_time_refs_stage3) + + total_time = time.time() - s_t + + # logger_train.info( + # f'[Critic Train] step: {step}, critic loss: {critic_loss}') + # logger_train.info(f'rewards: {trajectories.rewards.mean()}') + # critic_warmup_step -= 1 + + # if config['rollout_config'].get('write_to_file', True): + # if not os.path.exists(f'{work_dir}/rollouts'): + # os.makedirs(f'{work_dir}/rollouts') + # with open(f'{work_dir}/rollouts/step{step}_rollout.log', + # 'w') as file: + # for traj in trajectories: + # for output_s, req_id in zip(traj.output_str, + # traj.req_ids): + # file.write(output_s + '\n' + + # # 'Reward: ' + str(r.item()) + '\n' + + # 'Req_id: ' + str(req_id) + '\n' + + # '=' * 30 + '\n') + + # for output_s in trajectories.output_str: + # file.write(output_s + '\n' + + # '=' * 30 + '\n') + + query_tokens = [] + resp_tokens = [] + for traj in trajectories: + query_tokens.append(traj.question_mask.sum(-1).float().mean().item()) + resp_tokens.append(traj.answer_mask.sum(-1).float().mean().item()) + # rewards.append(traj.rewards.mean().item()) + # stage1_times.append(traj['stage1_time']) + # stage2_times.append(traj['stage2_time']) + + query_tokens_mean = sum(query_tokens) / len(query_tokens) + resp_tokens_mean = sum(resp_tokens) / len(resp_tokens) + # reward_mean = sum(rewards) / len(rewards) + + # policy_loss_mean = sum(policy_losses) / len(policy_losses) + # critic_loss_mean = sum(critic_losses) / len(critic_losses) + + summaries = dict( + # reward_mean=trajectories.rewards.mean().item(), + # reward_std=trajectories.rewards.std().item(), + # new_tokens_mean=trajectories.action_mask.sum( + # -1).float().mean().item(), + # new_tokens_std=trajectories.action_mask.sum( + # -1).float().std().item(), + # kl=trajectories.kl.mean().item(), + # entropy=trajectories.entropy.mean().item(), + step=step, + # policy_loss=policy_loss_mean, + # pretrain_loss=pt_loss, + # critic_loss=critic_loss_mean, + + # query_tokens_mean=trajectories.question_mask.sum( + # -1).float().mean().item(), + # resp_tokens_mean=trajectories.answer_mask.sum( + # -1).float().mean().item(), + + # generate_time=gen_time, + # forward_time=fwd_time, + # training_time=train_time, + # stage1_time=stage1_times, + # stage2_time=stage2_times, + # critic_time=critic_times, + # policy_time=policy_times, + total_time=total_time, + # reward_mean=reward_mean, + query_tokens=query_tokens_mean, + resp_tokens=resp_tokens_mean, + ) + # with open(f'{work_dir}/train_rlhf.log.jsonl', 'a') as f: + json_f.write(json.dumps(summaries) + '\n') + json_f.flush() + logger_train.info(f'[end to end] duration: {time.time() - s_t} s') + + step += 1 + if (step % save_interval == 0) or (step == max_train_step): + policy_model.save(f'{work_dir}/ckpt/policy_model/{step}') + critic_model.save(f'{work_dir}/ckpt/critic_model/{step}') + + json_f.close() \ No newline at end of file diff --git a/xtuner/rlhf/trainer/ppo.py b/xtuner/rlhf/trainer/ppo.py index 545230ed9..8d8d91487 100644 --- a/xtuner/rlhf/trainer/ppo.py +++ b/xtuner/rlhf/trainer/ppo.py @@ -3,6 +3,7 @@ from ..loss import CriticLoss, PPOPolicyLoss, PretrainLoss from ..model_server.base_model_server import BaseModelServer from ..timer import Timer +from ..policy_output import PolicyOutput import time @@ -121,9 +122,9 @@ def policy_learn(self, trajectories, update_param=True): # with Timer('policy_model.sync_model'): # self.policy_model.sync_model() - # return ppo_loss, pretrain_loss end = time.time() - return ppo_loss, end-start + return ppo_loss, pretrain_loss + # return ppo_loss, end-start def sync_model(self): with Timer('policy_model.sync_model'): @@ -208,3 +209,67 @@ def critic_learn_get(self, critic_loss_ref): self.critic_model.train_get(ref).item() for ref in critic_loss_ref ] + + def learn(self, trajectories, update_param): + critic_loss_ref = self.critic_learn_async(trajectories, update_param) + + ppo_loss, pt_loss = self.policy_learn(trajectories, update_param) + + critic_loss = self.critic_learn_get(critic_loss_ref) + return ppo_loss, critic_loss + + def infer(self, trajectories: PolicyOutput): + with Timer('critic_model.infer_async'): + critic_output_ref = self.critic_model.infer_async( + inputs=trajectories.output_ids, + attention_mask=trajectories.attention_mask, + output_logits=True, + # micro_batch_size=self.critic_micro_bs, + micro_batch_size=2, + ) + + with Timer('policy_model.infer_async'): + policy_output = self.policy_model.infer_async( + inputs=trajectories.output_ids, + # micro_batch_size=self.policy_micro_bs, + micro_batch_size=2, + attention_mask=trajectories.attention_mask, + output_logits=False, + output_logprobs=True) + + with Timer('policy_model.infer_get'): + policy_output = self.policy_model.infer_get(policy_output) + + with Timer('critic_model.infer_get'): + critic_output = self.critic_model.infer_get(critic_output_ref) + values = critic_output.logits.squeeze(-1) + + # action_mask = trajectories['action_mask'] + # num_actions = action_mask.size(1) + # old_values = values[:, -num_actions:] + # trajectories['old_values'] = old_values + + # policy_logprobs = policy_output.logprobs[:, -num_actions:] + # trajectories['policy_logprobs'] = policy_logprobs + + return values, policy_output.logprobs + + + # (kl_rewards, entropy, kl_distance, policy_logprobs, + # ref_logprobs) = self._get_kl_rewards(prompt_datas, trajectories) + # trajectories['kl'] = (kl_distance * action_mask).sum( + # axis=-1) / action_mask.sum(axis=-1) + # trajectories['entropy'] = entropy + # trajectories['kl_rewards'] = kl_rewards + # trajectories['ref_logprobs'] = ref_logprobs + + # advantages, returns = self.get_advantages_and_returns( + # old_values, kl_rewards, action_mask) + # if self.norm_adv: + # advantages = (advantages - advantages.mean()) / ( + # advantages.std() + 1e-8) + # trajectories['advantages'] = advantages + # trajectories['returns'] = returns + # end = time.time() + # trajectories['stage2_time'] = end - start + # return trajectories \ No newline at end of file From 0f65854931e2d6eb964a06904aef7cd9bbdff42e Mon Sep 17 00:00:00 2001 From: liangkaihuan Date: Mon, 29 Jul 2024 14:22:33 +0800 Subject: [PATCH 25/37] update scritpts and tools --- scripts/train_1node.sh | 3 +-- scripts/train_ray.sh | 2 +- tools/count_gpu.py | 5 ++++- tools/count_time.py | 12 +++++++++--- 4 files changed, 15 insertions(+), 7 deletions(-) diff --git a/scripts/train_1node.sh b/scripts/train_1node.sh index 637b435cb..3ec4db042 100644 --- a/scripts/train_1node.sh +++ b/scripts/train_1node.sh @@ -26,5 +26,4 @@ mkdir -p $work_dirs #python xtuner/rlhf/main.py -c $config_file -w $work_dirs > $work_dirs/debug.log 2>&1 & # python xtuner/rlhf/main.py -c $config_file -w $work_dirs 2>&1 | tee $work_dirs/main-$start_time.log # python xtuner/rlhf/test_actor.py -c $config_file -w $work_dirs 2>&1 | tee $work_dirs/main-$start_time.log -python xtuner/rlhf/test_actor_background.py -c $config_file -w $work_dirs -2>&1 | tee $work_dirs/main-$start_time.log +python xtuner/rlhf/test_actor_background.py -c $config_file -w $work_dirs 2>&1 | tee $work_dirs/main-$start_time.log diff --git a/scripts/train_ray.sh b/scripts/train_ray.sh index 43429dc4a..f64f612b3 100644 --- a/scripts/train_ray.sh +++ b/scripts/train_ray.sh @@ -89,7 +89,7 @@ if [ "$node_role" == "master" ]; then done fi # python xtuner/rlhf/main.py -c $config_file -w $work_dirs 2>&1 | tee $work_dirs/main-$start_time.log - python -u xtuner/rlhf/main.py -c $config_file -w $work_dirs > $work_dirs/main-$start_time.log 2>&1 & + python -u xtuner/rlhf/test_actor_background.py -c $config_file -w $work_dirs > $work_dirs/main-$start_time.log 2>&1 & else sleep infinity 2>&1 & fi diff --git a/tools/count_gpu.py b/tools/count_gpu.py index 2253b010e..6f3ce73d2 100644 --- a/tools/count_gpu.py +++ b/tools/count_gpu.py @@ -5,7 +5,10 @@ def calculate_averages(file_path): utilization_values = [] with open(file_path, 'r') as file: - for line in file: + #for line in file: + for idx, line in enumerate(file): + if idx < 70: + continue parts = line.split('Utilization:')[1].split(';')[0].split(',') values = [int(val) for val in parts if val] utilization_values.append(values) diff --git a/tools/count_time.py b/tools/count_time.py index 1e75a630d..65cf7e367 100644 --- a/tools/count_time.py +++ b/tools/count_time.py @@ -1,9 +1,11 @@ +import sys import json def calculate_statistics(filename): query_tokens_mean_list = [] resp_tokens_mean_list = [] - time_fields = ["generate_time", "forward_time", "training_time"] + #time_fields = ["generate_time", "forward_time", "training_time"] + time_fields = ["total_time"] time_sums = {field: 0 for field in time_fields} time_counts = {field: 0 for field in time_fields} @@ -13,6 +15,9 @@ def calculate_statistics(filename): count += 1 data = json.loads(line) + #query_tokens_mean_list.append(data['query_tokens']) + #resp_tokens_mean_list.append(data['resp_tokens']) + query_tokens_mean_list.append(data['query_tokens_mean']) resp_tokens_mean_list.append(data['resp_tokens_mean']) @@ -21,7 +26,7 @@ def calculate_statistics(filename): time_sums[field] += data[field] time_counts[field] += 1 - if count > 18: + if count > 15: break query_tokens_mean_avg = sum(query_tokens_mean_list) / len(query_tokens_mean_list) if query_tokens_mean_list else 0 @@ -42,7 +47,8 @@ def calculate_statistics(filename): } # 使用示例 -filename = 'logs/internlm2_20b_train_async/train_rlhf.log.jsonl' # 替换为你的jsonl文件路径 +#filename = 'logs/internlm2_20b_train_async/train_rlhf.log.jsonl' # 替换为你的jsonl文件路径 +filename = sys.argv[1] statistics = calculate_statistics(filename) print(statistics) From 1140f3a728e8cfdd4f665137250a616070e30bde Mon Sep 17 00:00:00 2001 From: liangkaihuan Date: Mon, 29 Jul 2024 14:53:30 +0800 Subject: [PATCH 26/37] unlock debug code to align 1.8B --- .../rlhf/model_backend/vllm_model_runner.py | 8 ++-- .../rlhf/model_server/reward_model_server.py | 40 +++++++++---------- xtuner/rlhf/test_actor_background.py | 2 +- xtuner/rlhf/trainer/ppo.py | 4 +- 4 files changed, 27 insertions(+), 27 deletions(-) diff --git a/xtuner/rlhf/model_backend/vllm_model_runner.py b/xtuner/rlhf/model_backend/vllm_model_runner.py index f5ff19e65..a6369061d 100644 --- a/xtuner/rlhf/model_backend/vllm_model_runner.py +++ b/xtuner/rlhf/model_backend/vllm_model_runner.py @@ -259,8 +259,8 @@ def generate_background( else: raise ValueError(f'Unsupported inputs with type({type(inputs)})') - # self.max_inputs_length = 1024 - self.max_inputs_length = max_inputs_length + self.max_inputs_length = 1024 + # self.max_inputs_length = max_inputs_length self.output_str = output_str self.output_logits = output_logits self.output_attentions = output_attentions @@ -310,8 +310,8 @@ def pad_list_with_pad_token_right(int_list, max_length, pad_token_id): output_token_ids = [ item for item in req_output.outputs[0].token_ids ] - # output_token_ids = pad_list_with_pad_token_right(output_token_ids, 1024, - # self.tokenizer.pad_token_id) + output_token_ids = pad_list_with_pad_token_right(output_token_ids, 1024, + self.tokenizer.pad_token_id) output_ids = input_ids + output_token_ids # concat output['input_ids'] = torch.Tensor(input_ids).to( diff --git a/xtuner/rlhf/model_server/reward_model_server.py b/xtuner/rlhf/model_server/reward_model_server.py index 8540eb82c..f5c95a29c 100644 --- a/xtuner/rlhf/model_server/reward_model_server.py +++ b/xtuner/rlhf/model_server/reward_model_server.py @@ -29,32 +29,32 @@ def init_tokenizer_and_config(self, model_config): # Inference def infer_async(self, inputs, attention_mask=None, *args, **infer_kwargs): if not isinstance(inputs, torch.Tensor): - input_ids, attention_mask = encode_inputs(inputs, self.tokenizer) + # input_ids, attention_mask = encode_inputs(inputs, self.tokenizer) # Debug - # if isinstance(inputs[0], list): - # inputs = [ - # self.tokenizer.apply_chat_template( - # input, - # tokenize=False, - # add_generation_prompt=False, - # return_tensors='pt', - # ) for input in inputs - # ] - # output = self.tokenizer( - # inputs, - # return_tensors='pt', - # padding='max_length', - # max_length=2048, - # add_special_tokens=False) - # input_ids, attention_mask = output.input_ids, output.attention_mask + if isinstance(inputs[0], list): + inputs = [ + self.tokenizer.apply_chat_template( + input, + tokenize=False, + add_generation_prompt=False, + return_tensors='pt', + ) for input in inputs + ] + output = self.tokenizer( + inputs, + return_tensors='pt', + padding='max_length', + max_length=2048, + add_special_tokens=False) + input_ids, attention_mask = output.input_ids, output.attention_mask else: input_ids = inputs # Reward model specific - if self.reward_token_id is not None: - input_ids, attention_mask = expand_reward_token_id( - self.reward_token_id, input_ids, attention_mask) + # if self.reward_token_id is not None: + # input_ids, attention_mask = expand_reward_token_id( + # self.reward_token_id, input_ids, attention_mask) return self.trainer.infer_async( input_ids=input_ids, diff --git a/xtuner/rlhf/test_actor_background.py b/xtuner/rlhf/test_actor_background.py index 267d4dc20..e416d8f37 100644 --- a/xtuner/rlhf/test_actor_background.py +++ b/xtuner/rlhf/test_actor_background.py @@ -227,7 +227,7 @@ def flatten_list(nested_list): critic_loss_refs, ppo_loss_refs, trajectories_refs = [], [], [] - micro_bs = 64 + micro_bs = 128 num_batches = len(prompt_input_messages) // micro_bs # Create placeholder lists to manage intermediate results diff --git a/xtuner/rlhf/trainer/ppo.py b/xtuner/rlhf/trainer/ppo.py index 8d8d91487..beecf9411 100644 --- a/xtuner/rlhf/trainer/ppo.py +++ b/xtuner/rlhf/trainer/ppo.py @@ -225,14 +225,14 @@ def infer(self, trajectories: PolicyOutput): attention_mask=trajectories.attention_mask, output_logits=True, # micro_batch_size=self.critic_micro_bs, - micro_batch_size=2, + micro_batch_size=8, ) with Timer('policy_model.infer_async'): policy_output = self.policy_model.infer_async( inputs=trajectories.output_ids, # micro_batch_size=self.policy_micro_bs, - micro_batch_size=2, + micro_batch_size=8, attention_mask=trajectories.attention_mask, output_logits=False, output_logprobs=True) From 4e5b1c3865fdf1b0b21c6b8394f81e9eb6ad0c8c Mon Sep 17 00:00:00 2001 From: liangkaihuan Date: Mon, 29 Jul 2024 16:54:06 +0800 Subject: [PATCH 27/37] update get data and support pretrain input --- xtuner/rlhf/envs/txt_env.py | 115 ++++++++++++++------------- xtuner/rlhf/test_actor_background.py | 86 ++++++++++++-------- 2 files changed, 113 insertions(+), 88 deletions(-) diff --git a/xtuner/rlhf/envs/txt_env.py b/xtuner/rlhf/envs/txt_env.py index 5d17d6842..7e630ad36 100644 --- a/xtuner/rlhf/envs/txt_env.py +++ b/xtuner/rlhf/envs/txt_env.py @@ -42,28 +42,12 @@ def __init__( self.generate_kwargs: dict = generate_kwargs self.resume_step = resume_step - def rollout(self, prompt_datas, prompt_input_messages, display=True): - # while self.resume_step > 0: - # logger.info(f'[Resume] {self.resume_step} consuming data...') - # next(self.prompt_mes_iter) - # if self.pretrain_mes_iter is not None: - # next(self.pretrain_mes_iter) - # self.resume_step -= 1 - # prompt_datas = deepcopy(next(self.prompt_mes_iter)) - # prompt_input_messages = [] - # for data in prompt_datas: - # assert data.mes_type == 'prompt' - # if data.sys_prompt != 'default': - # message = deepcopy([ - # dict( - # role='system', content=SYSTEM_PROMPT[data.sys_prompt]) - # ] + data.message) - # else: - # message = deepcopy(data.message) - # prompt_input_messages.append(message) - - start = time.time() - + def rollout( + self, + prompt_input_messages, + pretrain_input_messages, + display=True + ): # prompt data if display: logger.info( @@ -77,44 +61,37 @@ def rollout(self, prompt_datas, prompt_input_messages, display=True): generate_kwargs=self.generate_kwargs) logger.info(f'[Generate] len: {len(prompt_input_messages)}') - end = time.time() - trajectories['stage1_time'] = end - start - - # if self.async_reward: - # reward_output_ref = self.get_reward_async(prompt_datas, - # trajectories) - # trajectories['reward_output_ref'] = reward_output_ref - # else: - # rewards = self.get_reward(prompt_datas, trajectories) - # trajectories['rewards'] = rewards - # pretrain data - # if self.pretrain_mes_iter is not None: - # pretrain_datas = deepcopy(next(self.pretrain_mes_iter)) - # pretrain_input_messages = [] - # for data in pretrain_datas: - # assert data.mes_type == 'pretrain' - # pretrain_input_messages.append(message) - - # from xtuner.rlhf.tokenizer import encode_inputs - # pt_input_ids, pt_attention_mask = encode_inputs( - # pretrain_input_messages, self.policy_model.tokenizer) - # pretrain_labels = torch.nn.functional.pad( - # pt_input_ids[:, 1:], (0, 1), mode='constant', value=-100) - - # trajectories.pretrain_data = { - # 'input_ids': pt_input_ids, - # 'labels': pretrain_labels, - # 'attention_mask': pt_attention_mask - # } - # logger.info(f'[TxtEnv] gets {pt_input_ids.shape} pretrain data.') - # else: - # trajectories.pretrain_data = None - trajectories.pretrain_data = None + if pretrain_input_messages is not None: + from xtuner.rlhf.tokenizer import encode_inputs + pt_input_ids, pt_attention_mask = encode_inputs( + pretrain_input_messages, self.policy_model.tokenizer) + pretrain_labels = torch.nn.functional.pad( + pt_input_ids[:, 1:], (0, 1), mode='constant', value=-100) + + trajectories.pretrain_data = { + 'input_ids': pt_input_ids, + 'labels': pretrain_labels, + 'attention_mask': pt_attention_mask + } + logger.info(f'[TxtEnv] gets {pt_input_ids.shape} pretrain data.') + else: + trajectories.pretrain_data = None return trajectories - def rollout_background(self, prompt_input_messages): + def rollout_background( + self, + prompt_input_messages, + pretrain_input_messages, + display=True + ): + self.pretrain_input_messages = pretrain_input_messages + self.pretrain_idx = 0 + + if display: + logger.info( + f'[TXT_ENV For Generate]: \n{prompt_input_messages[0]}') with Timer('policy_model.generate'): ref = self.policy_model.generate_background( inputs=prompt_input_messages, @@ -128,8 +105,32 @@ def rollout_background(self, prompt_input_messages): def get_generate_finish(self, num): start = time.time() + # prompt data trajectories = self.policy_model.get_generate_finish(num) - trajectories.pretrain_data = None + + # pretrain data + # TODO: Get pretrain data proportionally + if self.pretrain_input_messages is not None: + assert self.pretrain_idx + num < len(self.pretrain_input_messages) + pretrain_input_messages = self.pretrain_input_messages[ + self.pretrain_idx:self.pretrain_idx+num] + #update pretrain idx + self.pretrain_idx += num + + from xtuner.rlhf.tokenizer import encode_inputs + pt_input_ids, pt_attention_mask = encode_inputs( + pretrain_input_messages, self.policy_model.tokenizer) + pretrain_labels = torch.nn.functional.pad( + pt_input_ids[:, 1:], (0, 1), mode='constant', value=-100) + + trajectories.pretrain_data = { + 'input_ids': pt_input_ids, + 'labels': pretrain_labels, + 'attention_mask': pt_attention_mask + } + logger.info(f'[TxtEnv] gets {pt_input_ids.shape} pretrain data.') + else: + trajectories.pretrain_data = None end = time.time() trajectories['stage1_time'] = end - start diff --git a/xtuner/rlhf/test_actor_background.py b/xtuner/rlhf/test_actor_background.py index e416d8f37..323579248 100644 --- a/xtuner/rlhf/test_actor_background.py +++ b/xtuner/rlhf/test_actor_background.py @@ -60,6 +60,54 @@ def flatten_list(nested_list): return flattened +class DataGenerator: + def __init__( + self, + prompt_mes_iter, + pretrain_mes_iter = None, + resume_step=-1, + ): + self.prompt_mes_iter = iter(prompt_mes_iter) + self.pretrain_mes_iter = iter( + pretrain_mes_iter) if pretrain_mes_iter.message_datasets else None + self.resume_step = resume_step + + def get(self): + while self.resume_step > 0: + logger.info(f'[Resume] {self.resume_step} consuming data...') + next(self.prompt_mes_iter) + if self.pretrain_mes_iter is not None: + next(self.pretrain_mes_iter) + self.resume_step -= 1 + + # prompt data + prompt_datas = deepcopy(next(self.prompt_mes_iter)) + prompt_input_messages = [] + for data in prompt_datas: + assert data.mes_type == 'prompt' + if data.sys_prompt != 'default': + message = deepcopy([ + dict( + role='system', content=SYSTEM_PROMPT[data.sys_prompt]) + ] + data.message) + else: + message = deepcopy(data.message) + prompt_input_messages.append(message) + + # pretrain data + if self.pretrain_mes_iter is not None: + pretrain_input_messages = [] + pretrain_datas = deepcopy(next(self.pretrain_mes_iter)) + for data in pretrain_datas: + assert data.mes_type == 'pretrain' + pretrain_input_messages.append(message) + + if self.pretrain_mes_iter is not None: + return prompt_datas, prompt_input_messages, pretrain_input_messages + else: + return prompt_datas, prompt_input_messages, None + + if __name__ == '__main__': args = parse_args() assert args.config is not None, 'config should not be None' @@ -103,10 +151,6 @@ def flatten_list(nested_list): pretrain_mes_iter = MessageIter( tokenizer=ref_model.tokenizer, **pretrain_dataset_config) - prompt_mes_iter = iter(prompt_mes_iter) - pretrain_mes_iter = iter( - pretrain_mes_iter) if pretrain_mes_iter.message_datasets else None - # init txt env rollout_config = config.get('rollout_config', {}) txt_env = ray.remote(TxtEnv).remote( @@ -151,36 +195,17 @@ def flatten_list(nested_list): # init log file json_f = open(f'{work_dir}/train_rlhf.log.jsonl', 'w') - # prompt_datas = deepcopy(next(prompt_mes_iter)) - # prompt_input_messages = [] - # for data in prompt_datas: - # assert data.mes_type == 'prompt' - # if data.sys_prompt != 'default': - # message = deepcopy([ - # dict( - # role='system', content=SYSTEM_PROMPT[data.sys_prompt]) - # ] + data.message) - # else: - # message = deepcopy(data.message) - # prompt_input_messages.append(message) + data_generator = DataGenerator( + prompt_mes_iter=prompt_mes_iter, + pretrain_mes_iter=pretrain_mes_iter, # None + resume_step=resume_step, + ) step = max(0, resume_step) while step <= max_train_step: s_t = time.time() with Timer(f'step {step}: end_to_end'): - - prompt_datas = deepcopy(next(prompt_mes_iter)) - prompt_input_messages = [] - for data in prompt_datas: - assert data.mes_type == 'prompt' - if data.sys_prompt != 'default': - message = deepcopy([ - dict( - role='system', content=SYSTEM_PROMPT[data.sys_prompt]) - ] + data.message) - else: - message = deepcopy(data.message) - prompt_input_messages.append(message) + prompt_datas, prompt_input_messages, pretrain_input_messages = data_generator.get() # critic_loss_refs, ppo_loss_refs, trajectories_refs = [], [], [] # micro_bs = 32 @@ -241,7 +266,7 @@ def flatten_list(nested_list): policy_time_refs_stage3 = [None] * num_batches # breakpoint() - ref = txt_env.rollout_background.remote(prompt_input_messages) + ref = txt_env.rollout_background.remote(prompt_input_messages, pretrain_input_messages) ray.get(ref) # Stage 1: Generate trajectories @@ -249,7 +274,6 @@ def flatten_list(nested_list): trajectories_ref = txt_env.get_generate_finish.remote(micro_bs) trajectories_refs_stage1[idx] = trajectories_ref - # breakpoint() # Stage 2: Process trajectories for idx in range(num_batches): trajectories_ref = trajectories_refs_stage1[idx] From 2fe77ea1258c511c8f65b8c431d4d60b40452603 Mon Sep 17 00:00:00 2001 From: liangkaihuan Date: Mon, 29 Jul 2024 17:18:46 +0800 Subject: [PATCH 28/37] clean txt_env class --- xtuner/rlhf/envs/txt_env.py | 103 +-------------------------- xtuner/rlhf/test_actor_background.py | 13 ++-- 2 files changed, 6 insertions(+), 110 deletions(-) diff --git a/xtuner/rlhf/envs/txt_env.py b/xtuner/rlhf/envs/txt_env.py index 7e630ad36..ce10b1c53 100644 --- a/xtuner/rlhf/envs/txt_env.py +++ b/xtuner/rlhf/envs/txt_env.py @@ -1,15 +1,10 @@ -from collections.abc import Iterable -from copy import deepcopy - import torch from loguru import logger from ..model_server.base_model_server import BaseModelServer from ..timer import Timer from .base import EnvBase -from .utils import SYSTEM_PROMPT -import time class TxtEnv(EnvBase): """A generic RL environment to generate textual sequences.""" @@ -17,30 +12,15 @@ class TxtEnv(EnvBase): def __init__( self, policy_model: BaseModelServer, - reward_model: BaseModelServer, - # prompt_mes_iter: Iterable, - # pretrain_mes_iter: Iterable = None, max_new_tokens: int = 1024, policy_micro_bs: int = 32, - reward_micro_bs: int = 32, - async_reward: bool = True, generate_kwargs: dict = None, - resume_step=-1, **_ignored, ): self.policy_model = policy_model - self.reward_model = reward_model - - # self.prompt_mes_iter = iter(prompt_mes_iter) - # self.pretrain_mes_iter = iter( - # pretrain_mes_iter) if pretrain_mes_iter.message_datasets else None - self.max_new_tokens = max_new_tokens self.policy_micro_bs = policy_micro_bs - self.reward_micro_bs = reward_micro_bs - self.async_reward = async_reward self.generate_kwargs: dict = generate_kwargs - self.resume_step = resume_step def rollout( self, @@ -93,18 +73,15 @@ def rollout_background( logger.info( f'[TXT_ENV For Generate]: \n{prompt_input_messages[0]}') with Timer('policy_model.generate'): - ref = self.policy_model.generate_background( + self.policy_model.generate_background( inputs=prompt_input_messages, micro_batch_size=self.policy_micro_bs, step=self.max_new_tokens, output_str=True, generate_kwargs=self.generate_kwargs) logger.info(f'[Generate] len: {len(prompt_input_messages)}') - return ref - def get_generate_finish(self, num): - start = time.time() - + def rollout_get(self, num): # prompt data trajectories = self.policy_model.get_generate_finish(num) @@ -132,80 +109,4 @@ def get_generate_finish(self, num): else: trajectories.pretrain_data = None - end = time.time() - trajectories['stage1_time'] = end - start return trajectories - - # default get_reward() is blocking. - # get_reward_async() needs to call get_reward_collect() - def get_reward_async(self, prompt_datas, policyout): - rm_input_messages = [] - for i in range(len(prompt_datas)): - if prompt_datas[i].mes_type != 'prompt': - continue - if (prompt_datas[i].rm_prompt != - 'default') or (prompt_datas[i].sys_prompt != 'default'): - # Conditional Reward Model - # for queries from different domains, use appropriate conditional system prompts # noqa: E501 - # From Alignment section of the InternLM2 Technical Report: - # https://arxiv.org/pdf/2403.17297 - if prompt_datas[i].rm_prompt != 'default': - prompt = prompt_datas[i].rm_prompt - else: - prompt = prompt_datas[i].sys_prompt - cur_rm_data = [ - dict(role='system', content=SYSTEM_PROMPT[prompt]) - ] + prompt_datas[i].message + [ - dict( - role='assistant', content=policyout.output_ans_str[i]) - ] - else: - cur_rm_data = prompt_datas[i].message + [ - dict( - role='assistant', content=policyout.output_ans_str[i]) - ] - rm_input_messages.append(cur_rm_data) - - logger.info(f'[For Reward]: {rm_input_messages[0]}') - with Timer('reward_model.infer_async'): - reward_output_ref = self.reward_model.infer_async( - rm_input_messages, - output_logprobs=False, - micro_batch_size=self.reward_micro_bs) - return reward_output_ref - - def get_reward_collect(self, reward_output_ref): - with Timer('reward_model.infer_get'): - rm_out = self.reward_model.infer_get(reward_output_ref) - rewards = rm_out.logits.squeeze(-1) - return rewards - - def get_reward(self, prompt_datas, policyout): - rm_input_messages = [] - for i in range(len(prompt_datas)): - if prompt_datas[i].mes_type != 'prompt': - continue - if prompt_datas[i].rm_prompt != 'default': - cur_rm_data = [ - dict( - role='system', - content=SYSTEM_PROMPT[prompt_datas[i].rm_prompt]) - ] + prompt_datas[i].message + [ - dict( - role='assistant', content=policyout.output_ans_str[i]) - ] - else: - cur_rm_data = prompt_datas[i].message + [ - dict( - role='assistant', content=policyout.output_ans_str[i]) - ] - rm_input_messages.append(cur_rm_data) - - logger.info(f'[For Reward]: {rm_input_messages[0]}') - with Timer('reward_model.infer'): - rm_out = self.reward_model.infer( - rm_input_messages, - output_logprobs=False, - micro_batch_size=self.reward_micro_bs) - rewards = rm_out.logits.squeeze(-1) - return rewards diff --git a/xtuner/rlhf/test_actor_background.py b/xtuner/rlhf/test_actor_background.py index 323579248..cba7b85f8 100644 --- a/xtuner/rlhf/test_actor_background.py +++ b/xtuner/rlhf/test_actor_background.py @@ -156,8 +156,6 @@ def get(self): txt_env = ray.remote(TxtEnv).remote( policy_model=policy_model, reward_model=reward_model, - # prompt_mes_iter=prompt_mes_iter, - # pretrain_mes_iter=pretrain_mes_iter, # None **rollout_config, ) # init repeater @@ -265,13 +263,10 @@ def get(self): critic_time_refs_stage3 = [None] * num_batches policy_time_refs_stage3 = [None] * num_batches - # breakpoint() - ref = txt_env.rollout_background.remote(prompt_input_messages, pretrain_input_messages) - ray.get(ref) - # Stage 1: Generate trajectories + txt_env.rollout_background.remote(prompt_input_messages, pretrain_input_messages) for idx in range(num_batches): - trajectories_ref = txt_env.get_generate_finish.remote(micro_bs) + trajectories_ref = txt_env.rollout_get.remote(micro_bs) trajectories_refs_stage1[idx] = trajectories_ref # Stage 2: Process trajectories @@ -362,13 +357,13 @@ def get(self): query_tokens = [] resp_tokens = [] rewards = [] - stage1_times = [] + # stage1_times = [] # stage2_times = [] for traj in trajectories: query_tokens.append(traj.question_mask.sum(-1).float().mean().item()) resp_tokens.append(traj.answer_mask.sum(-1).float().mean().item()) rewards.append(traj.rewards.mean().item()) - stage1_times.append(traj['stage1_time']) + # stage1_times.append(traj['stage1_time']) # stage2_times.append(traj['stage2_time']) query_tokens_mean = sum(query_tokens) / len(query_tokens) From 03d2f2ecb85539d684968bd55267ef7148855813 Mon Sep 17 00:00:00 2001 From: liangkaihuan Date: Mon, 29 Jul 2024 20:10:25 +0800 Subject: [PATCH 29/37] clean PPOTrainer class --- examples/rlhf/internlm2_1_8b_test_8gpu.py | 6 +- xtuner/rlhf/test_actor_background.py | 46 +++++++++----- xtuner/rlhf/trainer/ppo.py | 76 +++++++---------------- 3 files changed, 56 insertions(+), 72 deletions(-) diff --git a/examples/rlhf/internlm2_1_8b_test_8gpu.py b/examples/rlhf/internlm2_1_8b_test_8gpu.py index d35e8223e..de9a9d2b4 100644 --- a/examples/rlhf/internlm2_1_8b_test_8gpu.py +++ b/examples/rlhf/internlm2_1_8b_test_8gpu.py @@ -75,8 +75,10 @@ ) train_config = dict( - policy_micro_bs=TRAIN_MICRO_BATCH_SIZE, - critic_micro_bs=TRAIN_MICRO_BATCH_SIZE, + policy_train_micro_bs=TRAIN_MICRO_BATCH_SIZE, + critic_train_micro_bs=TRAIN_MICRO_BATCH_SIZE, + policy_infer_micro_bs=INFER_MICRO_BATCH_SIZE, + critic_infer_micro_bs=INFER_MICRO_BATCH_SIZE, ppo_loss_weight=1.0, pretrain_loss_weight=0.5, critic_warmup_step=0, diff --git a/xtuner/rlhf/test_actor_background.py b/xtuner/rlhf/test_actor_background.py index cba7b85f8..37bea35f8 100644 --- a/xtuner/rlhf/test_actor_background.py +++ b/xtuner/rlhf/test_actor_background.py @@ -178,10 +178,11 @@ def get(self): ) # init trainer train_config = config.get('train_config', {}) - # ppo = ray.remote(PPOTrainer).options(max_concurrency=2).remote( - # policy_model=policy_model, critic_model=critic_model, **train_config) ppo = ray.remote(PPOTrainer).remote( - policy_model=policy_model, critic_model=critic_model, **train_config) + policy_model=policy_model, + critic_model=critic_model, + **train_config) + critic_warmup_step = train_config['critic_warmup_step'] save_interval = train_config['save_interval'] max_train_step = train_config.get('max_train_step', float('inf')) @@ -257,8 +258,10 @@ def get(self): trajectories_refs_stage1 = [None] * num_batches trajectories_refs_stage2 = [None] * num_batches reward_refer_refs_stage2 = [None] * num_batches - critic_loss_refs_stage3 = [None] * num_batches - policy_loss_refs_stage3 = [None] * num_batches + + critic_loss_refs = [None] * num_batches + policy_loss_refs = [None] * num_batches + pretrain_loss_refs = [None] * num_batches critic_time_refs_stage3 = [None] * num_batches policy_time_refs_stage3 = [None] * num_batches @@ -298,27 +301,37 @@ def get(self): reward_ref, refer_logprobs_ref = reward_refer_refs_stage2[idx] trajectories_ref = klgae_repeater.process_kl_gae.remote( reward_ref, refer_logprobs_ref, values_ref, policy_logprobs_ref, trajectories_ref) - - update_param = idx == num_batches - 1 - policy_loss_ref, critic_loss_ref = ppo.learn.options(num_returns=2).remote(trajectories_ref, update_param) trajectories_refs_stage2[idx] = trajectories_ref - critic_loss_refs_stage3[idx] = critic_loss_ref - policy_loss_refs_stage3[idx] = policy_loss_ref + + update_param = (idx == num_batches - 1) + # policy_loss_ref, pretrain_loss_ref, critic_loss_ref = ppo.train.options(num_returns=3).remote(trajectories_ref, update_param, critic_warmup_step) + policy_loss_refs[idx], pretrain_loss_refs[idx], critic_loss_refs[idx] = ( + ppo.train.options(num_returns=3).remote( + trajectories_ref, + update_param, + critic_warmup_step + ) + ) + + # critic_loss_refs[idx] = critic_loss_ref + # policy_loss_refs[idx] = policy_loss_ref + # pretrain_loss_refs[idx] = pretrain_loss_ref # Collect results - policy_losses = flatten_list(ray.get([ref for ref in policy_loss_refs_stage3 if ref is not None])) - critic_losses = flatten_list(ray.get(critic_loss_refs_stage3)) - trajectories = ray.get(trajectories_refs_stage2) + policy_losses = flatten_list(ray.get(policy_loss_refs)) + pretrain_losses = flatten_list(ray.get(pretrain_loss_refs)) + critic_losses = flatten_list(ray.get(critic_loss_refs)) ray.get(ppo.sync_model.remote()) + trajectories = ray.get(trajectories_refs_stage2) # critic_times = ray.get(critic_time_refs_stage3) # policy_times = ray.get(policy_time_refs_stage3) # breakpoint() total_time = time.time() - s_t + critic_warmup_step -= 1 # logger_train.info( # f'[Critic Train] step: {step}, critic loss: {critic_loss}') # logger_train.info(f'rewards: {trajectories.rewards.mean()}') - # critic_warmup_step -= 1 if config['rollout_config'].get('write_to_file', True): if not os.path.exists(f'{work_dir}/rollouts'): @@ -370,8 +383,9 @@ def get(self): resp_tokens_mean = sum(resp_tokens) / len(resp_tokens) reward_mean = sum(rewards) / len(rewards) - policy_loss_mean = sum(policy_losses) / len(policy_losses) - critic_loss_mean = sum(critic_losses) / len(critic_losses) + policy_loss_mean = sum(policy_losses) / len(policy_losses) if policy_losses else None + pretrain_loss_mean = sum(pretrain_losses) / len(pretrain_losses) if pretrain_losses else None + critic_loss_mean = sum(critic_losses) / len(critic_losses) summaries = dict( # reward_mean=trajectories.rewards.mean().item(), diff --git a/xtuner/rlhf/trainer/ppo.py b/xtuner/rlhf/trainer/ppo.py index beecf9411..671435c69 100644 --- a/xtuner/rlhf/trainer/ppo.py +++ b/xtuner/rlhf/trainer/ppo.py @@ -13,8 +13,10 @@ def __init__( self, policy_model: BaseModelServer, critic_model: BaseModelServer, - policy_micro_bs=2, - critic_micro_bs=2, + policy_train_micro_bs=2, + critic_train_micro_bs=2, + policy_infer_micro_bs=8, + critic_infer_micro_bs=8, policy_learn_time=1, critic_learn_time=1, policy_minibatch=None, @@ -31,7 +33,8 @@ def __init__( self.policy_model = policy_model self.policy_learn_time = policy_learn_time self.policy_minibatch = policy_minibatch - self.policy_micro_bs = policy_micro_bs + self.policy_train_micro_bs = policy_train_micro_bs + self.policy_infer_micro_bs = policy_infer_micro_bs self.ppo_loss_weight = ppo_loss_weight self.pretrain_loss_weight = pretrain_loss_weight @@ -42,13 +45,12 @@ def __init__( self.critic_model = critic_model self.critic_learn_time = critic_learn_time self.critic_minibatch = critic_minibatch - self.critic_micro_bs = critic_micro_bs + self.critic_train_micro_bs = critic_train_micro_bs + self.critic_infer_micro_bs = critic_infer_micro_bs self.critic_criterion = critic_criterion def policy_learn(self, trajectories, update_param=True): - start = time.time() - if self.policy_minibatch is None: self.policy_minibatch = len(trajectories.output_ids) assert len(trajectories.output_ids) % self.policy_minibatch == 0 @@ -71,7 +73,7 @@ def policy_learn(self, trajectories, update_param=True): ] train_criterion = [self.policy_criterion] loss_weights = [self.ppo_loss_weight] - micro_batch_size = [self.policy_micro_bs] + micro_batch_size = [self.policy_train_micro_bs] train_lables = [ dict( @@ -95,7 +97,7 @@ def policy_learn(self, trajectories, update_param=True): trajectories.pretrain_data['attention_mask']) train_criterion.append(self.pretrain_criterion) loss_weights.append(self.pretrain_loss_weight) - micro_batch_size.append(self.policy_micro_bs) + micro_batch_size.append(self.policy_train_micro_bs) with Timer('policy_model.train'): p_loss = self.policy_model.train( @@ -119,12 +121,7 @@ def policy_learn(self, trajectories, update_param=True): f'[Policy Train] prompt data: {train_input_ids[0].shape}, ppo loss: {p_loss.item()}' # noqa: E501 ) - # with Timer('policy_model.sync_model'): - # self.policy_model.sync_model() - - end = time.time() return ppo_loss, pretrain_loss - # return ppo_loss, end-start def sync_model(self): with Timer('policy_model.sync_model'): @@ -151,7 +148,7 @@ def critic_learn(self, trajectories, update_param=True): labels=labels, attention_mask=critic_batch_inputs['attention_mask'], criterion=self.critic_criterion, - micro_batch_size=self.critic_micro_bs, + micro_batch_size=self.critic_train_micro_bs, update_param=update_param, ) logger.info(f'[Critic train] {self.critic_minibatch} batch, ' @@ -196,7 +193,7 @@ def critic_learn_async(self, trajectories, update_param=True): labels=labels, attention_mask=critic_batch_inputs['attention_mask'], criterion=self.critic_criterion, - micro_batch_size=self.critic_micro_bs, + micro_batch_size=self.critic_train_micro_bs, update_param=update_param, ) logger.info(f'[critic train] {self.critic_minibatch} batch') @@ -210,13 +207,14 @@ def critic_learn_get(self, critic_loss_ref): for ref in critic_loss_ref ] - def learn(self, trajectories, update_param): - critic_loss_ref = self.critic_learn_async(trajectories, update_param) - - ppo_loss, pt_loss = self.policy_learn(trajectories, update_param) - - critic_loss = self.critic_learn_get(critic_loss_ref) - return ppo_loss, critic_loss + def train(self, trajectories, update_param, critic_warmup_step=-1): + ppo_loss, pt_loss = None, None + with Timer('policy_critic_learn'): + critic_loss_ref = self.critic_learn_async(trajectories, update_param) + if critic_warmup_step <= 0: + ppo_loss, pt_loss = self.policy_learn(trajectories, update_param) + critic_loss = self.critic_learn_get(critic_loss_ref) + return ppo_loss, pt_loss, critic_loss def infer(self, trajectories: PolicyOutput): with Timer('critic_model.infer_async'): @@ -224,15 +222,13 @@ def infer(self, trajectories: PolicyOutput): inputs=trajectories.output_ids, attention_mask=trajectories.attention_mask, output_logits=True, - # micro_batch_size=self.critic_micro_bs, - micro_batch_size=8, + micro_batch_size=self.critic_infer_micro_bs, ) with Timer('policy_model.infer_async'): policy_output = self.policy_model.infer_async( inputs=trajectories.output_ids, - # micro_batch_size=self.policy_micro_bs, - micro_batch_size=8, + micro_batch_size=self.policy_infer_micro_bs, attention_mask=trajectories.attention_mask, output_logits=False, output_logprobs=True) @@ -244,32 +240,4 @@ def infer(self, trajectories: PolicyOutput): critic_output = self.critic_model.infer_get(critic_output_ref) values = critic_output.logits.squeeze(-1) - # action_mask = trajectories['action_mask'] - # num_actions = action_mask.size(1) - # old_values = values[:, -num_actions:] - # trajectories['old_values'] = old_values - - # policy_logprobs = policy_output.logprobs[:, -num_actions:] - # trajectories['policy_logprobs'] = policy_logprobs - return values, policy_output.logprobs - - - # (kl_rewards, entropy, kl_distance, policy_logprobs, - # ref_logprobs) = self._get_kl_rewards(prompt_datas, trajectories) - # trajectories['kl'] = (kl_distance * action_mask).sum( - # axis=-1) / action_mask.sum(axis=-1) - # trajectories['entropy'] = entropy - # trajectories['kl_rewards'] = kl_rewards - # trajectories['ref_logprobs'] = ref_logprobs - - # advantages, returns = self.get_advantages_and_returns( - # old_values, kl_rewards, action_mask) - # if self.norm_adv: - # advantages = (advantages - advantages.mean()) / ( - # advantages.std() + 1e-8) - # trajectories['advantages'] = advantages - # trajectories['returns'] = returns - # end = time.time() - # trajectories['stage2_time'] = end - start - # return trajectories \ No newline at end of file From bc4d7127f889ac41b2933a9098eed89c19af670b Mon Sep 17 00:00:00 2001 From: liangkaihuan Date: Mon, 29 Jul 2024 20:56:05 +0800 Subject: [PATCH 30/37] clean KLGAERepeater class --- xtuner/rlhf/repeaters/kl_gae.py | 185 +++------------------------ xtuner/rlhf/test_actor_background.py | 8 +- 2 files changed, 18 insertions(+), 175 deletions(-) diff --git a/xtuner/rlhf/repeaters/kl_gae.py b/xtuner/rlhf/repeaters/kl_gae.py index c449e7bcf..7699ba0f9 100644 --- a/xtuner/rlhf/repeaters/kl_gae.py +++ b/xtuner/rlhf/repeaters/kl_gae.py @@ -8,19 +8,14 @@ from loguru import logger from xtuner.rlhf.envs.utils import SYSTEM_PROMPT -import time class KLGAERepeater(RepeaterBase): def __init__( self, ref_model: BaseModelServer, - policy_model: BaseModelServer, - critic_model: BaseModelServer, reward_model: BaseModelServer, - policy_micro_bs: int = 8, ref_micro_bs: int = 8, - critic_micro_bs: int = 32, reward_micro_bs: int = 8, kl_coeff=0.01, gamma=1.0, @@ -29,18 +24,13 @@ def __init__( clip_reward_max: int = 5, norm_rewards=True, norm_adv=False, - env=None, **_ignored, ): # models self.ref_model = ref_model - self.policy_model = policy_model - self.critic_model = critic_model self.reward_model = reward_model - self.policy_micro_bs = policy_micro_bs self.ref_micro_bs = ref_micro_bs - self.critic_micro_bs = critic_micro_bs self.reward_micro_bs = reward_micro_bs self.kl_coeff = kl_coeff self.gamma = gamma @@ -53,148 +43,6 @@ def __init__( self.running_states = RunningStates(epsilon=0) self.norm_adv = norm_adv - # only used for async reward model.infer_get() in _get_kl_rewards - self.env = env - - def process(self, prompt_datas, trajectories: PolicyOutput): - start = time.time() - critic_output_ref = self._get_values_async(trajectories) - action_mask = trajectories['action_mask'] - num_actions = action_mask.size(1) - (kl_rewards, entropy, kl_distance, policy_logprobs, - ref_logprobs) = self._get_kl_rewards(prompt_datas, trajectories) - trajectories['kl'] = (kl_distance * action_mask).sum( - axis=-1) / action_mask.sum(axis=-1) - trajectories['entropy'] = entropy - trajectories['kl_rewards'] = kl_rewards - trajectories['policy_logprobs'] = policy_logprobs - trajectories['ref_logprobs'] = ref_logprobs - - values = self._get_values_collect(critic_output_ref) - old_values = values[:, -num_actions:] - advantages, returns = self.get_advantages_and_returns( - old_values, kl_rewards, action_mask) - if self.norm_adv: - advantages = (advantages - advantages.mean()) / ( - advantages.std() + 1e-8) - trajectories['advantages'] = advantages - trajectories['returns'] = returns - trajectories['old_values'] = old_values - end = time.time() - trajectories['stage2_time'] = end - start - return trajectories - - # for _ in range(10): - # critic_output_ref = self._get_values_async(trajectories) - # action_mask = trajectories['action_mask'] - # num_actions = action_mask.size(1) - # values = self._get_values_collect(critic_output_ref) - # old_values = values[:, -num_actions:] - # trajectories['old_values'] = old_values - # return trajectories - - def _get_kl_rewards(self, prompt_datas, trajectories: PolicyOutput): - with Timer('policy_model.infer_async'): - policy_output = self.policy_model.infer_async( - inputs=trajectories.output_ids, - micro_batch_size=self.policy_micro_bs, - attention_mask=trajectories.attention_mask, - output_logits=False, - output_logprobs=True) - with Timer('ref_model.infer_async'): - ref_output = self.ref_model.infer_async( - inputs=trajectories.output_ids, - micro_batch_size=self.ref_micro_bs, - attention_mask=trajectories.attention_mask, - output_logits=False, - output_logprobs=True) - - reward_output_ref = self.get_reward_async(prompt_datas, - trajectories) - rewards = self.get_reward_collect(reward_output_ref) - trajectories['rewards'] = rewards - - with Timer('policy_model.infer_get'): - policy_output = self.policy_model.infer_get(policy_output) - with Timer('ref_model.infer_get'): - ref_output = self.ref_model.infer_get(ref_output) - - # # Experimental - # if self.env.async_reward: - # rewards = self.env.get_reward_collect( - # trajectories['reward_output_ref']) - # trajectories['reward_output_ref'] = None - # trajectories['rewards'] = ray.get(rewards) - # else: - # rewards = trajectories['rewards'] - # # Experimental - # rewards = trajectories['rewards'] - - clipped_rewards = torch.clamp( - rewards, min=self.clip_reward_min, max=self.clip_reward_max) - trajectories['clipped_rewards'] = clipped_rewards - - if self.norm_rewards: - self.running_states.update(clipped_rewards) - norm_reward_score = (clipped_rewards - - self.running_states.mean) / ( - self.running_states.var.sqrt() + 1e-8) - else: - norm_reward_score = clipped_rewards - action_mask = trajectories.action_mask - num_actions = action_mask.size(1) - - policy_logprobs = policy_output.logprobs[:, -num_actions:] - ref_logprobs = ref_output.logprobs[:, -num_actions:] - - if self.kl_coeff <= 0.0: - self.kl_coeff = 0.0 - # compute_approx_kl - log_ratio = policy_logprobs - ref_logprobs - kl = log_ratio * action_mask - kl_reward = -self.kl_coeff * kl - - eos_indices = action_mask.size( - 1) - 1 - action_mask.long().fliplr().argmax( - dim=1, keepdim=True) - last_reward = torch.zeros_like(kl).scatter_( - dim=1, - index=eos_indices, - src=norm_reward_score.unsqueeze(1).to(kl.dtype)) - - reward = last_reward + kl_reward - - entropy = -(policy_logprobs * - action_mask).sum(axis=-1) / action_mask.sum(axis=-1) - return reward, entropy, kl, policy_logprobs, ref_logprobs - - def _get_values(self, trajectories: PolicyOutput): - with Timer('critic_model.infer'): - critic_output = self.critic_model.infer( - inputs=trajectories.output_ids, - attention_mask=trajectories.attention_mask, - output_logits=True, - micro_batch_size=self.critic_micro_bs, - ) - raw_values = critic_output.logits.squeeze(-1) - return raw_values - - def _get_values_async(self, trajectories: PolicyOutput): - with Timer('critic_model.infer_async'): - critic_output_ref = self.critic_model.infer_async( - inputs=trajectories.output_ids, - attention_mask=trajectories.attention_mask, - output_logits=True, - micro_batch_size=self.critic_micro_bs, - ) - return critic_output_ref - - def _get_values_collect(self, critic_output_ref): - with Timer('critic_model.infer_get'): - critic_output = self.critic_model.infer_get(critic_output_ref) - raw_values = critic_output.logits.squeeze(-1) - return raw_values - def get_advantages_and_returns( self, values: torch.Tensor, @@ -236,7 +84,6 @@ def get_advantages_and_returns( returns = advantages + values return advantages.detach(), returns - # default get_reward() is blocking. # get_reward_async() needs to call get_reward_collect() def get_reward_async(self, prompt_datas, policyout): rm_input_messages = [] @@ -282,25 +129,27 @@ def get_reward_collect(self, reward_output_ref): rewards = rm_out.logits.squeeze(-1) return rewards - def get_reward_refer(self, prompt_datas, trajectories): - reward_output_ref = self.get_reward_async(prompt_datas, - trajectories) - ref_output = self.ref_model.infer_async( + def get_reward_and_reference(self, prompt_datas, trajectories): + reward_ref = self.get_reward_async( + prompt_datas, + trajectories) + + reference_ref = self.ref_model.infer_async( inputs=trajectories.output_ids, micro_batch_size=self.ref_micro_bs, attention_mask=trajectories.attention_mask, output_logits=False, output_logprobs=True) - ref_output = self.ref_model.infer_get(ref_output) - rewards = self.get_reward_collect(reward_output_ref) + reference_output = self.ref_model.infer_get(reference_ref) + rewards = self.get_reward_collect(reward_ref) - return rewards, ref_output.logprobs + return rewards, reference_output.logprobs def process_kl_gae(self, rewards, ref_logprobs, values, policy_logprobs, trajectories): - trajectories['rewards'] = rewards clipped_rewards = torch.clamp( rewards, min=self.clip_reward_min, max=self.clip_reward_max) + trajectories['rewards'] = rewards trajectories['clipped_rewards'] = clipped_rewards if self.norm_rewards: @@ -321,24 +170,24 @@ def process_kl_gae(self, rewards, ref_logprobs, values, policy_logprobs, traject self.kl_coeff = 0.0 # compute_approx_kl log_ratio = policy_logprobs - ref_logprobs - kl = log_ratio * action_mask - kl_reward = -self.kl_coeff * kl + kl_distance = log_ratio * action_mask + kl_penalty = -self.kl_coeff * kl_distance eos_indices = action_mask.size( 1) - 1 - action_mask.long().fliplr().argmax( dim=1, keepdim=True) - last_reward = torch.zeros_like(kl).scatter_( + last_reward = torch.zeros_like(kl_distance).scatter_( dim=1, index=eos_indices, - src=norm_reward_score.unsqueeze(1).to(kl.dtype)) + src=norm_reward_score.unsqueeze(1).to(kl_distance.dtype)) - reward = last_reward + kl_reward + kl_rewards = last_reward + kl_penalty entropy = -(policy_logprobs * action_mask).sum(axis=-1) / action_mask.sum(axis=-1) - kl_rewards = reward - kl_distance = kl + # kl_rewards = reward + # kl_distance = kl # return reward, entropy, kl, policy_logprobs, ref_logprobs # (kl_rewards, entropy, kl_distance, policy_logprobs, diff --git a/xtuner/rlhf/test_actor_background.py b/xtuner/rlhf/test_actor_background.py index 37bea35f8..a86271741 100644 --- a/xtuner/rlhf/test_actor_background.py +++ b/xtuner/rlhf/test_actor_background.py @@ -162,18 +162,12 @@ def get(self): repeater_config = config.get('repeater_config', {}) ppo_repeater = ray.remote(KLGAERepeater).remote( ref_model=ref_model, - policy_model=policy_model, - critic_model=critic_model, reward_model=reward_model, - env=txt_env, **repeater_config, ) klgae_repeater = ray.remote(KLGAERepeater).remote( ref_model=ref_model, - policy_model=policy_model, - critic_model=critic_model, reward_model=reward_model, - env=txt_env, **repeater_config, ) # init trainer @@ -278,7 +272,7 @@ def get(self): # trajectories_ref = ppo_repeater.process.remote(prompt_datas, # trajectories_ref) # trajectories_refs_stage2[idx] = trajectories_ref - reward_ref, refer_logprobs_ref = ppo_repeater.get_reward_refer.options(num_returns=2).remote( + reward_ref, refer_logprobs_ref = ppo_repeater.get_reward_and_reference.options(num_returns=2).remote( prompt_datas, trajectories_ref) reward_refer_refs_stage2[idx] = (reward_ref, refer_logprobs_ref) From f8738c327729cd64397e37c5015c758c4d87f41f Mon Sep 17 00:00:00 2001 From: liangkaihuan Date: Tue, 30 Jul 2024 14:58:48 +0800 Subject: [PATCH 31/37] clean code and update summary log --- examples/rlhf/internlm2_1_8b_test_8gpu.py | 5 + xtuner/rlhf/policy_output.py | 2 + xtuner/rlhf/test_actor_background.py | 254 +++++++--------------- 3 files changed, 83 insertions(+), 178 deletions(-) diff --git a/examples/rlhf/internlm2_1_8b_test_8gpu.py b/examples/rlhf/internlm2_1_8b_test_8gpu.py index de9a9d2b4..9d163ce92 100644 --- a/examples/rlhf/internlm2_1_8b_test_8gpu.py +++ b/examples/rlhf/internlm2_1_8b_test_8gpu.py @@ -9,6 +9,10 @@ PROMPT_BATCH_SIZE = 128 PRETRAIN_BATCH_SIZE = 0 # 32 +PIPE_MICRO_BATCH_NUM = 1 +assert PROMPT_BATCH_SIZE % PIPE_MICRO_BATCH_NUM == 0 +PIPE_MICRO_BATCH_SIZE = PROMPT_BATCH_SIZE // PIPE_MICRO_BATCH_NUM #32 + GENERATE_MICRO_BATCH_SIZE = 16 INFER_MICRO_BATCH_SIZE = 8 TRAIN_MICRO_BATCH_SIZE = 2 @@ -75,6 +79,7 @@ ) train_config = dict( + pipe_micro_bs=PIPE_MICRO_BATCH_SIZE, policy_train_micro_bs=TRAIN_MICRO_BATCH_SIZE, critic_train_micro_bs=TRAIN_MICRO_BATCH_SIZE, policy_infer_micro_bs=INFER_MICRO_BATCH_SIZE, diff --git a/xtuner/rlhf/policy_output.py b/xtuner/rlhf/policy_output.py index af7129410..24075e511 100644 --- a/xtuner/rlhf/policy_output.py +++ b/xtuner/rlhf/policy_output.py @@ -112,6 +112,8 @@ def padding_policy_outputs(policy_outputs: list[PolicyOutput], padding_id=0): tensor_keys = union_tensor_keys_from_policy_outputs(policy_outputs) for key in tensor_keys: + if len(policy_outputs[0][key].shape) < 2: + continue padding_id = padding_token_map.get(key, padding_id) max_seq_len = find_max_seq_len(policy_outputs, key) for policy_output in policy_outputs: diff --git a/xtuner/rlhf/test_actor_background.py b/xtuner/rlhf/test_actor_background.py index a86271741..264b22c40 100644 --- a/xtuner/rlhf/test_actor_background.py +++ b/xtuner/rlhf/test_actor_background.py @@ -183,7 +183,7 @@ def get(self): resume_step = train_config.get('resume_step', -1) critic_warmup_step = min(critic_warmup_step, critic_warmup_step - resume_step) - async_learn = train_config.get('async_learn', False) + pipe_micro_bs = train_config['pipe_micro_bs'] # init log file json_f = open(f'{work_dir}/train_rlhf.log.jsonl', 'w') @@ -198,220 +198,118 @@ def get(self): while step <= max_train_step: s_t = time.time() with Timer(f'step {step}: end_to_end'): + # Get Data prompt_datas, prompt_input_messages, pretrain_input_messages = data_generator.get() - # critic_loss_refs, ppo_loss_refs, trajectories_refs = [], [], [] - # micro_bs = 32 - # # for start in range(0, len(prompt_input_messages), micro_bs): - # for idx in range(0, len(prompt_input_messages)//micro_bs): - # # breakpoint() - # # update param in last micro batch - # if idx == len(prompt_input_messages) // micro_bs - 1: - # update_param = True - # else: - # update_param = False - - # # generate trajectories - # trajectories_ref = txt_env.rollout.remote(prompt_datas[idx*micro_bs : (idx+1)*micro_bs], - # prompt_input_messages[idx*micro_bs : (idx+1)*micro_bs], - # display=True) - # # trajectories = ray.get(trajectories_ref) - - # # deal with trajectories - # trajectories_ref = ppo_repeater.process.remote(trajectories_ref) - # # trajectories = ray.get(trajectories_ref) - - # # critic & policy learn - # critic_loss_ref = ppo.critic_learn.remote(trajectories_ref, update_param) - - # # ppo_loss, pt_loss = None, None - # if critic_warmup_step <= 0: - # # ppo_loss, pt_loss = ppo.policy_learn.remote(trajectories) - # ppo_loss_ref = ppo.policy_learn.remote(trajectories_ref, update_param) - # pt_loss = None - - # # logger_train.info( - # # f'[Policy Train] Step: {step}, ' - # # f'ppo loss: {ppo_loss}, pretrain loss: {pt_loss}') - # critic_loss_refs.append(critic_loss_ref) - # ppo_loss_refs.append(ppo_loss_ref) - # trajectories_refs.append(trajectories_ref) - - # ppo_losses = flatten_list(ray.get(ppo_loss_refs)) - # critic_losses = flatten_list(ray.get(critic_loss_refs)) - # trajectories = ray.get(trajectories_refs) - # # trajectories = concat_policy_outputs(trajectories) - # ray.get(ppo.sync_model.remote()) - - - critic_loss_refs, ppo_loss_refs, trajectories_refs = [], [], [] - micro_bs = 128 - num_batches = len(prompt_input_messages) // micro_bs - # Create placeholder lists to manage intermediate results - trajectories_refs_stage1 = [None] * num_batches - trajectories_refs_stage2 = [None] * num_batches - reward_refer_refs_stage2 = [None] * num_batches + num_batch = len(prompt_input_messages) // pipe_micro_bs + logger.info(f'prompt_bs={len(prompt_input_messages)}, ' + f'pipe_micro_bs={pipe_micro_bs}, ' + f'num_batch={num_batch}') - critic_loss_refs = [None] * num_batches - policy_loss_refs = [None] * num_batches - pretrain_loss_refs = [None] * num_batches + traj_refs_stage1 = [None] * num_batch + traj_refs_stage2 = [None] * num_batch + reward_reference_stage2 = [None] * num_batch - critic_time_refs_stage3 = [None] * num_batches - policy_time_refs_stage3 = [None] * num_batches + critic_loss_refs = [None] * num_batch + policy_loss_refs = [None] * num_batch + pretrain_loss_refs = [None] * num_batch # Stage 1: Generate trajectories txt_env.rollout_background.remote(prompt_input_messages, pretrain_input_messages) - for idx in range(num_batches): - trajectories_ref = txt_env.rollout_get.remote(micro_bs) - trajectories_refs_stage1[idx] = trajectories_ref - - # Stage 2: Process trajectories - for idx in range(num_batches): - trajectories_ref = trajectories_refs_stage1[idx] - # trajectories_ref = ppo_repeater.process.remote(prompt_datas, - # trajectories_ref) - # trajectories_refs_stage2[idx] = trajectories_ref - reward_ref, refer_logprobs_ref = ppo_repeater.get_reward_and_reference.options(num_returns=2).remote( - prompt_datas, trajectories_ref) - reward_refer_refs_stage2[idx] = (reward_ref, refer_logprobs_ref) - - # Stage 3: Critic & Policy learn - for idx in range(num_batches): - # update_param = idx == num_batches - 1 - # trajectories_ref = trajectories_refs_stage2[idx] - # critic_loss_ref, critic_time_ref = ppo.critic_learn.options(num_returns=2).remote(trajectories_ref, update_param) - # critic_loss_refs_stage3[idx] = critic_loss_ref - # critic_time_refs_stage3[idx] = critic_time_ref - - # if critic_warmup_step <= 0: - # policy_loss_ref, policy_time_ref = ppo.policy_learn.options(num_returns=2).remote(trajectories_ref, update_param) - # policy_loss_refs_stage3[idx] = policy_loss_ref - # policy_time_refs_stage3[idx] = policy_time_ref - - trajectories_ref = trajectories_refs_stage1[idx] - values_ref, policy_logprobs_ref = ppo.infer.options(num_returns=2).remote(trajectories_ref) - - reward_ref, refer_logprobs_ref = reward_refer_refs_stage2[idx] - trajectories_ref = klgae_repeater.process_kl_gae.remote( - reward_ref, refer_logprobs_ref, values_ref, policy_logprobs_ref, trajectories_ref) - trajectories_refs_stage2[idx] = trajectories_ref - - update_param = (idx == num_batches - 1) - # policy_loss_ref, pretrain_loss_ref, critic_loss_ref = ppo.train.options(num_returns=3).remote(trajectories_ref, update_param, critic_warmup_step) - policy_loss_refs[idx], pretrain_loss_refs[idx], critic_loss_refs[idx] = ( + for idx in range(num_batch): + traj_ref = txt_env.rollout_get.remote(pipe_micro_bs) + traj_refs_stage1[idx] = traj_ref + + # Stage 2: Reward & Reference Model infer + for idx in range(num_batch): + traj_ref = traj_refs_stage1[idx] + reward_ref, reference_logprobs_ref = ( + ppo_repeater.get_reward_and_reference.options(num_returns=2).remote( + prompt_datas, + traj_ref) + ) + reward_reference_stage2[idx] = (reward_ref, reference_logprobs_ref) + + # Stage 3: Critic & Policy infer and learn + for idx in range(num_batch): + # Infer + traj_ref = traj_refs_stage1[idx] + values_ref, policy_logprobs_ref = ppo.infer.options(num_returns=2).remote( + traj_ref) + + # Process KL, GAE + reward_ref, reference_logprobs_ref = reward_reference_stage2[idx] + traj_ref_2 = klgae_repeater.process_kl_gae.remote( + reward_ref, + reference_logprobs_ref, + values_ref, + policy_logprobs_ref, + traj_ref) + traj_refs_stage2[idx] = traj_ref_2 + + # Train + update_param = (idx == num_batch - 1) + policy_loss_ref, pretrain_loss_ref, critic_loss_ref = ( ppo.train.options(num_returns=3).remote( - trajectories_ref, + traj_ref_2, update_param, critic_warmup_step ) ) - - # critic_loss_refs[idx] = critic_loss_ref - # policy_loss_refs[idx] = policy_loss_ref - # pretrain_loss_refs[idx] = pretrain_loss_ref + critic_loss_refs[idx] = critic_loss_ref + policy_loss_refs[idx] = policy_loss_ref + pretrain_loss_refs[idx] = pretrain_loss_ref # Collect results policy_losses = flatten_list(ray.get(policy_loss_refs)) pretrain_losses = flatten_list(ray.get(pretrain_loss_refs)) critic_losses = flatten_list(ray.get(critic_loss_refs)) ray.get(ppo.sync_model.remote()) - trajectories = ray.get(trajectories_refs_stage2) - # critic_times = ray.get(critic_time_refs_stage3) - # policy_times = ray.get(policy_time_refs_stage3) - # breakpoint() - total_time = time.time() - s_t - critic_warmup_step -= 1 + trajectories = ray.get(traj_refs_stage2) + # Post process output + padding_token_map = {'output_ids': policy_model.tokenizer.pad_token_id} + trajectories = concat_policy_outputs(trajectories, + padding_token_map) - # logger_train.info( - # f'[Critic Train] step: {step}, critic loss: {critic_loss}') - # logger_train.info(f'rewards: {trajectories.rewards.mean()}') + critic_warmup_step -= 1 + total_time = time.time() - s_t + # write log if config['rollout_config'].get('write_to_file', True): if not os.path.exists(f'{work_dir}/rollouts'): os.makedirs(f'{work_dir}/rollouts') with open(f'{work_dir}/rollouts/step{step}_rollout.log', 'w') as file: - for traj in trajectories: - for output_s, r, req_id in zip(traj.output_str, - traj.rewards, - traj.req_ids): - # breakpoint() - file.write(output_s + '\n' + - 'Reward: ' + str(r.item()) + '\n' + - 'Req_id: ' + str(req_id) + '\n' + - '=' * 30 + '\n') - - # import torch - # input_ids = [traj.input_ids for traj in trajectories] - # torch.save(torch.concat(input_ids, dim=0), f'{work_dir}/rollouts/step{step}_input_ids.pth') - - # output_ids = [traj.output_ids for traj in trajectories] - # torch.save(torch.concat(output_ids, dim=0), f'{work_dir}/rollouts/step{step}_output_ids.pth') - - # attention_mask = [traj.attention_mask for traj in trajectories] - # torch.save(torch.concat(attention_mask, dim=0), f'{work_dir}/rollouts/step{step}_attention_mask.pth') - - # kl_rewards = [traj.kl_rewards for traj in trajectories] - # torch.save(torch.concat(kl_rewards, dim=0), f'{work_dir}/rollouts/step{step}_kl_rewards.pth') - - # old_values = [traj.old_values for traj in trajectories] - # torch.save(torch.concat(old_values, dim=0), f'{work_dir}/rollouts/step{step}_old_values.pth') - - # orig_values = [traj.orig_values for traj in trajectories] - # torch.save(torch.concat(orig_values, dim=0), f'{work_dir}/rollouts/step{step}_orig_values.pth') - - query_tokens = [] - resp_tokens = [] - rewards = [] - # stage1_times = [] - # stage2_times = [] - for traj in trajectories: - query_tokens.append(traj.question_mask.sum(-1).float().mean().item()) - resp_tokens.append(traj.answer_mask.sum(-1).float().mean().item()) - rewards.append(traj.rewards.mean().item()) - # stage1_times.append(traj['stage1_time']) - # stage2_times.append(traj['stage2_time']) - - query_tokens_mean = sum(query_tokens) / len(query_tokens) - resp_tokens_mean = sum(resp_tokens) / len(resp_tokens) - reward_mean = sum(rewards) / len(rewards) + for output_s, r, req_id in zip(trajectories.output_str, + trajectories.rewards, + trajectories.req_ids): + file.write(output_s + '\n' + + 'Reward: ' + str(r.item()) + '\n' + + 'Req_id: ' + str(req_id) + '\n' + + '=' * 30 + '\n') policy_loss_mean = sum(policy_losses) / len(policy_losses) if policy_losses else None pretrain_loss_mean = sum(pretrain_losses) / len(pretrain_losses) if pretrain_losses else None critic_loss_mean = sum(critic_losses) / len(critic_losses) summaries = dict( - # reward_mean=trajectories.rewards.mean().item(), - # reward_std=trajectories.rewards.std().item(), - # new_tokens_mean=trajectories.action_mask.sum( - # -1).float().mean().item(), - # new_tokens_std=trajectories.action_mask.sum( - # -1).float().std().item(), - # kl=trajectories.kl.mean().item(), - # entropy=trajectories.entropy.mean().item(), + reward_mean=trajectories.rewards.mean().item(), + reward_std=trajectories.rewards.std().item(), + new_tokens_mean=trajectories.action_mask.sum( + -1).float().mean().item(), + new_tokens_std=trajectories.action_mask.sum( + -1).float().std().item(), + resp_tokens_mean=trajectories.answer_mask.sum( + -1).float().mean().item(), + kl=trajectories.kl.mean().item(), + entropy=trajectories.entropy.mean().item(), step=step, policy_loss=policy_loss_mean, - # pretrain_loss=pt_loss, + pretrain_loss=pretrain_loss_mean, critic_loss=critic_loss_mean, - - # query_tokens_mean=trajectories.question_mask.sum( - # -1).float().mean().item(), - # resp_tokens_mean=trajectories.answer_mask.sum( - # -1).float().mean().item(), - # generate_time=gen_time, - # forward_time=fwd_time, - # training_time=train_time, - # stage1_time=stage1_times, - # stage2_time=stage2_times, - # critic_time=critic_times, - # policy_time=policy_times, total_time=total_time, - reward_mean=reward_mean, - query_tokens=query_tokens_mean, - resp_tokens=resp_tokens_mean, ) - # with open(f'{work_dir}/train_rlhf.log.jsonl', 'a') as f: json_f.write(json.dumps(summaries) + '\n') json_f.flush() logger_train.info(f'[end to end] duration: {time.time() - s_t} s') From ecf8df9791b298a71cba53cd5c7c63390af414de Mon Sep 17 00:00:00 2001 From: liangkaihuan Date: Tue, 30 Jul 2024 15:07:21 +0800 Subject: [PATCH 32/37] add Timer --- xtuner/rlhf/envs/txt_env.py | 5 +++-- xtuner/rlhf/repeaters/kl_gae.py | 26 +++++++++++--------------- 2 files changed, 14 insertions(+), 17 deletions(-) diff --git a/xtuner/rlhf/envs/txt_env.py b/xtuner/rlhf/envs/txt_env.py index ce10b1c53..eacf1c385 100644 --- a/xtuner/rlhf/envs/txt_env.py +++ b/xtuner/rlhf/envs/txt_env.py @@ -72,7 +72,7 @@ def rollout_background( if display: logger.info( f'[TXT_ENV For Generate]: \n{prompt_input_messages[0]}') - with Timer('policy_model.generate'): + with Timer('txt_env.generate_background'): self.policy_model.generate_background( inputs=prompt_input_messages, micro_batch_size=self.policy_micro_bs, @@ -83,7 +83,8 @@ def rollout_background( def rollout_get(self, num): # prompt data - trajectories = self.policy_model.get_generate_finish(num) + with Timer('txt_env.rollout_get'): + trajectories = self.policy_model.get_generate_finish(num) # pretrain data # TODO: Get pretrain data proportionally diff --git a/xtuner/rlhf/repeaters/kl_gae.py b/xtuner/rlhf/repeaters/kl_gae.py index 7699ba0f9..7884aa14b 100644 --- a/xtuner/rlhf/repeaters/kl_gae.py +++ b/xtuner/rlhf/repeaters/kl_gae.py @@ -133,15 +133,18 @@ def get_reward_and_reference(self, prompt_datas, trajectories): reward_ref = self.get_reward_async( prompt_datas, trajectories) + + with Timer('ref_model.infer_async'): + reference_ref = self.ref_model.infer_async( + inputs=trajectories.output_ids, + micro_batch_size=self.ref_micro_bs, + attention_mask=trajectories.attention_mask, + output_logits=False, + output_logprobs=True) + + with Timer('ref_model.infer_get'): + reference_output = self.ref_model.infer_get(reference_ref) - reference_ref = self.ref_model.infer_async( - inputs=trajectories.output_ids, - micro_batch_size=self.ref_micro_bs, - attention_mask=trajectories.attention_mask, - output_logits=False, - output_logprobs=True) - - reference_output = self.ref_model.infer_get(reference_ref) rewards = self.get_reward_collect(reward_ref) return rewards, reference_output.logprobs @@ -186,13 +189,6 @@ def process_kl_gae(self, rewards, ref_logprobs, values, policy_logprobs, traject entropy = -(policy_logprobs * action_mask).sum(axis=-1) / action_mask.sum(axis=-1) - # kl_rewards = reward - # kl_distance = kl - # return reward, entropy, kl, policy_logprobs, ref_logprobs - - # (kl_rewards, entropy, kl_distance, policy_logprobs, - # ref_logprobs) = self._get_kl_rewards(prompt_datas, trajectories) - trajectories['kl'] = (kl_distance * action_mask).sum( axis=-1) / action_mask.sum(axis=-1) trajectories['entropy'] = entropy From fc9f44c5c9611f87666fe57ebebd2fdbc800f3ab Mon Sep 17 00:00:00 2001 From: liangkaihuan Date: Tue, 30 Jul 2024 15:23:16 +0800 Subject: [PATCH 33/37] remove redunant code --- xtuner/rlhf/model_backend/hf_model_runner.py | 108 ------------------- 1 file changed, 108 deletions(-) diff --git a/xtuner/rlhf/model_backend/hf_model_runner.py b/xtuner/rlhf/model_backend/hf_model_runner.py index aaea7224e..55095cb30 100644 --- a/xtuner/rlhf/model_backend/hf_model_runner.py +++ b/xtuner/rlhf/model_backend/hf_model_runner.py @@ -45,7 +45,6 @@ class HfModelRunner: def __init__(self, model_config): self.model_config: dict = model_config - self.index = 0 def initialize(self): # 0. Environment @@ -312,15 +311,6 @@ def train( loss_entry.append(loss) if debug: set_seed(1234) - self.info_rank0(f"[{self.model_type}] Train mb_index: {mb_index}, loss: {loss.item()}") - - #debug - # if self.accelerator.is_main_process: - # mbs = micro_batch['input_ids'].shape[0] * len(micro_batches) - # micro_batch['loss'] = loss.item() - # torch.save(micro_batch, f'/mnt/afs_2/liangkaihuan/Codes/xtuner/data/{self.model_type}_mbs{mbs}_index{self.index}.pth') - # self.index += 1 - # breakpoint() loss_list[index] = sum(loss_entry) / len(loss_entry) @@ -434,87 +424,9 @@ def infer( policy_outputs.append(policy_output_mb) if debug: self.set_seed(1234) - self.info_rank0(f"[{self.model_type}] Infer mb_index: {index}") # Concatenate the policy outputs from each micro batch and return the result # noqa: E501 return concat_policy_outputs(policy_outputs) - - @torch.no_grad() - def infer_from_future( - self, - # input_ids: torch.Tensor, - # attention_mask=None, - object_refs, - micro_batch_size: Optional[ - int] = -1, # -1: use the entire input as one batch - tokenizer=None, # Only used for reward models - output_logprobs=False, - output_logits=True, - output_attentions=False, - output_hidden_states=False, - infer_kwargs: Optional[dict] = {}, - debug=True, - **_ignored, - ) -> PolicyOutput: - self.info_rank0( - f'[{self.model_type}] self.infer() kwargs: {infer_kwargs}') - - ########## - outputs = ray.get(object_refs, timeout=None) - padding_token_map = { - 'output_ids': 0, - } - trajectories = concat_policy_outputs(outputs, padding_token_map) - input_ids = trajectories.output_ids - attention_mask = trajectories.attention_mask - ########## - - input_ids = input_ids.to(self.device) - if attention_mask is not None: - attention_mask = attention_mask.to(self.device) - # returns entire-input-as-one-batch inference results - if micro_batch_size < 0: - self.info_rank0( - f'[{self.model_type}] infer() input_ids.shape: {input_ids.shape}' # noqa: E501 - ) - return self._infer( - input_ids, - attention_mask, - output_logprobs, - output_logits, - output_attentions, - output_hidden_states, - infer_kwargs, - ) - - # Otherwise, partition the input into micro batches and run inference on each micro batch separately # noqa: E501 - micro_batches = partition_by_micro_batch_size(input_ids, - micro_batch_size, - attention_mask) - policy_outputs = [] - for index, micro_batch in enumerate(micro_batches): - input_ids_mb = micro_batch['input_ids'] - attention_mask_mb = micro_batch['attention_mask'] - if index == 0: - self.info_rank0( - f'[{self.model_type}] will infer() input_ids_mb.shape: {input_ids_mb.shape} * {len(micro_batches)} times' # noqa: E501 - ) - policy_output_mb = self._infer( - input_ids_mb, - attention_mask_mb, - output_logprobs, - output_logits, - output_attentions, - output_hidden_states, - infer_kwargs, - ) - policy_outputs.append(policy_output_mb) - if debug: - self.set_seed(1234) - # Concatenate the policy outputs from each micro batch and return the result # noqa: E501 - return concat_policy_outputs(policy_outputs) - - # Generate @torch.no_grad() def _generate( @@ -921,26 +833,6 @@ def infer(self, *args, **kwargs): object_refs = self.infer_async(*args, **kwargs) return self.infer_get(object_refs) - # Inference - def infer_from_future(self, object_refs, *args, **kwargs): - # micro_batch_size = input_ids.shape[0] // self.dp_size + ( - # input_ids.shape[0] % self.dp_size > 0 - # ) # round up division, i.e., math.ceil(a / b) - # micro_batches = partition_by_micro_batch_size(input_ids, - # micro_batch_size, - # attention_mask) - - # assert len(micro_batches) == self.dp_size - return [ - self.ray_actors[0].infer_from_future.remote( - # input_ids=micro_batch['input_ids'], - # attention_mask=micro_batch['attention_mask'], - object_refs, - *args, - **kwargs, - ) - ] - # Generation def generate_async(self, input_ids, attention_mask, *args, **kwargs): micro_batch_size = input_ids.shape[0] // self.dp_size + ( From 4c47b7fdc197ba69b0248c149680f7ae1bbc3eeb Mon Sep 17 00:00:00 2001 From: liangkaihuan Date: Tue, 30 Jul 2024 15:47:23 +0800 Subject: [PATCH 34/37] rename test code --- scripts/train_1node.sh | 7 +------ xtuner/rlhf/{test_actor_background.py => pipeline.py} | 0 xtuner/rlhf/{ => tests}/test_actor.py | 0 xtuner/rlhf/{ => tests}/test_policy_ref_pipe.py | 0 xtuner/rlhf/{ => tests}/test_ref.py | 0 xtuner/rlhf/{ => tests}/test_vllm.py | 0 6 files changed, 1 insertion(+), 6 deletions(-) rename xtuner/rlhf/{test_actor_background.py => pipeline.py} (100%) rename xtuner/rlhf/{ => tests}/test_actor.py (100%) rename xtuner/rlhf/{ => tests}/test_policy_ref_pipe.py (100%) rename xtuner/rlhf/{ => tests}/test_ref.py (100%) rename xtuner/rlhf/{ => tests}/test_vllm.py (100%) diff --git a/scripts/train_1node.sh b/scripts/train_1node.sh index 3ec4db042..0afad0bd6 100644 --- a/scripts/train_1node.sh +++ b/scripts/train_1node.sh @@ -1,8 +1,6 @@ #!/bin/bash set -ex -export XTERM=linux - # export NCCL_DEBUG=INFO export NCCL_IB_TIMEOUT=22 export NCCL_IB_RETRY_CNT=13 @@ -23,7 +21,4 @@ if [ ! -f $config_file ]; then fi mkdir -p $work_dirs -#python xtuner/rlhf/main.py -c $config_file -w $work_dirs > $work_dirs/debug.log 2>&1 & -# python xtuner/rlhf/main.py -c $config_file -w $work_dirs 2>&1 | tee $work_dirs/main-$start_time.log -# python xtuner/rlhf/test_actor.py -c $config_file -w $work_dirs 2>&1 | tee $work_dirs/main-$start_time.log -python xtuner/rlhf/test_actor_background.py -c $config_file -w $work_dirs 2>&1 | tee $work_dirs/main-$start_time.log +python xtuner/rlhf/pipeline.py -c $config_file -w $work_dirs 2>&1 | tee $work_dirs/main-$start_time.log diff --git a/xtuner/rlhf/test_actor_background.py b/xtuner/rlhf/pipeline.py similarity index 100% rename from xtuner/rlhf/test_actor_background.py rename to xtuner/rlhf/pipeline.py diff --git a/xtuner/rlhf/test_actor.py b/xtuner/rlhf/tests/test_actor.py similarity index 100% rename from xtuner/rlhf/test_actor.py rename to xtuner/rlhf/tests/test_actor.py diff --git a/xtuner/rlhf/test_policy_ref_pipe.py b/xtuner/rlhf/tests/test_policy_ref_pipe.py similarity index 100% rename from xtuner/rlhf/test_policy_ref_pipe.py rename to xtuner/rlhf/tests/test_policy_ref_pipe.py diff --git a/xtuner/rlhf/test_ref.py b/xtuner/rlhf/tests/test_ref.py similarity index 100% rename from xtuner/rlhf/test_ref.py rename to xtuner/rlhf/tests/test_ref.py diff --git a/xtuner/rlhf/test_vllm.py b/xtuner/rlhf/tests/test_vllm.py similarity index 100% rename from xtuner/rlhf/test_vllm.py rename to xtuner/rlhf/tests/test_vllm.py From 2e45ef4e0697c81e5521d0418ed68278fcf10a9d Mon Sep 17 00:00:00 2001 From: liangkaihuan Date: Tue, 30 Jul 2024 16:17:25 +0800 Subject: [PATCH 35/37] remove align code --- xtuner/rlhf/model_backend/hf_model_runner.py | 6 +-- .../rlhf/model_backend/vllm_model_runner.py | 8 ++-- .../rlhf/model_server/reward_model_server.py | 40 +++++++++---------- 3 files changed, 27 insertions(+), 27 deletions(-) diff --git a/xtuner/rlhf/model_backend/hf_model_runner.py b/xtuner/rlhf/model_backend/hf_model_runner.py index 55095cb30..a53c624e5 100644 --- a/xtuner/rlhf/model_backend/hf_model_runner.py +++ b/xtuner/rlhf/model_backend/hf_model_runner.py @@ -247,7 +247,7 @@ def train( step_interval: int = 1, # None means using the entire input as one batch micro_batch_size: Optional[Union[list[int], int]] = None, - debug=True, + debug=False, update_param=True, **_ignored, ): @@ -377,7 +377,7 @@ def infer( output_attentions=False, output_hidden_states=False, infer_kwargs: Optional[dict] = {}, - debug=True, + debug=False, **_ignored, ) -> PolicyOutput: self.info_rank0( @@ -518,7 +518,7 @@ def generate( output_hidden_states=False, chat_template=None, generate_kwargs: Optional[dict] = {}, - debug=True, + debug=False, **_ignored, ) -> PolicyOutput: self.info_rank0( diff --git a/xtuner/rlhf/model_backend/vllm_model_runner.py b/xtuner/rlhf/model_backend/vllm_model_runner.py index a6369061d..f5ff19e65 100644 --- a/xtuner/rlhf/model_backend/vllm_model_runner.py +++ b/xtuner/rlhf/model_backend/vllm_model_runner.py @@ -259,8 +259,8 @@ def generate_background( else: raise ValueError(f'Unsupported inputs with type({type(inputs)})') - self.max_inputs_length = 1024 - # self.max_inputs_length = max_inputs_length + # self.max_inputs_length = 1024 + self.max_inputs_length = max_inputs_length self.output_str = output_str self.output_logits = output_logits self.output_attentions = output_attentions @@ -310,8 +310,8 @@ def pad_list_with_pad_token_right(int_list, max_length, pad_token_id): output_token_ids = [ item for item in req_output.outputs[0].token_ids ] - output_token_ids = pad_list_with_pad_token_right(output_token_ids, 1024, - self.tokenizer.pad_token_id) + # output_token_ids = pad_list_with_pad_token_right(output_token_ids, 1024, + # self.tokenizer.pad_token_id) output_ids = input_ids + output_token_ids # concat output['input_ids'] = torch.Tensor(input_ids).to( diff --git a/xtuner/rlhf/model_server/reward_model_server.py b/xtuner/rlhf/model_server/reward_model_server.py index f5c95a29c..8540eb82c 100644 --- a/xtuner/rlhf/model_server/reward_model_server.py +++ b/xtuner/rlhf/model_server/reward_model_server.py @@ -29,32 +29,32 @@ def init_tokenizer_and_config(self, model_config): # Inference def infer_async(self, inputs, attention_mask=None, *args, **infer_kwargs): if not isinstance(inputs, torch.Tensor): - # input_ids, attention_mask = encode_inputs(inputs, self.tokenizer) + input_ids, attention_mask = encode_inputs(inputs, self.tokenizer) # Debug - if isinstance(inputs[0], list): - inputs = [ - self.tokenizer.apply_chat_template( - input, - tokenize=False, - add_generation_prompt=False, - return_tensors='pt', - ) for input in inputs - ] - output = self.tokenizer( - inputs, - return_tensors='pt', - padding='max_length', - max_length=2048, - add_special_tokens=False) - input_ids, attention_mask = output.input_ids, output.attention_mask + # if isinstance(inputs[0], list): + # inputs = [ + # self.tokenizer.apply_chat_template( + # input, + # tokenize=False, + # add_generation_prompt=False, + # return_tensors='pt', + # ) for input in inputs + # ] + # output = self.tokenizer( + # inputs, + # return_tensors='pt', + # padding='max_length', + # max_length=2048, + # add_special_tokens=False) + # input_ids, attention_mask = output.input_ids, output.attention_mask else: input_ids = inputs # Reward model specific - # if self.reward_token_id is not None: - # input_ids, attention_mask = expand_reward_token_id( - # self.reward_token_id, input_ids, attention_mask) + if self.reward_token_id is not None: + input_ids, attention_mask = expand_reward_token_id( + self.reward_token_id, input_ids, attention_mask) return self.trainer.infer_async( input_ids=input_ids, From 23bb2f10e666fb4565be86b3c34db6943a6f66b1 Mon Sep 17 00:00:00 2001 From: liangkaihuan Date: Tue, 30 Jul 2024 17:43:53 +0800 Subject: [PATCH 36/37] update internlm2_20b config --- examples/rlhf/internlm2_20b_test_32gpu.py | 22 +++++++++++----------- scripts/train_ray.sh | 2 +- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/examples/rlhf/internlm2_20b_test_32gpu.py b/examples/rlhf/internlm2_20b_test_32gpu.py index 53d02d61f..9fd42f6ac 100644 --- a/examples/rlhf/internlm2_20b_test_32gpu.py +++ b/examples/rlhf/internlm2_20b_test_32gpu.py @@ -9,6 +9,10 @@ PROMPT_BATCH_SIZE = 128 PRETRAIN_BATCH_SIZE = 0 # 0 +PIPE_MICRO_BATCH_NUM = 4 +assert PROMPT_BATCH_SIZE % PIPE_MICRO_BATCH_NUM == 0 +PIPE_MICRO_BATCH_SIZE = PROMPT_BATCH_SIZE // PIPE_MICRO_BATCH_NUM #32 + GENERATE_MICRO_BATCH_SIZE = 8 INFER_MICRO_BATCH_SIZE = 2 TRAIN_MICRO_BATCH_SIZE = 1 @@ -44,10 +48,8 @@ rollout_config = dict( policy_micro_bs=GENERATE_MICRO_BATCH_SIZE, - reward_micro_bs=GENERATE_MICRO_BATCH_SIZE, max_new_tokens=MAX_ANSWER_LEN, - # write_to_file=True, - write_to_file=False, ## Debug-Only + write_to_file=True, resume_step=RESUME_STEP, generate_kwargs={ 'do_sample': True, @@ -64,10 +66,8 @@ ) repeater_config = dict( - policy_micro_bs=INFER_MICRO_BATCH_SIZE, - critic_micro_bs=INFER_MICRO_BATCH_SIZE, ref_micro_bs=INFER_MICRO_BATCH_SIZE, - #ref_micro_bs=8, ## Optimize + reward_micro_bs=GENERATE_MICRO_BATCH_SIZE, kl_coeff=0.01, gamma=1.0, gae_lambda=0.99, @@ -77,8 +77,11 @@ ) train_config = dict( - policy_micro_bs=TRAIN_MICRO_BATCH_SIZE, - critic_micro_bs=TRAIN_MICRO_BATCH_SIZE, + pipe_micro_bs=PIPE_MICRO_BATCH_SIZE, + policy_train_micro_bs=TRAIN_MICRO_BATCH_SIZE, + critic_train_micro_bs=TRAIN_MICRO_BATCH_SIZE, + policy_infer_micro_bs=INFER_MICRO_BATCH_SIZE, + critic_infer_micro_bs=INFER_MICRO_BATCH_SIZE, ppo_loss_weight=1.0, pretrain_loss_weight=0.5, # critic_warmup_step=40, @@ -86,7 +89,6 @@ save_interval=200, max_train_step=800, resume_step=RESUME_STEP, - async_learn=True, ## Optimize ) model_configs = dict( @@ -199,7 +201,6 @@ deepspeed_config={ "zero_optimization": { "stage": 3, - #"stage": 0, ## Optimize "overlap_comm": True, "stage3_gather_16bit_weights_on_model_save": True }, @@ -232,7 +233,6 @@ deepspeed_config={ "zero_optimization": { "stage": 3, - #"stage": 0, ## Optimize "overlap_comm": True, "stage3_gather_16bit_weights_on_model_save": True }, diff --git a/scripts/train_ray.sh b/scripts/train_ray.sh index f64f612b3..61c5ab306 100644 --- a/scripts/train_ray.sh +++ b/scripts/train_ray.sh @@ -89,7 +89,7 @@ if [ "$node_role" == "master" ]; then done fi # python xtuner/rlhf/main.py -c $config_file -w $work_dirs 2>&1 | tee $work_dirs/main-$start_time.log - python -u xtuner/rlhf/test_actor_background.py -c $config_file -w $work_dirs > $work_dirs/main-$start_time.log 2>&1 & + python -u xtuner/rlhf/pipeline.py -c $config_file -w $work_dirs > $work_dirs/main-$start_time.log 2>&1 & else sleep infinity 2>&1 & fi From 784174cb351e65453579b0dad41e2dc8160aaa1e Mon Sep 17 00:00:00 2001 From: liangkaihuan Date: Tue, 30 Jul 2024 17:47:32 +0800 Subject: [PATCH 37/37] add pipeline readme --- README_pipeline.md | 100 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 100 insertions(+) create mode 100644 README_pipeline.md diff --git a/README_pipeline.md b/README_pipeline.md new file mode 100644 index 000000000..955b763ba --- /dev/null +++ b/README_pipeline.md @@ -0,0 +1,100 @@ +## pipeline优化 +### 优化原理 + +RLHF的每次迭代过程可以分为三个阶段:Generation、Forward和Train。在Generation阶段,由vLLM推理生成回复;在Forward阶段,actor、critic、reference和reward四个模型进行推理;在Train阶段,actor和critic模型进行训练。 + +在每个阶段运行时,其它阶段的GPU会处于空闲等待状态,导致资源浪费。 + +为了解决这个问题,可以借助流水线并行的思想进行优化。将batch数据分为多个小的micro-batch,每个阶段处理完一个micro-batch后,立即将数据传递到下一个阶段进行处理,而不是等待整个batch处理完成。这样可以减少各阶段GPU的空闲等待时间,提高资源利用率。 + +### 运行步骤 + +1)vLLM添加接口 +- 获取vLLM安装路径 + ```shell + export vllm=$(pip show numpy | grep Location | awk '{print $2"/vllm"}') + ``` + +- 编辑$vllm/entrypoints/llm.py,在`class LLM`中添加下面两个接口 + ```python + def generate_to_queue( + self, + prompts: Optional[Union[str, List[str]]] = None, + sampling_params: Optional[SamplingParams] = None, + prompt_token_ids: Optional[List[List[int]]] = None, + prefix_pos: Optional[Union[int, List[int]]] = None, + use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, + queue = None, + ) -> List[RequestOutput]: + """Generates the completions for the input prompts and put result to queue. + """ + if prompts is None and prompt_token_ids is None: + raise ValueError("Either prompts or prompt_token_ids must be " + "provided.") + if isinstance(prompts, str): + # Convert a single prompt to a list. + prompts = [prompts] + if (prompts is not None and prompt_token_ids is not None + and len(prompts) != len(prompt_token_ids)): + raise ValueError("The lengths of prompts and prompt_token_ids " + "must be the same.") + if sampling_params is None: + # Use default sampling params. + sampling_params = SamplingParams() + + # Add requests to the engine. + num_requests = len(prompts) if prompts is not None else len( + prompt_token_ids) + for i in range(num_requests): + prompt = prompts[i] if prompts is not None else None + prefix_pos_i = prefix_pos[i] if prefix_pos is not None else None + token_ids = None if prompt_token_ids is None else prompt_token_ids[ + i] + self._add_request(prompt, + sampling_params, + token_ids, + lora_request=lora_request, + prefix_pos=prefix_pos_i) + return self._run_engine_to_queue(use_tqdm, queue) + + + def _run_engine_to_queue(self, use_tqdm: bool, queue) -> List[RequestOutput]: + # Initialize tqdm. + if use_tqdm: + num_requests = self.llm_engine.get_num_unfinished_requests() + pbar = tqdm(total=num_requests, desc="Processed prompts") + # Run the engine. + outputs: List[RequestOutput] = [] + while self.llm_engine.has_unfinished_requests(): + step_outputs = self.llm_engine.step() + for output in step_outputs: + if output.finished: + outputs.append(output) + queue.put(output) + if use_tqdm: + pbar.update(1) + if use_tqdm: + pbar.close() + # Sort the outputs by request ID. + # This is necessary because some requests may be finished earlier than + # its previous requests. + outputs = sorted(outputs, key=lambda x: int(x.request_id)) + return outputs + ``` + +2)启动训练 +```shell +bash scripts/train_ray.sh ${config_file} ${work_dir} ${num_node} +``` + +### 参数配置 +参考配置文件 examples/rlhf/internlm2_20b_test_32gpu.py +```python +... +PIPE_MICRO_BATCH_NUM = 4 # 调整micro-batch的数量 +... +``` + +### 精度影响 +启用norm_rewards时,精度无法严格对齐。原因在于norm_rewards对奖励进行了归一化处理。在优化前,归一化操作是在整个batch上进行的;而优化后,归一化操作是在每个micro-batch上分别进行。