Welcome to the official repository for the work Multi-Agent Sampling: Scaling Inference Compute for Data Synthesis with Tree Search-Based Agentic Collaboration.
In this work, we tackle the challenge of synthesizing alignment data from multiple distinct language models, such as Llama3, Qwen2, Mistral, and others—an approach known as multi-agent sampling. To address this problem, we introduce TOA, a novel method leveraging Tree Search-based Orchestrated Agents.
Our approach integrates Monte Carlo Tree Search (MCTS) with a Reward Model to optimize collaboration among diverse language models, ensuring high-quality alignment data synthesis.
- [2024/12/22] TOA paper is out at arXiv.
- Key Features
- Supported Methods
- Supported LLMs and Reward Models
- Synthesized Alignment Data
- Quick Start
- Results
- Citation
This repository introduces TOA, a framework for multi-agent sampling to synthesize high-quality alignment data from diverse language models:
- Alignment Data Synthesis: Generates high-quality responses from multiple language models.
- Agent Collaboration: Coordinates diverse models for scalable and robust data synthesis.
- Monte Carlo Tree Search (MCTS): Optimizes response generation using MCTS with a reward model.
- 😊 Universal Model Compatibility: Fully compatible with any model offering OpenAI-like APIs:
- 🎯 Reward Model Integration: Allows custom reward models to guide generation.
- 💰 Compute Efficient: Uses MCTS for efficient computation and response generation.
- 📣 Flexible Sampling Methods: Supports both single-agent and multi-agent sampling.
The table below summarizes the key methods supported in this repository, along with references to their respective papers and example code:
Method | Paper | Example Code |
---|---|---|
Random Sampling | Link | bash |
PRS | Link | bash |
Parallel Ensemble | Link | bash |
Sequential Refine | Link | bash |
MoA | Link | bash |
TOA (Ours) | Link | bash |
- Random Sampling and PRS are single-agent-based methods.
- TOA represents our novel approach, integrating Monte Carlo Tree Search (MCTS) to optimize multi-agent collaboration.
We have tested the following open-source language models (LLMs) and reward models in our framework:
The LLMs are grouped below for clarity and ease of reference:
The following reward models have been evaluated in our experiments:
Reward Models |
---|
Skywork-Reward-Llama-3.1-8B-v0.2 |
Skywork-Reward-Gemma-2-27B-v0.2 |
ArmoRM-Llama3-8B-v0.1 |
We synthesized alignment data using the following four language models and a reward model. The input prompts were sourced from Ultrafeedback. For each prompt:
- 160 responses were sampled.
- The best response with the highest reward was selected for SFT data.
- For DPO training, the ranked 30th response was used as the rejected sample, and the best response was used as the chosen sample.
Models | Reward Model | SFT Data | DPO Data |
---|---|---|---|
Llama-3.1-8B-Instruct | ArmoRM-Llama3-8B-v0.1 | Rand-Qwen2-7B-Inst | TOA |
Qwen2-7B-Instruct | Rand-Lla3.1-8B-Inst | ||
Mistral-7B-Instruct-v0.2 | PRS-Qwen2-7B-Inst | ||
Yi-1.5-9B-Chat-16K | PRS-Lla3.1-8B-Inst | ||
Par. Ensemble | |||
SeqRefine | |||
MoA | |||
TOA |
- SFT Data: The best response from the sampled 160 responses is used for supervised fine-tuning.
- DPO Data: The 30th ranked response serves as the rejected sample, while the top response is used as the accepted sample.
- Explore the provided links for detailed datasets and models.
Accordingly, we fine-tuned Llama3-8b-inst with the generated synthetic data using SFT or DPO training loss:
Model Name | Link |
---|---|
Rand-SFT | Rand-SFT |
PRS-SFT | PRS-SFT |
Par. Ensemble-SFT | Par. Ensemble-SFT |
Seq. Refine-SFT | Seq. Refine-SFT |
MoA-SFT | MoA-SFT |
TOA-SFT | TOA-SFT |
TOA-DPO | TOA-DPO |
We primarily require vLLM to be installed to ensure efficient and fast model inference for our code. Please make sure that vLLM is installed on your machine. Alternatively, other toolkits like Sglang can also be used.
If you want to host language models locally, you can use the provided code to start local servers.
Navigate to the directory and run the following command:
cd bash/launch_large_models
python start_server.vllm.py path_to_config root_to_save GPU port gpu_utilize
- path_to_config: Path to the model configuration file (in JSON format). Example:
{
"policy_model": {
"llama-3.1-8b-instruct": {
"path_to_model": "",
"path_to_chat_template": "../chat_templates/llama-3.1-instruct.jinja",
"stop_tokens": "['<|eot_id|>']"
}
}
}
- root_to_save: Path to save the server configuration (in JSON format). Example:
{
"model_name": "llama-3.1-8b-instruct",
"config": {
"path_to_model": "",
"path_to_chat_template": "../chat_templates/llama-3.1-instruct.jinja",
"stop_tokens": "['<|eot_id|>']",
"api_key": "abc123",
"port": 8000,
"host": "localhost",
"GPU": "0",
"gpu_utilize": 0.9
}
}
- GPU: GPU IDs to use, e.g., "0", "0,1", or "0,1,2,3".
- port: Port number for the server, e.g., 8000, 8001, etc.
- gpu_utilize: Percentage of GPU memory to use, e.g., 0.9 for 90%.
- You can start servers for different models using this script.
- Ensure that all server configurations are saved in the same directory (specified by root_to_save).
By following these steps, you can run multiple local servers for hosting language models seamlessly.
A reward model is required to generate real-time rewards for the generated responses.
- Navigate to the
model_configs
directory:
cd model_configs
- Provide the configuration file in JSON format. An example configuration looks like this:
{
"reward_model": {
"name": "ArmoRM",
"path": "",
"GPU": "0"
}
}
If you need to use a personalized reward model, update the code in the following file:
- code/reward.py
Within this file, you must specify how the reward model will be used for reward calculation. Ensure the implementation aligns with your specific model’s requirements.
By following these steps, you can easily integrate and customize the reward model for your needs.
You are now ready to start generating data! Follow the steps below:
- Navigate to the experiment directory:
cd bash
cd exp_alpaca_eval
- Run the provided script to start the synthesis process:
bash run_generate.api.mcts.pre_load.sh
What Happens Next
- After the generation is complete, you will obtain multiple responses for each input prompt.
- Each response is associated with a reward, which can be used for:
- Reject Sampling: Filter out lower-quality responses.
- Best-of-N Sampling: Select the highest-quality response from the generated samples.
By following these steps, you can efficiently generate alignment data tailored to your requirements.
(a) Result of AlpacaEval 2.0 |
(b) Result of WMT'22. |
(c) Scaling results on AlpacaEval. |
|
(d) Effectiveness of synthetic alignment data. |
We utilize a combination of five advanced language models to perform best-of-160 sampling:
- Llama-3.1-70B-Instruct
- Mistral-Large-Instruct-2407
- Qwen2-72B-Instruct
- Mixtral-8x22B-Instruct-v0.1
- WizardLM-2-8x22B
For the reward model, we use ArmoRM-Llama3-8B-v0.1.
The results are illustrated in Fig. (a).
cd bash/exp_alpaca_eval
TOA: bash run_generate.api.mcts.pre_load.sh
MoA: bash run_generate.api.moa.pre_load.sh
Seq. Refine: bash run_generate.api.ensemble_seq.pre_load.sh
Ensemble: bash run_generate.api.ensemble.pre_load.sh
PRS: bash run_generate.api.prs.pre_load.sh
We provide the generated responses here.
We also employ the aforementioned large language models for best-of-160 sampling. The reward model is updated to KIWI.
The results are presented in Fig. (b), where the evaluation metrics are based on KIWI-XXL.
cd bash/exp_nmt
We provide the generated responses here.
We present the results of scaling inference compute in Fig. (c), demonstrating that our TOA approach is the most compute-efficient among the baselines.
- Left: Results are obtained using ArmoRM-Llama3-8B-v0.1 as the reward model for both generation and evaluation.
- Right: Results include an additional round of evaluation using GPT-4 to assess the best response with the highest Armo Reward.
We compare synthetic data generation across various baselines and fine-tune Llama-3.1-8B-Instruct. The comparison involves outputs generated using 4 small models:
Outputs are generated using Ultrafeedback prompts, sampling 160 responses per prompt. The best response is retained using ArmoRM-Llama3-8B-v0.1.
As shown in Fig. (d), synthetic data generated by our method achieves superior results on AlpacaEval and Arena-Hard benchmarks. Post-DPO training, our approach establishes a new state-of-the-art (SOTA), outperforming both DPO and SimPO.
If you find this work useful, please cite it as:
@misc{ye2024multiagentsamplingscalinginference,
title={Multi-Agent Sampling: Scaling Inference Compute for Data Synthesis with Tree Search-Based Agentic Collaboration},
author={Hai Ye and Mingbao Lin and Hwee Tou Ng and Shuicheng Yan},
year={2024},
eprint={2412.17061},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2412.17061},
}