Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow loss masking for defined spans of characters #113

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

sohamparikh
Copy link
Member

@sohamparikh sohamparikh commented Jan 14, 2025

✨ Description

Support loss masking for spans specified in the input data. This PR will ensure that loss will not be computed on the specified spans. The biggest use-case for this is instruction tuning data where we want to avoid training on the prompts.

Closes #109

📝 Changes

List the key changes introduced in this PR:

  • Support character spans as inputs specified in the prepare command
  • Read the spans during training and apply masks to cross-entropy loss

🔍 Type of change

  • 🐛 Bug fix (non-breaking change that addresses a specific issue)
  • 🚀 New feature (non-breaking change that adds functionality)
  • ⚠️ Breaking change (a change that could affect existing functionality)
  • 📈 Performance improvement/optimization (improves speed, memory usage, or efficiency)
  • 🛠️ Code refactor (non-functional changes that improve code readability, structure, etc.)
  • 📦 Dependency bump (updates dependencies, including Dockerfile or package changes)
  • 📝 Documentation change (updates documentation, including new content or typo fixes)
  • 🔧 Infrastructure/Build change (affects build process, CI/CD, or dependencies)

@jlamypoirier
Copy link
Collaborator

Looks good so far, but can you please add a short description and/or point to an issue?

@sohamparikh sohamparikh changed the title convert character spans to token spans Allow loss masking for defined spans of characters Jan 24, 2025
for start, end in char_spans:
if char_pos < start:
curr_text = text[char_pos:start]
tokenized_text = self._tokenizer.tokenize(curr_text, add_special_tokens=beginning_of_text)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This works only for those tokenizers that only have a BOS but not a EOS token.
For those that come with both, can we control whether tokenize adds the BOS and EOS tokens independently? I'm worried that we are adding the EOS token at the end of the first segment and the BOS token at the beginning of the last segment.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch! I'll make it explicitly add BOS only for the first segment
Btw, most tokenizers (Llama-3.1, Mistral-Nemo-Base-2407, OLMoE-1B-7B-0924) do not add the EOS token with add_special_tokens=True. Does this mean we've been training the models without the EOS token?

In the future I think we should make this config driven. The default behaviour would be to add both BOS and EOS tokens. It's important for pretraining with attention mask, and especially for SFT.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean we've been training the models without the EOS token?

Indeed, we decided that adding both BOS and EOS tokens in pretraining was unnecessary, because they are redundant. Here though I think we need to add them because we need to teach the model to terminate a response with the EOS token so that generation can stop at the right moment. Btw, I think HF is not adding the EOS token by default because otherwise prompts would end with it.

exp_logits1 = exp_logits.scatter(
1, target, exp_logits.gather(1, target) - target_mask * sum_exp_logits.unsqueeze(dim=-1)
)
exp_logits2 = exp_logits1.mul((grad_output / logits.size(0)) / sum_exp_logits.unsqueeze(dim=-1))
if logits_scale_factor != 1.0:
exp_logits2 *= logits_scale_factor

grad = exp_logits2.to(logits.dtype)
grad.index_put_((mask,), exp_logits2.to(logits.dtype))

predicted_logits = (target_mask * logits_norm.gather(1, target)).squeeze(1)
all_reduce(predicted_logits, op=ReduceOp.SUM, group=group)
Copy link
Collaborator

@tscholak tscholak Jan 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does the triton implementation support masking?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it doesn't: https://github.com/ServiceNow/Fast-LLM/blob/soham/loss-masking-spans/fast_llm/functional/triton/cross_entropy.py
We need to add it. Since this is the same for all loss functions, it would make sense to implement it before dispatching to specialized cross-entropy implementations:

def cross_entropy_forward_backward(
    logits,
    target,
    grad_output: float | None,
    group: ProcessGroup | None,
    implementation: CrossEntropyImpl = CrossEntropyImpl.fused,
    logits_scale_factor: float = 1.0,
    ignore_index: int=-100,
) -> tuple[torch.Tensor, torch.Tensor | None]:
    ...
    mask = target != ignore_index
    target = target[mask]
    logits = logits[mask]
    ...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[feat] Implement Loss Masking to Exclude Predefined Token Spans from LM Loss
3 participants