In this work, we introduce Weighted-Reward Preference Optimization (WRPO) for the implicit model fusion of heterogeneous open-source LLMs with diverse architectures and sizes, aiming to create a more capable and robust target LLM.
As shown in Figure below, this objective introduces a fusion coefficient
In our experiments, we use LLaMA3-8B-Instruct as the target LLM. As for the source LLMs, we include ten advanced open-source models of varying architectures and sizes. We assess the performance of our models on three representative instruction-following benchmarks: MT-Bench, AlpacaEval-2, and Arena-Hard. Extensive experiments demonstrate that WRPO consistently outperforms existing knowledge fusion methods and various fine-tuning baselines.
This repository includes a requirements file specifying the Python package versions used in our experiments. We utilized Python 3.10
and CUDA 12.2
for this work.
We use one of the subset of UltraFeedback provided in princeton-nlp/llama3-ultrafeedback-armorm to construct our training dataset.
Our training dataset contains quadruples of (x, yws, ywt, yl), where yws is a response from Source the LLMs, ywt and yl are responses from the Target LLM.
The Target and Source LLMs, along with their corresponding Huggingface IDs, are listed below:
Models | Huggingface ID |
---|---|
Target (LLaMA-3-8B-Instruct) | meta-llama/Meta-Llama-3-8B-Instruct |
Mistral-Large-Instruct-2407 | mistralai/Mistral-Large-Instruct-2407 |
Gemma2-27B-IT | google/gemma-2-27b-it |
Qwen2-72B-Instruct | Qwen/Qwen2-72B-Instruct |
LLaMA3-70B-Instruct | meta-llama/Meta-Llama-3-70B-Instruct |
Gemma2-9B-IT | google/gemma-2-9b-it |
Internlm2.5-20B-Chat | internlm/internlm2_5-20b-chat |
DeepSeek-V2-Chat | deepseek-ai/DeepSeek-V2-Chat-0628 |
DeepSeek-Coder-V2-Instruct | deepseek-ai/DeepSeek-Coder-V2-Instruct-0724 |
Yi-1.5-34B-Chat | 01-ai/Yi-1.5-34B-Chat |
Phi-3-medium-4k-instruct | microsoft/Phi-3-medium-4k-instruct |
-
For each prompt in the Ultrafeedback dataset, we sample five responses from each source LLM. This was done using top-p sampling (p=0.95) with a temperature of 0.8, following the pipeline in the SimPO repository.
-
We employ ArmoRM-LLaMA3-8B-v0.1 as the reward model to score and rank these responses. The highest-scoring response across all source models is selected as one of the preferred responses, yws.
-
The dataset is split into two parts: one-third of the data instances are randomly sampled for supervised fine-tuning (SFT), while the remaining instances are used for preference optimization, as detailed in our paper.
-
We apply Supervised Fine-Tuning (SFT) on the set of yws using first part of the dataset.
-
We then generate five samples from the SFT model using the remaining dataset. The response with the highest score is labeled as ywt, while the lowest-scoring response is regarded as yl.
Below is an example instance of our dataset, where "chosen" is a list containing [yws, ywt], and "rejected" is a list containing [yl].
{
"prompt_id": "3ebac2832721f4ef8e9ead1bb19e251e5d21d60dbd9f89cae931fe4aac900058",
"prompt": {
"content": "Given the task definition and input, reply with output. The provided file includes inquiries about restaurants in Finnish, and we ask you to translate those to English language. Please bear in mind the following guidelines while doing the translation: 1) We are looking for the most naturally written and formal form of each sentence in your language. We are *NOT* looking for colloquial forms of the sentence. We are looking for formal form which is how you would type your queries in a text-based virtual assistant. 2) The words between quotation marks *SHOULD NOT* be translated. We expect you to keep those values intact and include the quotation marks around them as well. 3) The fully capitalized words like DATE_0, or DURATION_0 *SHOULD NOT* be translated. Please keep them as they are in the translations. 4) Please do not localize measurement units like miles to kilometers during your translation. miles should be translated to its equivalent in your language. 6) Note the input is all lowercased except for fully capitalized special placeholders (e.g. NUMBER, DATE, TIME). Please do the same in your translations.\n\nkerro minulle \" dave 's seafood shack \" -ravintolan sijainti.\n",
"role": "user"
},
"chosen": [
{
"content": "Tell me the location of the \"dave's seafood shack\" restaurant.",
"role": "assistant"
},
{
"content": "Here is the translated output:\n\nTell me the location of \"Dave's Seafood Shack\".",
"role": "assistant"
}
],
"rejected": [
{
"content": "Here's the translation:\n\nTell me the location of \"Dave's Seafood Shack\".",
"role": "assistant"
}
]
}
We provide configuration files for both training stages, designed for an environment with 8x80G A800 GPUs. You may need to adjust num_processes
and per_device_train_batch_size
based on your specific computational setup.
- To run Target-SFT:
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file training_configs/deepspeed_zero3.yaml scripts/run_sft.py training_configs/llama-3-8b-instruct-sft.yaml
- To run Target-SFT-WRPO:
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file training_configs/deepspeed_zero3.yaml scripts/run_wrpo.py training_configs/llama-3-8b-instruct-sft-wrpo.yaml
@article{yang2024wrpo,
title={Weighted-Reward Preference Optimization for Implicit Model Fusion},
author={Ziyi Yang and Fanqi Wan and Longguang Zhong and Tianyuan Shi and Xiaojun Quan},
journal={arXiv preprint arXiv:2412.03187},
year={2024}
}