Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify byte_pair_merge #17

Merged
merged 1 commit into from
Oct 17, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 36 additions & 56 deletions src/corebpe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,80 +13,60 @@ fn _byte_pair_merge(
piece: &[u8],
) -> Vec<(usize, Rank)> {
// This is a vector of (start, rank).
// The rank is of the byte pair starting at position start.
// The rank of the last item in the vector is not a valid value.
let mut parts: Vec<(usize, Rank)> = (0..piece.len() + 1).map(|i| (i, Rank::MAX)).collect();
// The rank is of the pair starting at position start.
let mut parts = Vec::with_capacity(piece.len() + 1);

// Note that we hash bytes when indexing into `ranks`, not token pairs. As long as we train BPE
// the way we currently do, this is equivalent. An easy way to break this would be to decouple
// merge priority from token index or to prevent specific token merges.
let mut min_rank: (Rank, usize) = (Rank::MAX, usize::MAX);
for i in 0..piece.len() - 1 {
let rank = *ranks.get(&piece[i..i + 2]).unwrap_or(&Rank::MAX);
if rank < min_rank.0 {
min_rank = (rank, i);
}
parts.push((i, rank));
}
parts.push((piece.len() - 1, Rank::MAX));
parts.push((piece.len(), Rank::MAX));

let get_rank = {
#[inline(always)]
|parts: &Vec<(usize, Rank)>, start_idx: usize, skip: usize| {
if (start_idx + skip + 2) < parts.len() {
ranks
.get(&piece[parts[start_idx].0..parts[start_idx + skip + 2].0])
.copied()
|parts: &Vec<(usize, Rank)>, i: usize| {
if (i + 3) < parts.len() {
// Similar to `piece[i..i + 2]` above. The +3 is because we haven't yet deleted
// parts[i + 1], see comment in the main loop.
*ranks
.get(&piece[parts[i].0..parts[i + 3].0])
.unwrap_or(&Rank::MAX)
} else {
None
Rank::MAX
}
}
};

// We look up the ranks once in the beginning and iteratively update
// them during each merge, which reduces the number of rank lookups.
for i in 0..parts.len() - 2 {
match get_rank(&parts, i, 0) {
Some(rank) => {
// Rank::MAX is a sentinel value and cannot be a valid rank
debug_assert!(rank != Rank::MAX);
parts[i].1 = rank;
}
None => {
continue;
}
};
}

// If you have n parts and m merges, this does O(mn) work.
// We could do something with a heap and do O(m log n) work.
// It is important to consider that n is often small (<100), and as such
// the cache-locality benefits outweigh the algorithmic complexity downsides
// of the `parts` vector data structure above.

// Note that we hash bytes, not token pairs. As long as we train BPE the way we
// currently do, this is equivalent. An easy way to break this would be to decouple
// merge priority from token index or to prevent specific token merges.
loop {
if parts.len() == 1 {
break;
// n is often very small so considerations like cache-locality outweigh the algorithmic
// complexity downsides of the `parts` vector.
while min_rank.0 != Rank::MAX {
let i = min_rank.1;
// Update parts[i] and parts[i - 1] before removing parts[i + 1], since
// `parts.remove(i + 1)` will thrash the cache.
if i > 0 {
parts[i - 1].1 = get_rank(&parts, i - 1);
}
parts[i].1 = get_rank(&parts, i);
parts.remove(i + 1);

// Rank::MAX is a sentinel rank value allowing us to
// take the min more quickly
let mut min_rank: (Rank, usize) = (Rank::MAX, 0);
min_rank = (Rank::MAX, usize::MAX);
for (i, &(_, rank)) in parts[..parts.len() - 1].iter().enumerate() {
if rank < min_rank.0 {
min_rank = (rank, i);
}
}

if min_rank.0 != Rank::MAX {
let i = min_rank.1;

// NOTE: We are about to remove parts[i + 1]. We do not do it
// yet because there are cache-locality benefits to updating
// parts[i] and parts[i-1] before removing, which could thrash
// the cache. Thus, we update the rank calculation by skipping over
// parts[i + 1], by invoking `get_rank!` with `skip = 1`.
parts[i].1 = get_rank(&parts, i, 1).unwrap_or(Rank::MAX);
if i > 0 {
parts[i - 1].1 = get_rank(&parts, i - 1, 1).unwrap_or(Rank::MAX);
}

parts.remove(i + 1);
} else {
break;
}
}

parts
}

Expand Down
Loading