Skip to content

Commit

Permalink
Separate and simplify errors
Browse files Browse the repository at this point in the history
  • Loading branch information
torymur committed Nov 12, 2024
1 parent d09ea69 commit c7af981
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 91 deletions.
27 changes: 27 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
use thiserror::Error;

#[derive(Error, Debug, PartialEq)]
pub enum Error {
#[error("The vocabulary does not allow us to build a sequence that matches the input")]
IndexError,
#[error("Unable to create tokenizer for {model}")]
UnableToCreateTokenizer { model: String },
#[error("Unable to locate EOS token for {model}")]
UnableToLocateEosTokenId { model: String },
#[error("Tokenizer is not supported by token processor")]
UnsupportedByTokenProcessor,
#[error("Decoder unpacking failed for token processor")]
DecoderUnpackingFailed,
#[error("Token processing failed for byte level processor")]
ByteProcessorFailed,
#[error("Token processing failed for byte fallback level processor")]
ByteFallbackProcessorFailed,
}

#[cfg(feature = "python-bindings")]
impl From<Error> for pyo3::PyErr {
fn from(e: Error) -> Self {
use pyo3::{exceptions::PyValueError, PyErr};
PyErr::new::<PyValueError, _>(e.to_string())
}
}
5 changes: 2 additions & 3 deletions src/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@
use crate::prelude::{State, TransitionKey};
use crate::regex::{get_vocabulary_transition_keys, state_scan_tokens};
use crate::vocabulary::Vocabulary;
use crate::{Error, Result};
use std::collections::{HashMap, HashSet};

pub type Result<T, E = crate::Error> = std::result::Result<T, E>;

#[derive(Debug)]
pub struct FSMInfo {
pub(crate) initial: State,
Expand Down Expand Up @@ -101,7 +100,7 @@ impl Index {
eos_token_id,
})
} else {
Err(crate::Error::IndexError)
Err(Error::IndexError)
}
}

Expand Down
54 changes: 4 additions & 50 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,60 +1,14 @@
pub mod error;
pub mod index;
pub mod json_schema;
pub mod prelude;
pub mod primitives;
pub mod regex;
pub mod vocabulary;

#[cfg(feature = "python-bindings")]
mod python_bindings;

use thiserror::Error;

#[derive(Error, Debug, PartialEq)]
pub enum Error {
#[error("The vocabulary does not allow us to build a sequence that matches the input")]
IndexError,
}
use error::Error;

#[derive(Error, Debug)]
#[error("Tokenizer error")]
pub struct TokenizerError(tokenizers::Error);

impl PartialEq for TokenizerError {
fn eq(&self, other: &Self) -> bool {
self.0.to_string() == other.0.to_string()
}
}

#[derive(Error, Debug, PartialEq)]
pub enum VocabularyError {
#[error("Unable to create tokenizer for {model}, source {source}")]
UnableToCreateTokenizer {
model: String,
source: TokenizerError,
},
#[error("Unable to locate EOS token for {model}")]
UnableToLocateEosTokenId { model: String },
#[error("Unable to process token")]
TokenProcessorError(#[from] TokenProcessorError),
}

#[derive(Error, Debug, PartialEq)]
pub enum TokenProcessorError {
#[error("Tokenizer is not supported")]
UnsupportedTokenizer,
#[error("Decoder unpacking failed")]
DecoderUnpackingFailed,
#[error("Token processing failed for byte level processor")]
ByteProcessorFailed,
#[error("Token processing failed for byte fallback level processor")]
ByteFallbackProcessorFailed,
}
pub type Result<T, E = Error> = std::result::Result<T, E>;

#[cfg(feature = "python-bindings")]
impl From<Error> for pyo3::PyErr {
fn from(e: Error) -> Self {
use pyo3::{exceptions::PyValueError, PyErr};
PyErr::new::<PyValueError, _>(e.to_string())
}
}
mod python_bindings;
26 changes: 10 additions & 16 deletions src/vocabulary/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ use std::collections::HashMap;
use tokenizers::normalizers::Sequence;
use tokenizers::{FromPretrainedParameters, NormalizerWrapper, Tokenizer};

use crate::{prelude::*, TokenizerError, VocabularyError};
use crate::prelude::*;
use crate::{Error, Result};

use locator::EosTokenLocator;
use processor::TokenProcessor;
Expand Down Expand Up @@ -44,19 +45,18 @@ impl Vocabulary {
pub fn from_pretrained(
model: &str,
parameters: Option<FromPretrainedParameters>,
) -> Result<Self, VocabularyError> {
) -> Result<Self> {
let mut tokenizer =
Tokenizer::from_pretrained(model, parameters.clone()).map_err(|error| {
VocabularyError::UnableToCreateTokenizer {
Tokenizer::from_pretrained(model, parameters.clone()).map_err(|_| {
Error::UnableToCreateTokenizer {
model: model.to_string(),
source: TokenizerError(error),
}
})?;
Self::filter_prepend_normalizers(&mut tokenizer);

let eos_token_id = EosTokenLocator::locate(model, &tokenizer, &parameters);
let Some(eos_token_id) = eos_token_id else {
return Err(VocabularyError::UnableToLocateEosTokenId {
return Err(Error::UnableToLocateEosTokenId {
model: model.to_string(),
});
};
Expand Down Expand Up @@ -106,9 +106,9 @@ impl Vocabulary {
}

impl TryFrom<(&mut Tokenizer, u32)> for Vocabulary {
type Error = VocabularyError;
type Error = Error;

fn try_from(value: (&mut Tokenizer, u32)) -> Result<Vocabulary, VocabularyError> {
fn try_from(value: (&mut Tokenizer, u32)) -> Result<Self> {
let (tokenizer, eos_token_id) = value;

let mut vocabulary = Vocabulary::new(Some(eos_token_id));
Expand Down Expand Up @@ -313,12 +313,7 @@ mod tests {

assert!(vocabulary.is_err());
if let Err(e) = vocabulary {
assert_eq!(
e,
VocabularyError::TokenProcessorError(
crate::TokenProcessorError::UnsupportedTokenizer
)
)
assert_eq!(e, Error::UnsupportedByTokenProcessor)
}
}

Expand All @@ -328,9 +323,8 @@ mod tests {
let vocabulary = Vocabulary::from_pretrained(model, None);

assert!(vocabulary.is_err());
if let Err(VocabularyError::UnableToCreateTokenizer { model, source }) = vocabulary {
if let Err(Error::UnableToCreateTokenizer { model }) = vocabulary {
assert_eq!(model, model.to_string());
assert_eq!(source.to_string(), "Tokenizer error".to_string());
}
}

Expand Down
41 changes: 19 additions & 22 deletions src/vocabulary/processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@ use serde::Deserialize;
use tokenizers::normalizers::Replace;
use tokenizers::{DecoderWrapper, Tokenizer};

use crate::TokenProcessorError;

type Result<T, E = TokenProcessorError> = std::result::Result<T, E>;
use crate::{Error, Result};

/// GPT2-like tokenizers have multibyte tokens that can have a mix of full and incomplete
/// UTF-8 characters, for example, byte ` \xf0` can be one token. These tokenizers map each
Expand Down Expand Up @@ -157,7 +155,7 @@ impl TokenProcessor {
/// Create new `TokenProcessor` with the level defined based on tokenizer's decoders.
pub(crate) fn new(tokenizer: &Tokenizer) -> Result<Self> {
match tokenizer.get_decoder() {
None => Err(TokenProcessorError::UnsupportedTokenizer),
None => Err(Error::UnsupportedByTokenProcessor),
Some(decoder) => match decoder {
DecoderWrapper::ByteLevel(_) => Ok(Self {
level: TokenProcessorLevel::Byte,
Expand Down Expand Up @@ -188,34 +186,33 @@ impl TokenProcessor {
level: TokenProcessorLevel::ByteFallback(Mods { spacechar }),
})
} else {
Err(TokenProcessorError::UnsupportedTokenizer)
Err(Error::UnsupportedByTokenProcessor)
}
}
_ => Err(TokenProcessorError::UnsupportedTokenizer),
_ => Err(Error::UnsupportedByTokenProcessor),
},
}
}

/// Operates on each token based on the level of `TokenProcessor`.
pub(crate) fn process(&self, token: String) -> Result<Vec<u8>> {
match &self.level {
TokenProcessorLevel::Byte => {
let mut bytes = vec![];
for char in token.chars() {
match CHAR_MAP.get(&char) {
None => return Err(TokenProcessorError::ByteProcessorFailed),
Some(b) => bytes.push(*b),
}
}
Ok(bytes)
}
TokenProcessorLevel::Byte => token
.chars()
.map(|char| {
CHAR_MAP
.get(&char)
.copied()
.ok_or(Error::ByteProcessorFailed)
})
.collect(),
TokenProcessorLevel::ByteFallback(mods) => {
// If the token is of form `<0x__>`:
if token.len() == 6 && token.starts_with("<0x") && token.ends_with('>') {
// Get to a single byte specified in the __ part and parse it in base 16 to a byte.
match u8::from_str_radix(&token[3..5], 16) {
Ok(byte) => Ok([byte].to_vec()),
Err(_) => Err(TokenProcessorError::ByteFallbackProcessorFailed),
Err(_) => Err(Error::ByteFallbackProcessorFailed),
}
} else {
Ok(mods.apply_default(token).as_bytes().to_vec())
Expand All @@ -228,10 +225,10 @@ impl TokenProcessor {
/// into local `ReplaceDecoder` structure.
fn unpack_decoder(decoder: &Replace) -> Result<ReplaceDecoder> {
match serde_json::to_value(decoder) {
Err(_) => Err(TokenProcessorError::DecoderUnpackingFailed),
Err(_) => Err(Error::DecoderUnpackingFailed),
Ok(value) => match serde_json::from_value(value) {
Ok(d) => Ok(d),
Err(_) => Err(TokenProcessorError::DecoderUnpackingFailed),
Err(_) => Err(Error::DecoderUnpackingFailed),
},
}
}
Expand Down Expand Up @@ -324,7 +321,7 @@ mod tests {
let result = TokenProcessor::new(&tokenizer);
assert!(result.is_err());
if let Err(e) = result {
assert_eq!(e, TokenProcessorError::UnsupportedTokenizer)
assert_eq!(e, Error::UnsupportedByTokenProcessor)
}
}

Expand All @@ -338,7 +335,7 @@ mod tests {
let result = processor.process(token.to_string());
assert!(result.is_err());
if let Err(e) = result {
assert_eq!(e, TokenProcessorError::ByteProcessorFailed)
assert_eq!(e, Error::ByteProcessorFailed)
}
}
}
Expand All @@ -352,7 +349,7 @@ mod tests {
let result = processor.process("<0x6y>".to_string());
assert!(result.is_err());
if let Err(e) = result {
assert_eq!(e, TokenProcessorError::ByteFallbackProcessorFailed)
assert_eq!(e, Error::ByteFallbackProcessorFailed)
}
}
}

0 comments on commit c7af981

Please sign in to comment.