Skip to content

SLIT-AI/WRPO

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

16 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Weighted-Reward Preference Optimization for Implicit Model Fusion

Version License Stars Issues

Overview

In this work, we introduce Weighted-Reward Preference Optimization (WRPO) for the implicit model fusion of heterogeneous open-source LLMs with diverse architectures and sizes, aiming to create a more capable and robust target LLM. As shown in Figure below, this objective introduces a fusion coefficient $\alpha$ that dynamically balances the internal reward of the preferred response yws from source LLMs and that of ywt from the target during training. This approach enables the target LLM to transition smoothly from its distribution to align with that of the source LLMs.


Overall Results

In our experiments, we use LLaMA3-8B-Instruct as the target LLM. As for the source LLMs, we include ten advanced open-source models of varying architectures and sizes. We assess the performance of our models on three representative instruction-following benchmarks: MT-Bench, AlpacaEval-2, and Arena-Hard. Extensive experiments demonstrate that WRPO consistently outperforms existing knowledge fusion methods and various fine-tuning baselines.


Requirements

This repository includes a requirements file specifying the Python package versions used in our experiments. We utilized Python 3.10 and CUDA 12.2 for this work.

Training Data Construction

We use one of the subset of UltraFeedback provided in princeton-nlp/llama3-ultrafeedback-armorm to construct our training dataset.

Our training dataset contains quadruples of (x, yws, ywt, yl), where yws is a response from Source the LLMs, ywt and yl are responses from the Target LLM.

Target & Source LLMs

The Target and Source LLMs, along with their corresponding Huggingface IDs, are listed below:

Models Huggingface ID
Target (LLaMA-3-8B-Instruct) meta-llama/Meta-Llama-3-8B-Instruct
Mistral-Large-Instruct-2407 mistralai/Mistral-Large-Instruct-2407
Gemma2-27B-IT google/gemma-2-27b-it
Qwen2-72B-Instruct Qwen/Qwen2-72B-Instruct
LLaMA3-70B-Instruct meta-llama/Meta-Llama-3-70B-Instruct
Gemma2-9B-IT google/gemma-2-9b-it
Internlm2.5-20B-Chat internlm/internlm2_5-20b-chat
DeepSeek-V2-Chat deepseek-ai/DeepSeek-V2-Chat-0628
DeepSeek-Coder-V2-Instruct deepseek-ai/DeepSeek-Coder-V2-Instruct-0724
Yi-1.5-34B-Chat 01-ai/Yi-1.5-34B-Chat
Phi-3-medium-4k-instruct microsoft/Phi-3-medium-4k-instruct

Construction of yws

  1. For each prompt in the Ultrafeedback dataset, we sample five responses from each source LLM. This was done using top-p sampling (p=0.95) with a temperature of 0.8, following the pipeline in the SimPO repository.

  2. We employ ArmoRM-LLaMA3-8B-v0.1 as the reward model to score and rank these responses. The highest-scoring response across all source models is selected as one of the preferred responses, yws.

Construction of ywt and yl

  1. The dataset is split into two parts: one-third of the data instances are randomly sampled for supervised fine-tuning (SFT), while the remaining instances are used for preference optimization, as detailed in our paper.

  2. We apply Supervised Fine-Tuning (SFT) on the set of yws using first part of the dataset.

  3. We then generate five samples from the SFT model using the remaining dataset. The response with the highest score is labeled as ywt, while the lowest-scoring response is regarded as yl.

Below is an example instance of our dataset, where "chosen" is a list containing [yws, ywt], and "rejected" is a list containing [yl].

{
    "prompt_id": "3ebac2832721f4ef8e9ead1bb19e251e5d21d60dbd9f89cae931fe4aac900058",
    "prompt": {
        "content": "Given the task definition and input, reply with output. The provided file includes inquiries about restaurants in Finnish, and we ask you to translate those to English language. Please bear in mind the following guidelines while doing the translation: 1) We are looking for the most naturally written and formal form of each sentence in your language. We are *NOT* looking for colloquial forms of the sentence. We are looking for formal form which is how you would type your queries in a text-based virtual assistant. 2) The words between quotation marks *SHOULD NOT* be translated. We expect you to keep those values intact and include the quotation marks around them as well. 3) The fully capitalized words like DATE_0, or DURATION_0 *SHOULD NOT* be translated. Please keep them as they are in the translations. 4) Please do not localize measurement units like miles to kilometers during your translation. miles should be translated to its equivalent in your language. 6) Note the input is all lowercased except for fully capitalized special placeholders (e.g. NUMBER, DATE, TIME). Please do the same in your translations.\n\nkerro minulle \" dave 's seafood shack \" -ravintolan sijainti.\n",
        "role": "user"
    },
    "chosen": [
        {
            "content": "Tell me the location of the \"dave's seafood shack\" restaurant.",
            "role": "assistant"
        },
        {
            "content": "Here is the translated output:\n\nTell me the location of \"Dave's Seafood Shack\".",
            "role": "assistant"
        }
    ],
    "rejected": [
        {
            "content": "Here's the translation:\n\nTell me the location of \"Dave's Seafood Shack\".",
            "role": "assistant"
        }
    ]
}

Training configurations

We provide configuration files for both training stages, designed for an environment with 8x80G A800 GPUs. You may need to adjust num_processes and per_device_train_batch_size based on your specific computational setup.

Commands

  • To run Target-SFT:
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file training_configs/deepspeed_zero3.yaml scripts/run_sft.py training_configs/llama-3-8b-instruct-sft.yaml
  • To run Target-SFT-WRPO:
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file training_configs/deepspeed_zero3.yaml scripts/run_wrpo.py training_configs/llama-3-8b-instruct-sft-wrpo.yaml

Citation

@article{yang2024wrpo,
  title={Weighted-Reward Preference Optimization for Implicit Model Fusion},
  author={Ziyi Yang and Fanqi Wan and Longguang Zhong and Tianyuan Shi and Xiaojun Quan},
  journal={arXiv preprint arXiv:2412.03187},
  year={2024}
}

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%