-
Notifications
You must be signed in to change notification settings - Fork 20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow loss masking for defined spans of characters #113
base: main
Are you sure you want to change the base?
Conversation
Looks good so far, but can you please add a short description and/or point to an issue? |
for start, end in char_spans: | ||
if char_pos < start: | ||
curr_text = text[char_pos:start] | ||
tokenized_text = self._tokenizer.tokenize(curr_text, add_special_tokens=beginning_of_text) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This works only for those tokenizers that only have a BOS but not a EOS token.
For those that come with both, can we control whether tokenize
adds the BOS and EOS tokens independently? I'm worried that we are adding the EOS token at the end of the first segment and the BOS token at the beginning of the last segment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good catch! I'll make it explicitly add BOS only for the first segment
Btw, most tokenizers (Llama-3.1
, Mistral-Nemo-Base-2407
, OLMoE-1B-7B-0924
) do not add the EOS token with add_special_tokens=True
. Does this mean we've been training the models without the EOS token?
In the future I think we should make this config driven. The default behaviour would be to add both BOS and EOS tokens. It's important for pretraining with attention mask, and especially for SFT.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this mean we've been training the models without the EOS token?
Indeed, we decided that adding both BOS and EOS tokens in pretraining was unnecessary, because they are redundant. Here though I think we need to add them because we need to teach the model to terminate a response with the EOS token so that generation can stop at the right moment. Btw, I think HF is not adding the EOS token by default because otherwise prompts would end with it.
exp_logits1 = exp_logits.scatter( | ||
1, target, exp_logits.gather(1, target) - target_mask * sum_exp_logits.unsqueeze(dim=-1) | ||
) | ||
exp_logits2 = exp_logits1.mul((grad_output / logits.size(0)) / sum_exp_logits.unsqueeze(dim=-1)) | ||
if logits_scale_factor != 1.0: | ||
exp_logits2 *= logits_scale_factor | ||
|
||
grad = exp_logits2.to(logits.dtype) | ||
grad.index_put_((mask,), exp_logits2.to(logits.dtype)) | ||
|
||
predicted_logits = (target_mask * logits_norm.gather(1, target)).squeeze(1) | ||
all_reduce(predicted_logits, op=ReduceOp.SUM, group=group) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does the triton implementation support masking?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it doesn't: https://github.com/ServiceNow/Fast-LLM/blob/soham/loss-masking-spans/fast_llm/functional/triton/cross_entropy.py
We need to add it. Since this is the same for all loss functions, it would make sense to implement it before dispatching to specialized cross-entropy implementations:
def cross_entropy_forward_backward(
logits,
target,
grad_output: float | None,
group: ProcessGroup | None,
implementation: CrossEntropyImpl = CrossEntropyImpl.fused,
logits_scale_factor: float = 1.0,
ignore_index: int=-100,
) -> tuple[torch.Tensor, torch.Tensor | None]:
...
mask = target != ignore_index
target = target[mask]
logits = logits[mask]
...
✨ Description
Support loss masking for spans specified in the input data. This PR will ensure that loss will not be computed on the specified spans. The biggest use-case for this is instruction tuning data where we want to avoid training on the prompts.
Closes #109
📝 Changes
List the key changes introduced in this PR:
🔍 Type of change