This repo draws from the excellently written HALOs repo and DPO repo. We have preserved many design choices from the orignal.
This repo is to provide a generali framework for aligning large language models(LLMs) with the Transformers and Datasets from Huggingface. Unlike the TRL framework from Huggingface, we hereby incorporate the following features:
- Support for modifying the weights of training samples.
- Support for generating responses from the LLM policy.
- Support for getting feedback with online reward model or language model.
A diagram of the data flow from a high level is shown below:
In the following, we introduce the major components of the framework, but not by the order of the data flow.
The BatchFactory
is a class that generates batches of data for training.
It takes a DatasLoader
object as input and generates train/test batches to the Trainer
of the following format:
{
'prompt': str, # prompt for the generated texts
'prompt_token_ids': tensor, # token ids of the prompt
'prompt_attention_mask': tensor, # attention mask of the prompt
'generation1': str, # text of prompt + 1st generation
'generation1_response_only': str, # text of the 1st generation only
'generation1_token_ids': tensor, # token ids of the 1st generation
'generation1_attention_mask': tensor, # attention mask of the 1st generation
'generation1_reward': float, # reward of the 1st generation
'generation1_weight': float, # weight of the 1st generation
'generation2': str, # text of prompt + 2nd generation
'generation2_response_only': str, # text of the 2nd generation only
'generation2_token_ids': tensor, # token ids of the 2nd generation
'generation2_attention_mask': tensor, # attention mask of the 2nd generation
'generation2_reward': float, # reward of the 2nd generation
'generation2_weight': float, # weight of the 2nd generation
}
Note that the above items are not necessarily all included in the batch.
Below is a diagram of data in BatchFactory
. Note that the final output batches
We hereby list the learning tasks and the corresponding batch items as well as the source of them:
-
Supervised fine-tuning: only
prompt
andgeneration1
in the batch. Moreover, thegeneration1_reward
isNone
and thegeneration1_weight
is always 1.0. -
Reward modelling in RLHF:
prompt
,generation1
, andgeneration2
are all included. However, thegeneration1_reward
andgeneration2_reward
are bothNone
. Thegeneration1_weight
andgeneration2_weight
are always 1.0. -
Reinforcement learning: only
prompt
andgeneration1
are included, and thegeneration1_reward
is obtained from the online Reward Model. -
Offline Pointwise preference learning: only
prompt
andgeneration1
are included. Thegeneration1_reward
is 1.0 if thegeneration1
is a desired response, otherwise 0.0 to indicate thatgeneration1
is undesired. Thegeneration1_weight
is always 1.0. (Check out HALOs repo for the details of training models with pointwise desirable/undesirable feedback.)Both the generations and annotations are from the precollected and fixed
DatasLoader
, thus this is an OFFLINE learning setup. -
Online Pointwise preference learning: same to the above offline pointwise preference learning, except that the
generation1
is sampled from the LLM policy being training and thegeneration1_reward
is obtained from the online Annotator.The generations are from the LLM policy being trained and the feedbacks from online annotator, thus this is an ONLINE learning setup.*
-
Offline Pairwise preference learning:
prompt
,generation1
, andgeneration2
are all included. Thegeneration1_reward
is 1.0 andgeneration2_reward
is 0.0 to indicate thatgeneration1
is preferred overgeneration2
. Thegeneration1_weight
andgeneration2_weight
are always 1.0.Like the offline pointwise preference learning setup, the generations and annotations are from the precollected and fixed
DatasLoader
, thus this is an OFFLINE learning setup. -
Online Pairwise preference learning: same to the above offline pairwise preference learning, except that the
generation1
andgeneration2
are sampled from the LLM policy being training.generation1_reward
is 1.0 ifgeneration1
is preferred overgeneration2
by the online annotator, otherwise 0.0.The generations are from the LLM policy being trained and the feedbacks from online annotator, thus this is an ONLINE learning setup.*
The DatasLoader
is a class that loads the original data from Huggingface hub.
Note that there might be a DISTRIBUTION SHIFT problem between the responses in DatasLoader
to the responses generated by the LLM policy being trained.
To be more specific, suppose that the responses in DatasLoader
were generated by a language model
Difference between DatasLoader
and BatchFactory
: In short, DatasLoader
is a component of BatchFactory
.
The DatasLoader
yields pre-collected and pre-annotated responses from BatchFactory
can either keep the responses and preferences from Annotator
.
In the later case, only the prompts from DatasLoader
are kept by BatchFactory
.