Skip to content

Commit

Permalink
Make RegexGuide pickleable again for vllm and tgi (#99)
Browse files Browse the repository at this point in the history
I understand that `pickleable` is not your priority right now. But the
`RegexGuide` needs to be pickled for `vllm` production use, which is
multiprocessing-based.

This PR reintroduces this pickling capability + some tests.

I understand that this introduces more effort on your side.

References:
dottxt-ai/outlines#1274
vllm-project/vllm#10490
vllm-project/vllm#10576
vllm-project/vllm#10489

It would also tackle the current caching issues: 
huggingface/text-generation-inference#2766
dottxt-ai/outlines#1283

Closes:
#95
  • Loading branch information
joennlae authored Dec 2, 2024
1 parent 9ddf5e7 commit d1a0e8c
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 2 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ regex = "1.10.6"
serde-pyobject = "0.4.0"
serde_json = { version = "1.0", features = ["preserve_order"] }
serde = {version = "1.0", features = ["derive"]}
bincode = "2.0.0-rc.3"
# Fragile dependencies, minor updates often break the code
hf-hub = "=0.3.2"
tokenizers = { version = "=0.20.3", features = ["http"] }
Expand Down
3 changes: 2 additions & 1 deletion src/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use crate::prelude::{State, TransitionKey};
use crate::regex::{get_vocabulary_transition_keys, state_scan_tokens};
use crate::vocabulary::Vocabulary;
use crate::{Error, Result};
use bincode::{Decode, Encode};
use std::collections::{HashMap, HashSet};

#[derive(Debug)]
Expand Down Expand Up @@ -32,7 +33,7 @@ impl FSMInfo {
}
}

#[derive(Debug)]
#[derive(Debug, Encode, Decode)]
pub struct Index {
initial: u32,
finals: HashSet<u32>,
Expand Down
24 changes: 23 additions & 1 deletion src/python_bindings/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::regex::get_token_transition_keys;
use crate::regex::get_vocabulary_transition_keys;
use crate::regex::state_scan_tokens;
use crate::regex::walk_fsm;
use bincode::config;
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::types::PyDict;
Expand Down Expand Up @@ -72,7 +73,7 @@ impl PyFSMInfo {
}
}

#[pyclass(name = "Index")]
#[pyclass(name = "Index", module = "outlines_core.fsm.outlines_core_rs")]
pub struct PyIndex(Index);

#[pymethods]
Expand All @@ -89,6 +90,27 @@ impl PyIndex {
.map_err(Into::into)
}

fn __reduce__(&self) -> PyResult<(PyObject, (Vec<u8>,))> {
Python::with_gil(|py| {
let cls = PyModule::import_bound(py, "outlines_core.fsm.outlines_core_rs")?
.getattr("Index")?;
let binary_data: Vec<u8> = bincode::encode_to_vec(&self.0, config::standard())
.map_err(|e| {
PyErr::new::<PyValueError, _>(format!("Serialization of Index failed: {}", e))
})?;
Ok((cls.getattr("from_binary")?.to_object(py), (binary_data,)))
})
}

#[staticmethod]
fn from_binary(binary_data: Vec<u8>) -> PyResult<Self> {
let (index, _): (Index, usize) =
bincode::decode_from_slice(&binary_data[..], config::standard()).map_err(|e| {
PyErr::new::<PyValueError, _>(format!("Deserialization of Index failed: {}", e))
})?;
Ok(PyIndex(index))
}

fn get_allowed_tokens(&self, state: u32) -> Option<Vec<u32>> {
self.0.allowed_tokens(state)
}
Expand Down
56 changes: 56 additions & 0 deletions tests/fsm/test_serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import pickle

import pytest
from outlines_core.fsm.guide import RegexGuide
from transformers import AutoTokenizer

from tests.fsm.test_regex import TransformerTokenizer


def test_serialization():
class MockTokenizer:
vocabulary = {"1": 1, "a": 2, "eos": 3}
special_tokens = {"eos"}
eos_token_id = 3

def convert_token_to_string(self, token):
return token

regex_str = "[1-9]"
tokenizer = MockTokenizer()

fsm = RegexGuide.from_regex(regex_str, tokenizer)

serialized = pickle.dumps(fsm)
deserialized = pickle.loads(serialized)

assert fsm.eos_tensor == deserialized.eos_tensor
assert fsm.initial_state == deserialized.initial_state


@pytest.mark.parametrize(
"hf_tokenizer_uri, revision",
[
("openai-community/gpt2", "607a30d783dfa663caf39e06633721c8d4cfcd7e"),
("microsoft/phi-2", "ef382358ec9e382308935a992d908de099b64c23"),
("Qwen/Qwen1.5-0.5B-Chat", "4d14e384a4b037942bb3f3016665157c8bcb70ea"),
(
"NousResearch/Hermes-2-Pro-Llama-3-8B",
"783fd50eb82d7f57758de033861f54d62dde234f",
),
],
)
def test_complex_serialization(hf_tokenizer_uri, revision):
# The combined regular expressions of a lexer state in a Python grammar
regex_str = "(?:(?:[0-9](?:(?:_)?[0-9])*(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*|(?:[0-9](?:(?:_)?[0-9])*\\.(?:[0-9](?:(?:_)?[0-9])*)?|\\.[0-9](?:(?:_)?[0-9])*)(?:(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*)?)|[0-9](?:(?:_)?[0-9])*)(?:J|j)|(?:[0-9](?:(?:_)?[0-9])*(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*|(?:[0-9](?:(?:_)?[0-9])*\\.(?:[0-9](?:(?:_)?[0-9])*)?|\\.[0-9](?:(?:_)?[0-9])*)(?:(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*)?)|0(?:x|X)(?:(?:_)?(?:[0-9]|[a-f]|[A-F]))+|0(?:b|B)(?:(?:_)?[0-1])+|0(?:o|O)(?:(?:_)?[0-7])+|(?:(?i:([ubf]?r?|r[ubf])('([^\\\\']|.)*?'))|(?i:([ubf]?r?|r[ubf])(\"([^\\\"]|.)*?\")))|(?:(?:\r?\n[\t ]*|#[^\n]*))+|[1-9](?:(?:_)?[0-9])*|\\\\[\t \x0c]*\r?\n|continue|nonlocal|assert|global|import|lambda|return|async|await|break|class|False|match|raise|while|yield|case|from|None|pass|True|with|def|del|for|not|try|if|[^\\W\\d]\\w*|#[^\n]*|[\t \x0c]+|\\.\\.\\.|@|\\{|\\(|\\[|\\-|\\+|\\*|\\~"

tokenizer = AutoTokenizer.from_pretrained(hf_tokenizer_uri, revision=revision)
tokenizer = TransformerTokenizer(tokenizer)

fsm = RegexGuide.from_regex(regex_str, tokenizer)

serialized = pickle.dumps(fsm)
deserialized = pickle.loads(serialized)

assert fsm.eos_tensor == deserialized.eos_tensor
assert fsm.initial_state == deserialized.initial_state

0 comments on commit d1a0e8c

Please sign in to comment.