Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(pooling): faster avg. with EmbeddingBags #146

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

fdschmidt93
Copy link

@fdschmidt93 fdschmidt93 commented Oct 2, 2024

This is a draft PR to speed-up the implementation of averaging token embeddings. One thing to note is that EmbeddingBag has slight but negligible numerical differences to current implementation. I'll finalize this once most recent transformers is supported.

Below is the benchmark code of various implementations. mean_embedding_bag2 corresponds to this PR. It may be slightly slower than the other implementation since indices are not precomputed. In a fair setup, it will be (negligibly) faster since padded tokens are ignored without overhead due to flattened indices.

Results on a 4090 with a simulated batch of N=256, L sampled between 350-512 and hidden dim of 4,096.

<torch.utils.benchmark.utils.common.Measurement object at 0x75ece41b8f40>
mean_iter(hidden_states, attention_mask)
setup: from __main__ import mean_iter
  7.29 ms
  1 measurement, 1000 runs , 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x75ece41b99f0>
mean(hidden_states, attention_mask)
setup: from __main__ import mean
  6.77 ms
  1 measurement, 1000 runs , 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x75ece41b9900>
mean_embedding_bag(hidden_states, offsets, padding_offset)
setup: from __main__ import mean_embedding_bag
  2.09 ms
  1 measurement, 1000 runs , 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x75ece41b8f40>
mean_embedding_bag2(hidden_states, attention_mask)
setup: from __main__ import mean_embedding_bag2
  2.09 ms
  1 measurement, 1000 runs , 1 thread
import torch
import torch.nn.functional as F
from typing import cast, Optional


def mean(
    hidden_states: torch.Tensor, attention_mask: torch.Tensor, *args, **kwargs
) -> torch.Tensor:
    with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
        attention_mask_ = attention_mask.clamp(min=0, max=1)
        return (hidden_states * attention_mask_[:, :, None]).sum(
            1
        ) / attention_mask_.sum(-1, keepdim=True)


def cls(hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
    return hidden_states[:, 0, :]


def eos(
    hidden_states: torch.Tensor,
    attention_mask: torch.Tensor,
    padding_side: str = "right",
    *args,
    **kwargs,
) -> torch.Tensor:
    if padding_side == "right":
        N = torch.arange(hidden_states.shape[0], device=hidden_states.device)
        eos_token_id = attention_mask.sum(1) - 1
        return hidden_states[N, eos_token_id, :]
    else:
        return hidden_states[:, -1, :]


def get_padding_offset(attention_mask: torch.Tensor) -> int:
    """
    If mask was flattened, give first offset of a padding token.
    If no padding token exists, return -1
    """
    try:
        return cast(int, (attention_mask.view(-1) == 0).nonzero()[0].item())
    except IndexError as _:
        return -1


def get_offsets(
    attention_mask: torch.Tensor, padding_offset: Optional[int] = None
) -> torch.Tensor:
    """
    [[1 1 1 0 0]
     [1 1 1 1 1]] becomes

    [[0 1 2 3 3]
     [5 6 7 8 9]]

    assuming padding_offset 3 was input.
    """
    N, L = attention_mask.shape
    offsets = torch.arange(N * L, device=attention_mask.device).view(N, L)
    if isinstance(padding_offset, int):
        offsets[~(attention_mask.bool())] = padding_offset
    return offsets


def mean_embedding_bag(
    hidden_states: torch.Tensor,
    offsets: torch.Tensor,
    padding_idx: int,
    *args,
    **kwargs,
):
    token_embeds = hidden_states.view(-1, hidden_states.shape[-1])
    with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
        if padding_idx > -1:
            return F.embedding_bag(
                weight=token_embeds,
                input=offsets,
                padding_idx=padding_idx,
            )
        else:
            return F.embedding_bag(
                weight=token_embeds,
                input=offsets,
            )


def mean_embedding_bag2(
    hidden_states: torch.Tensor,
    attention_mask: torch.Tensor,
    *args,
    **kwargs,
):
    """
    Compute the mean of non-padded embeddings using `embedding_bag`,
    properly handling padding with offsets.
    """
    with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
        # Flatten hidden_states to 2D: shape (batch_size * seq_len, embedding_dim)
        batch_size, seq_len, embed_dim = hidden_states.shape
        token_embeds = hidden_states.view(-1, embed_dim)

        # Find the indices of non-padded tokens in flattened hidden_states
        input_indices = attention_mask.view(-1).nonzero(as_tuple=False).squeeze()

        # Compute the offsets: for each sequence, where it starts in the flattened input
        non_padded_lengths = attention_mask.sum(
            dim=1
        )  # Count non-padded tokens per sequence
        offsets = torch.cat(
            [
                torch.tensor([0], device=hidden_states.device),
                non_padded_lengths.cumsum(dim=0)[:-1],
            ]
        )

        # Use embedding_bag with mode 'mean' and appropriate padding index
        return F.embedding_bag(
            input=input_indices,  # Indices of non-padded tokens in flattened form
            weight=token_embeds,  # The flattened hidden states as embedding matrix
            offsets=offsets,  # Offsets specifying start of each sequence
            mode="mean",  # Aggregation mode
        )


def mean_iter(hidden_states: torch.Tensor, attention_mask: torch.Tensor):
    with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
        out = []
        for hs, mask in zip(hidden_states, attention_mask):
            out.append(hs[: mask.sum(), :].mean(0))
        embeds_mean_iter = torch.vstack(out)
    return embeds_mean_iter


hidden_states = torch.randn(256, 512, 4096).to("cuda:0")
attention_mask = (
    torch.randint(350, 512, (256,))[:, None] >= torch.arange(512)[None]
).long()
attention_mask = attention_mask.to("cuda:0")
embeds_mean_vec = mean(hidden_states, attention_mask)
with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
    out = []
    for hs, mask in zip(hidden_states, attention_mask):
        out.append(hs[: mask.sum(), :].mean(0))
    embeds_mean_iter = torch.vstack(out)
    padding_offset = get_padding_offset(attention_mask)
    offsets = get_offsets(attention_mask, padding_offset)
    embeds_mb1 = mean_embedding_bag(hidden_states, offsets, padding_offset)
    embeds_mb2 = mean_embedding_bag2(hidden_states, attention_mask)

print(torch.allclose(embeds_mean_vec, embeds_mean_iter))  # true
print(torch.allclose(embeds_mean_vec, embeds_mb1))  # false
print(torch.allclose(embeds_mean_vec, embeds_mb2))  # false
print(torch.allclose(embeds_mean_iter, embeds_mb1))  # false
print(torch.allclose(embeds_mean_iter, embeds_mb2))  # false
print(torch.allclose(embeds_mb1, embeds_mb2))  # true

from torch.utils.benchmark import Timer

# Example usage

t_iter = Timer(
    stmt="mean_iter(hidden_states, attention_mask)",
    setup="from __main__ import mean_iter",
    globals={"hidden_states": hidden_states, "attention_mask": attention_mask},
)
t_vec = Timer(
    stmt="mean(hidden_states, attention_mask)",
    setup="from __main__ import mean",
    globals={"hidden_states": hidden_states, "attention_mask": attention_mask},
)
t_emb2 = Timer(
    stmt="mean_embedding_bag2(hidden_states, attention_mask)",
    setup="from __main__ import mean_embedding_bag2",
    globals={"hidden_states": hidden_states, "attention_mask": attention_mask},
)
t_emb1 = Timer(
    stmt="mean_embedding_bag(hidden_states, offsets, padding_offset)",
    setup="from __main__ import mean_embedding_bag",
    globals={
        "hidden_states": hidden_states,
        "offsets": offsets,
        "padding_offset": padding_offset,
    },
)

print(t_iter.timeit(1000))
print(t_vec.timeit(1000))
# precomputes indices etc
print(t_emb1.timeit(1000))
# computes indices etc on the fly
print(t_emb2.timeit(1000))

@vaibhavad
Copy link
Collaborator

Thanks a lot @fdschmidt93 ! This optimization will be incredible useful.

I just pushed #147 to main that supports latest version of transformers. Can you merge/rebase main into this branch and make sure it works?

@fdschmidt93
Copy link
Author

Note: This PR currently exposes usage of EmbeddingBag as a separate option, but as we'll discussed offline, it'll be the new faster default.

Verification below

import torch
import torch.nn.functional as F
from llm2vec import LLM2Vec

l2v = LLM2Vec.from_pretrained(
    "McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp",
    peft_model_name_or_path="McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp-supervised",
    device_map="cuda" if torch.cuda.is_available() else "cpu",
    torch_dtype=torch.bfloat16,
)

texts = ["This is a test sentence.", "Another example for mean pooling."]

l2v.pooling_mode = "mean"
assert l2v.pooling_mode == "mean"
default_embedding = l2v.encode(texts)

l2v.pooling_mode = "embedding_bag"
assert l2v.pooling_mode == "embedding_bag"
bagged_embedding = l2v.encode(texts)

print(F.mse_loss(default_embedding, bagged_embedding)) # tensor(8.6357e-07)
print(F.cosine_similarity(default_embedding, bagged_embedding)) # tensor([1.0000, 1.0000])

There's only slight numerical differences which don't matter as per cosine similarity :) but 4x speed-up 🚀

@fdschmidt93
Copy link
Author

fdschmidt93 commented Oct 3, 2024

Given the importance of the function, please briefly checkout the first commit and run the above function to verify you get the same output (barring minuscule GPU differences) @vaibhavad 😅 :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants