Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Computing faster lopgs #72

Open
alexvishnevskiy opened this issue Mar 9, 2024 · 3 comments
Open

Computing faster lopgs #72

alexvishnevskiy opened this issue Mar 9, 2024 · 3 comments

Comments

@alexvishnevskiy
Copy link

Hi, great work! The results and research in this area are truly amazing. I have a question regarding the concatenated_forward part. From my understanding, we just need logs from both chosen and rejected responses. Why can't we have a batch that consists of [prompt + chosen_response + rejected_response] instead of [prompt + chosen_response, prompt + rejected_response]? It should be okay to calculate logps for both chosen and rejected responses without them intersecting with each other, using an attention mask. Correct me if I'm wrong, thanks!

def concatenated_forward(self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
        """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
        
           We do this to avoid doing two forward passes, because it's faster for FSDP.
        """
        concatenated_batch = concatenated_inputs(batch)
        all_logits = model(concatenated_batch['concatenated_input_ids'], attention_mask=concatenated_batch['concatenated_attention_mask']).logits.to(torch.float32)
        all_logps = _get_batch_logps(all_logits, concatenated_batch['concatenated_labels'], average_log_prob=False)
        chosen_logps = all_logps[:batch['chosen_input_ids'].shape[0]]
        rejected_logps = all_logps[batch['chosen_input_ids'].shape[0]:]
        return chosen_logps, rejected_logps
@bhavyashahh
Copy link

Also, how do you assure that when doing model fwd step, the prompt+rejected do not attend to the chosen response, at what place in code is this check made?

@cthorrez
Copy link

I think you are misunderstanding the implementation. They are concatenated in the batch dimension in order to get the logps for both in one forward pass instead of two. They are not concatenated in the sequence dimension so they will not attend to each other.

@bhavyashahh
Copy link

yes, i did not carefully read concat on dim=0. thank you for pointing out.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants