You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi, great work! The results and research in this area are truly amazing. I have a question regarding the concatenated_forward part. From my understanding, we just need logs from both chosen and rejected responses. Why can't we have a batch that consists of [prompt + chosen_response + rejected_response] instead of [prompt + chosen_response, prompt + rejected_response]? It should be okay to calculate logps for both chosen and rejected responses without them intersecting with each other, using an attention mask. Correct me if I'm wrong, thanks!
def concatenated_forward(self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
"""Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
We do this to avoid doing two forward passes, because it's faster for FSDP.
"""
concatenated_batch = concatenated_inputs(batch)
all_logits = model(concatenated_batch['concatenated_input_ids'], attention_mask=concatenated_batch['concatenated_attention_mask']).logits.to(torch.float32)
all_logps = _get_batch_logps(all_logits, concatenated_batch['concatenated_labels'], average_log_prob=False)
chosen_logps = all_logps[:batch['chosen_input_ids'].shape[0]]
rejected_logps = all_logps[batch['chosen_input_ids'].shape[0]:]
return chosen_logps, rejected_logps
The text was updated successfully, but these errors were encountered:
Also, how do you assure that when doing model fwd step, the prompt+rejected do not attend to the chosen response, at what place in code is this check made?
I think you are misunderstanding the implementation. They are concatenated in the batch dimension in order to get the logps for both in one forward pass instead of two. They are not concatenated in the sequence dimension so they will not attend to each other.
Hi, great work! The results and research in this area are truly amazing. I have a question regarding the concatenated_forward part. From my understanding, we just need logs from both chosen and rejected responses. Why can't we have a batch that consists of [prompt + chosen_response + rejected_response] instead of [prompt + chosen_response, prompt + rejected_response]? It should be okay to calculate logps for both chosen and rejected responses without them intersecting with each other, using an attention mask. Correct me if I'm wrong, thanks!
The text was updated successfully, but these errors were encountered: