Skip to content

Commit

Permalink
Implement interval encoding by reusing states (incorrect\!)
Browse files Browse the repository at this point in the history
  • Loading branch information
hendrikvanantwerpen committed Oct 2, 2024
1 parent b05fb36 commit 7d97391
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 4 deletions.
5 changes: 3 additions & 2 deletions crates/bpe/src/byte_pair_encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,9 +211,9 @@ fn find_token_by_bytes(
}
}

#[derive(Clone)]
#[derive(Clone, Debug)]
pub(crate) struct State {
state: u32,
pub(crate) state: u32,
pub(crate) last_token: u32,
pub(crate) count: u32,
}
Expand Down Expand Up @@ -422,6 +422,7 @@ impl BytePairEncoding {
}
}
}
unreachable!()
}

/// Counts the number tokens produced when encoding the text.
Expand Down
27 changes: 25 additions & 2 deletions crates/bpe/src/interval_encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,31 @@ impl<'a> IntervalEncoding<'a> {
}

pub(crate) fn encode_interval(&self, states: &mut Vec<State>, range: Range<usize>) {
// TODO Use the precomputed tokens
self.bpe.encode_next_bytes(states, &self.text[range]);
assert!(range.start <= range.end && range.end <= self.text.len());
for pos in range.clone() {
if let (Some(last_state), Some(prev_state)) =
(states.last(), (pos > 0).then(|| &self.states[pos - 1]))
{
// If we have reached the same state and token, copy the remaining states as-is.
if last_state.state == prev_state.state
&& last_state.last_token == prev_state.last_token
{
for next_pos in pos..range.end {
let next_state = &self.states[next_pos];
let next_count =
states[states.len() - self.bpe.token_len(next_state.last_token)].count
+ 1;
states.push(State {
state: next_state.state,
last_token: next_state.last_token,
count: next_count,
});
}
return;
}
}
self.bpe.encode_next_byte(states, self.text[pos]);
}
}
}

Expand Down

0 comments on commit 7d97391

Please sign in to comment.