Skip to content

Commit

Permalink
Introduce token processor
Browse files Browse the repository at this point in the history
  • Loading branch information
torymur committed Nov 7, 2024
1 parent 1446c76 commit c3b4430
Show file tree
Hide file tree
Showing 4 changed files with 344 additions and 1 deletion.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ license = "Apache-2.0"
repository = "https://github.com/dottxt-ai/outlines-core"

[dependencies]
once_cell = "1.20"
anyhow = "1.0.86"
thiserror = "1.0"
pyo3 = { version = "0.22.0", features = ["extension-module"], optional = true }
Expand Down
26 changes: 26 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ pub mod regex;
pub mod vocabulary;

mod locator;
mod processor;

#[cfg(feature = "python-bindings")]
mod python_bindings;
Expand All @@ -18,6 +19,31 @@ pub enum Error {
IndexError,
}

#[derive(Error, Debug)]
pub enum VocabularyError {
#[error("Unable to create tokenizer for {model}, source {source}")]
UnableToCreateTokenizer {
model: String,
source: tokenizers::Error,
},
#[error("Unable to locate EOS token for {model}")]
UnableToLocateEosTokenId { model: String },
#[error("Unable to process token")]
TokenProcessorError(#[from] TokenProcessorError),
}

#[derive(Error, Debug)]
pub enum TokenProcessorError {
#[error("Tokenizer is not supported")]
UnsupportedTokenizer,
#[error("Decoder unpacking failed")]
DecoderUnpackingFailed,
#[error("Token processing failed for byte level processor")]
ByteProcessorFailed,
#[error("Token processing failed for byte fallback level processor")]
ByteFallbackProcessorFailed,
}

#[cfg(feature = "python-bindings")]
impl From<Error> for pyo3::PyErr {
fn from(e: Error) -> Self {
Expand Down
3 changes: 2 additions & 1 deletion src/locator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ struct Content {
content: String,
}

/// Kind of the json field which will be checked for eos token id.
/// Which part in config's json to check for eos token id.
enum EosTokenField {
Id,
Value,
Expand All @@ -88,6 +88,7 @@ struct EosTokenLocation {
pub(crate) struct EosTokenLocator;

impl EosTokenLocator {
/// Locates eos token id by searching in defined common locations.
pub(crate) fn locate(
model: &str,
tokenizer: &Tokenizer,
Expand Down
315 changes: 315 additions & 0 deletions src/processor.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,315 @@
use std::collections::HashMap;

use once_cell::sync::Lazy;
use serde::Deserialize;
use tokenizers::normalizers::Replace;
use tokenizers::{DecoderWrapper, Tokenizer};

use crate::TokenProcessorError;

pub type Result<T, E = TokenProcessorError> = std::result::Result<T, E>;

/// GPT2-like tokenizers have multibyte tokens that can have a mix of full and incomplete
/// utf-8 characters. For example, b` \xf0` can be one token. These tokenizers map each
/// byte to a valid UTF-8 character. And we need to map back those characters into bytes.
///
/// "ĠO" = [U+0120, U+004F] should be interpreted as [0x20, 0x4F] = " O"
/// or
/// "Ġal" = [U+0120, U+0061, U+006C] should be interpreted as [0x20, 0x61, 0x6C] = " al"
///
/// We'll use the following the mapping for this transition:
/// ---
/// 'Ā' == '\u{0100}' -> 0x00 == 0
/// 'ā' == '\u{0101}' -> 0x01 == 1
/// 'Ă' == '\u{0102}' -> 0x02 == 2
/// ...
/// 'Ğ' == '\u{011E}' -> 0x1E == 30
/// 'ğ' == '\u{011F}' -> 0x1F == 31
/// 'Ġ' == '\u{0120}' -> 0x20 == 32
/// ---
/// '!' == '\u{0021}' -> 0x21 == 33
/// '"' == '\u{0022}' -> 0x22 == 34
/// '#' == '\u{0023}' -> 0x23 == 35
/// ...
/// '|' == '\u{007C}' -> 0x7C == 124
/// '}' == '\u{007D}' -> 0x7D == 125
/// '~' == '\u{007E}' -> 0x7E == 126
/// ---
/// 'ġ' == '\u{0121}' -> 0x7F == 127
/// 'Ģ' == '\u{0122}' -> 0x80 == 128
/// 'ģ' == '\u{0123}' -> 0x81 == 129
/// ...
/// 'ŀ' == '\u{0140}' -> 0x9E == 158
/// 'Ł' == '\u{0141}' -> 0x9F == 159
/// 'ł' == '\u{0142}' -> 0xA0 == 160
/// ---
/// '¡' == '\u{00A1}' -> 0xA1 == 161
/// '¢' == '\u{00A2}' -> 0xA2 == 162
/// '£' == '\u{00A3}' -> 0xA3 == 163
/// ...
/// 'ª' == '\u{00AA}' -> 0xAA == 170
/// '«' == '\u{00AB}' -> 0xAB == 171
/// '¬' == '\u{00AC}' -> 0xAC == 172
/// ---
/// 'Ń' == '\u{0143}' -> 0xAD == 173
/// ---
/// '®' == '\u{00AE}' -> 0xAE == 174
/// '¯' == '\u{00AF}' -> 0xAF == 175
/// '°' == '\u{00B0}' -> 0xB0 == 176
/// ...
/// 'ý' == '\u{00FD}' -> 0xFD == 253
/// 'þ' == '\u{00FE}' -> 0xFE == 254
/// 'ÿ' == '\u{00FF}' -> 0xFF == 255
/// ---
static CHAR_MAP: Lazy<HashMap<char, u8>> = Lazy::new(|| {
let mut char_map = HashMap::with_capacity(256);
let mut key = 0x100u32;
for byte in 0..=255u8 {
let char = byte as char;
if matches!(
char, '!'..='~' | '\u{00A1}'..='\u{00AC}' | '\u{00AE}'..='\u{00FF}',
) {
char_map.insert(char, byte);
} else if let Some(ch) = char::from_u32(key) {
char_map.insert(ch, byte);
key += 1;
}
}
char_map
});

/// Token processor to adjust tokens according to the tokenizer's level.
#[derive(Debug)]
pub(crate) struct TokenProcessor {
level: TokenProcessorLevel,
}

/// Recognized tokenizer's levels.
#[derive(Debug, Clone, PartialEq)]
pub enum TokenProcessorLevel {
/// Matches byte level tokenizer (e.g., gpt2).
Byte,
/// Matches byte fallback tokenizer (e.g., llama), which have <0x__> tokens for
/// all __ >= 0x80 to represent incomplete UTF-8 sequences.
ByteFallback(Mods),
}

impl std::fmt::Display for TokenProcessorLevel {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
Self::Byte => write!(f, "Byte Level"),
Self::ByteFallback(mods) => write!(f, "Byte Fallback Level with mods: {:?}", mods),
}
}
}

/// Modifications to be applied by `ByteFallback` `TokenProcessorLevel`.
#[derive(Debug, Clone, PartialEq)]
pub struct Mods {
spacechar: char,
}

/// Default string modification to be applied by `TokenProcessor` of `ByteFallback` level.
static DEFAULT_MODS: Mods = Mods { spacechar: ' ' };

impl Mods {
/// Apply default modifications.
fn apply_default(&self, token: String) -> String {
let to = DEFAULT_MODS.spacechar.to_string();
token.replace(self.spacechar, &to)
}
}

#[derive(Debug, Deserialize)]
struct ReplaceDecoder {
content: String,
pattern: ReplacePattern,
}

impl ReplaceDecoder {
fn space_replacement(&self) -> Option<char> {
if self.content != " " {
return None;
}
match &self.pattern {
ReplacePattern::String(pattern) => {
let mut chars = pattern.chars();
let char = chars.next();
if let Some(replacement) = char {
if chars.next().is_none() {
return Some(replacement);
}
}
None
}
}
}
}

#[derive(Debug, Deserialize)]
pub enum ReplacePattern {
String(String),
}

impl TokenProcessor {
/// Create new `TokenProcessor` with the level defined based on tokenizer's decoders.
pub(crate) fn new(tokenizer: &Tokenizer) -> Result<Self> {
match tokenizer.get_decoder() {
None => Err(TokenProcessorError::UnsupportedTokenizer),
Some(decoder) => match decoder {
DecoderWrapper::ByteLevel(_) => Ok(Self {
level: TokenProcessorLevel::Byte,
}),
DecoderWrapper::Sequence(decoding_sequence) => {
let mut is_byte_fallback = false;
let mut spacechar = ' ';

for decoder in decoding_sequence.get_decoders() {
match decoder {
DecoderWrapper::ByteFallback(_) => {
is_byte_fallback = true;
}
DecoderWrapper::Replace(replace) => {
// `Replace` decoder would replace a pattern in the output with something else,
// which we need to know.
let decoder = Self::unpack_decoder(replace)?;
if let Some(replacement) = decoder.space_replacement() {
spacechar = replacement;
}
}
_ => {}
}
}

if is_byte_fallback {
Ok(Self {
level: TokenProcessorLevel::ByteFallback(Mods { spacechar }),
})
} else {
Err(TokenProcessorError::UnsupportedTokenizer)
}
}
_ => Err(TokenProcessorError::UnsupportedTokenizer),
},
}
}

/// Process each token based on the level ofTokenProcesso.
pub(crate) fn process(&self, token: String) -> Result<Vec<u8>> {
match &self.level {
TokenProcessorLevel::Byte => {
let mut bytes = vec![];
for char in token.chars() {
match CHAR_MAP.get(&char) {
None => return Err(TokenProcessorError::ByteProcessorFailed),
Some(b) => bytes.push(*b),
}
}
Ok(bytes)
}
TokenProcessorLevel::ByteFallback(mods) => {
// If the token is of form `<0x__>`:
if token.len() == 6 && token.starts_with("<0x") && token.ends_with('>') {
// Get to a single byte specified in the __ part and parse it in base 16 to a byte.
match u8::from_str_radix(&token[3..5], 16) {
Ok(byte) => Ok([byte].to_vec()),
Err(_) => Err(TokenProcessorError::ByteFallbackProcessorFailed),
}
} else {
Ok(mods.apply_default(token).as_bytes().to_vec())
}
}
}
}

/// Since all fields of `Replace` are private with no getters, we'll have to unpack it into our own.
fn unpack_decoder(decoder: &Replace) -> Result<ReplaceDecoder> {
match serde_json::to_value(decoder) {
Err(_) => Err(TokenProcessorError::DecoderUnpackingFailed),
Ok(value) => match serde_json::from_value(value) {
Ok(d) => Ok(d),
Err(_) => Err(TokenProcessorError::DecoderUnpackingFailed),
},
}
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn byte_level_processor() {
let model = "openai-community/gpt2";
let tokenizer = Tokenizer::from_pretrained(model, None).expect("Tokenizer failed");
let processor = TokenProcessor::new(&tokenizer).expect("Processor failed");

assert_eq!(processor.level, TokenProcessorLevel::Byte);

for (ch, byte) in [
('Ā', 0x00),
('ā', 0x01),
('Ă', 0x02),
('Ğ', 0x1E),
('ğ', 0x1F),
('Ġ', 0x20),
('!', 0x21),
('"', 0x22),
('#', 0x23),
('|', 0x7C),
('}', 0x7D),
('~', 0x7E),
('ġ', 0x7F),
('Ģ', 0x80),
('ģ', 0x81),
('ŀ', 0x9E),
('Ł', 0x9F),
('ł', 0xA0),
('¡', 0xA1),
('¢', 0xA2),
('£', 0xA3),
('ª', 0xAA),
('«', 0xAB),
('¬', 0xAC),
('Ń', 0xAD),
('®', 0xAE),
('¯', 0xAF),
('°', 0xB0),
('ý', 0xFD),
('þ', 0xFE),
('ÿ', 0xFF),
] {
let processed = processor.process(ch.to_string()).expect("Not processed");
assert_eq!(processed, [byte]);
}
}

#[test]
fn byte_fallback_level_processor() {
let model = "hf-internal-testing/llama-tokenizer";
let tokenizer = Tokenizer::from_pretrained(model, None).expect("Tokenizer failed");
let processor = TokenProcessor::new(&tokenizer).expect("Processor failed");
let spacechar = '▁';
let mods = Mods { spacechar };

assert_eq!(processor.level, TokenProcessorLevel::ByteFallback(mods));

for (input, expected) in [
("abc", vec![0x61, 0x62, 0x63]),
("<0x61>", vec![0x61]),
("<0x61>a", vec![0x3C, 0x30, 0x78, 0x36, 0x31, 0x3E, 0x61]),
(&spacechar.to_string(), vec![0x20]),
(
&format!("{}{}abc", spacechar, spacechar),
vec![0x20, 0x20, 0x61, 0x62, 0x63],
),
(
&format!("{}{}{}", spacechar, spacechar, spacechar),
vec![0x20, 0x20, 0x20],
),
] {
let processed = processor.process(input.to_string()).expect("Not processed");
assert_eq!(processed, expected);
}
}
}

0 comments on commit c3b4430

Please sign in to comment.