From 3427117cae8ae9865d7fba043e41a7dfd484cbd0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jannis=20Sch=C3=B6nleber?= <joennlae@gmail.com> Date: Sat, 30 Nov 2024 20:01:32 +0100 Subject: [PATCH 1/5] feat(pickle): make `Index` pickleable by using `serde` --- src/index.rs | 4 +++- src/python_bindings/mod.rs | 19 ++++++++++++++++++- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/src/index.rs b/src/index.rs index cc1187e8..c37defdf 100644 --- a/src/index.rs +++ b/src/index.rs @@ -1,3 +1,5 @@ +use serde::{Deserialize, Serialize}; + /// Construct an Index. use crate::prelude::{State, TransitionKey}; use crate::regex::{get_vocabulary_transition_keys, state_scan_tokens}; @@ -32,7 +34,7 @@ impl FSMInfo { } } -#[derive(Debug)] +#[derive(Debug, Serialize, Deserialize)] pub struct Index { initial: u32, finals: HashSet<u32>, diff --git a/src/python_bindings/mod.rs b/src/python_bindings/mod.rs index 046d7ce9..5697a00c 100644 --- a/src/python_bindings/mod.rs +++ b/src/python_bindings/mod.rs @@ -72,7 +72,7 @@ impl PyFSMInfo { } } -#[pyclass(name = "Index")] +#[pyclass(name = "Index", module = "outlines_core.fsm.outlines_core_rs")] pub struct PyIndex(Index); #[pymethods] @@ -89,6 +89,23 @@ impl PyIndex { .map_err(Into::into) } + fn __reduce__(&self) -> PyResult<(PyObject, (String,))> { + Python::with_gil(|py| { + let cls = PyModule::import_bound(py, "outlines_core.fsm.outlines_core_rs")? + .getattr("Index")?; + let json_data = serde_json::to_string(&self.0) + .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?; + Ok((cls.getattr("from_json")?.to_object(py), (json_data,))) + }) + } + + #[staticmethod] + fn from_json(json_data: String) -> PyResult<Self> { + let index: Index = serde_json::from_str(&json_data) + .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?; + Ok(PyIndex(index)) + } + fn get_allowed_tokens(&self, state: u32) -> Option<Vec<u32>> { self.0.allowed_tokens(state) } From b88907533e6fcc74840f84ef84a0a529b88220d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jannis=20Sch=C3=B6nleber?= <joennlae@gmail.com> Date: Sat, 30 Nov 2024 20:04:51 +0100 Subject: [PATCH 2/5] test(pickle): add simple + complex pickle test --- tests/fsm/test_serialization.py | 67 +++++++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100644 tests/fsm/test_serialization.py diff --git a/tests/fsm/test_serialization.py b/tests/fsm/test_serialization.py new file mode 100644 index 00000000..33850f8e --- /dev/null +++ b/tests/fsm/test_serialization.py @@ -0,0 +1,67 @@ +import pickle +from timeit import default_timer as timer + +import pytest +from outlines_core.fsm.guide import RegexGuide +from transformers import AutoTokenizer + +from tests.fsm.test_regex import TransformerTokenizer + + +def test_serialization(): + class MockTokenizer: + vocabulary = {"1": 1, "a": 2, "eos": 3} + special_tokens = {"eos"} + eos_token_id = 3 + + def convert_token_to_string(self, token): + return token + + regex_str = "[1-9]" + tokenizer = MockTokenizer() + + fsm = RegexGuide.from_regex(regex_str, tokenizer) + + serialized = pickle.dumps(fsm) + deserialized = pickle.loads(serialized) + + assert fsm.eos_tensor == deserialized.eos_tensor + assert fsm.initial_state == deserialized.initial_state + + +@pytest.mark.parametrize( + "hf_tokenizer_uri, revision", + [ + ("openai-community/gpt2", "607a30d783dfa663caf39e06633721c8d4cfcd7e"), + ("microsoft/phi-2", "ef382358ec9e382308935a992d908de099b64c23"), + ("Qwen/Qwen1.5-0.5B-Chat", "4d14e384a4b037942bb3f3016665157c8bcb70ea"), + ( + "NousResearch/Hermes-2-Pro-Llama-3-8B", + "783fd50eb82d7f57758de033861f54d62dde234f", + ), + ], +) +def test_complex_serialization(hf_tokenizer_uri, revision): + # The combined regular expressions of a lexer state in a Python grammar + regex_str = "(?:(?:[0-9](?:(?:_)?[0-9])*(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*|(?:[0-9](?:(?:_)?[0-9])*\\.(?:[0-9](?:(?:_)?[0-9])*)?|\\.[0-9](?:(?:_)?[0-9])*)(?:(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*)?)|[0-9](?:(?:_)?[0-9])*)(?:J|j)|(?:[0-9](?:(?:_)?[0-9])*(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*|(?:[0-9](?:(?:_)?[0-9])*\\.(?:[0-9](?:(?:_)?[0-9])*)?|\\.[0-9](?:(?:_)?[0-9])*)(?:(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*)?)|0(?:x|X)(?:(?:_)?(?:[0-9]|[a-f]|[A-F]))+|0(?:b|B)(?:(?:_)?[0-1])+|0(?:o|O)(?:(?:_)?[0-7])+|(?:(?i:([ubf]?r?|r[ubf])('([^\\\\']|.)*?'))|(?i:([ubf]?r?|r[ubf])(\"([^\\\"]|.)*?\")))|(?:(?:\r?\n[\t ]*|#[^\n]*))+|[1-9](?:(?:_)?[0-9])*|\\\\[\t \x0c]*\r?\n|continue|nonlocal|assert|global|import|lambda|return|async|await|break|class|False|match|raise|while|yield|case|from|None|pass|True|with|def|del|for|not|try|if|[^\\W\\d]\\w*|#[^\n]*|[\t \x0c]+|\\.\\.\\.|@|\\{|\\(|\\[|\\-|\\+|\\*|\\~" + + tokenizer = AutoTokenizer.from_pretrained(hf_tokenizer_uri, revision=revision) + tokenizer = TransformerTokenizer(tokenizer) + + fsm = RegexGuide.from_regex(regex_str, tokenizer) + + start = timer() + serialized = pickle.dumps(fsm) + serialization_time = timer() - start + + # Measure deserialization time + start = timer() + deserialized = pickle.loads(serialized) + deserialization_time = timer() - start + + assert fsm.eos_tensor == deserialized.eos_tensor + assert fsm.initial_state == deserialized.initial_state + + # Print or log the timing results + print(f"Serialization time: {serialization_time:.6f} seconds") + print(f"Deserialization time: {deserialization_time:.6f} seconds") From 8754128863276c920e0fc4157c33d38125d72127 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jannis=20Sch=C3=B6nleber?= <joennlae@gmail.com> Date: Sat, 30 Nov 2024 20:22:33 +0100 Subject: [PATCH 3/5] feat(pickle): change to `bincode` for slighly faster serialize and deserialize times --- Cargo.toml | 1 + src/index.rs | 5 ++--- src/python_bindings/mod.rs | 14 +++++++------- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 94eab3a0..2082b140 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,6 +15,7 @@ regex = "1.10.6" serde-pyobject = "0.4.0" serde_json = { version = "1.0", features = ["preserve_order"] } serde = {version = "1.0", features = ["derive"]} +bincode = "2.0.0-rc.3" # Fragile dependencies, minor updates often break the code hf-hub = "=0.3.2" tokenizers = { version = "=0.20.3", features = ["http"] } diff --git a/src/index.rs b/src/index.rs index c37defdf..a756445c 100644 --- a/src/index.rs +++ b/src/index.rs @@ -1,10 +1,9 @@ -use serde::{Deserialize, Serialize}; - /// Construct an Index. use crate::prelude::{State, TransitionKey}; use crate::regex::{get_vocabulary_transition_keys, state_scan_tokens}; use crate::vocabulary::Vocabulary; use crate::{Error, Result}; +use bincode::{Decode, Encode}; use std::collections::{HashMap, HashSet}; #[derive(Debug)] @@ -34,7 +33,7 @@ impl FSMInfo { } } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Encode, Decode)] pub struct Index { initial: u32, finals: HashSet<u32>, diff --git a/src/python_bindings/mod.rs b/src/python_bindings/mod.rs index 5697a00c..086fdb60 100644 --- a/src/python_bindings/mod.rs +++ b/src/python_bindings/mod.rs @@ -5,6 +5,7 @@ use crate::regex::get_token_transition_keys; use crate::regex::get_vocabulary_transition_keys; use crate::regex::state_scan_tokens; use crate::regex::walk_fsm; +use bincode::config; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; use pyo3::types::PyDict; @@ -89,20 +90,19 @@ impl PyIndex { .map_err(Into::into) } - fn __reduce__(&self) -> PyResult<(PyObject, (String,))> { + fn __reduce__(&self) -> PyResult<(PyObject, (Vec<u8>,))> { Python::with_gil(|py| { let cls = PyModule::import_bound(py, "outlines_core.fsm.outlines_core_rs")? .getattr("Index")?; - let json_data = serde_json::to_string(&self.0) - .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?; - Ok((cls.getattr("from_json")?.to_object(py), (json_data,))) + let binary_data: Vec<u8> = bincode::encode_to_vec(&self.0, config::standard()).unwrap(); + Ok((cls.getattr("from_binary")?.to_object(py), (binary_data,))) }) } #[staticmethod] - fn from_json(json_data: String) -> PyResult<Self> { - let index: Index = serde_json::from_str(&json_data) - .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?; + fn from_binary(binary_data: Vec<u8>) -> PyResult<Self> { + let (index, _): (Index, usize) = + bincode::decode_from_slice(&binary_data[..], config::standard()).unwrap(); Ok(PyIndex(index)) } From 412ef296392a0814a5490ccc15080e79f98cd411 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jannis=20Sch=C3=B6nleber?= <joennlae@gmail.com> Date: Sat, 30 Nov 2024 20:24:16 +0100 Subject: [PATCH 4/5] refactor(pickle): remove timing infra + remove lgos --- tests/fsm/test_serialization.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/tests/fsm/test_serialization.py b/tests/fsm/test_serialization.py index 33850f8e..d3c38365 100644 --- a/tests/fsm/test_serialization.py +++ b/tests/fsm/test_serialization.py @@ -1,5 +1,4 @@ import pickle -from timeit import default_timer as timer import pytest from outlines_core.fsm.guide import RegexGuide @@ -50,18 +49,8 @@ def test_complex_serialization(hf_tokenizer_uri, revision): fsm = RegexGuide.from_regex(regex_str, tokenizer) - start = timer() serialized = pickle.dumps(fsm) - serialization_time = timer() - start - - # Measure deserialization time - start = timer() deserialized = pickle.loads(serialized) - deserialization_time = timer() - start assert fsm.eos_tensor == deserialized.eos_tensor assert fsm.initial_state == deserialized.initial_state - - # Print or log the timing results - print(f"Serialization time: {serialization_time:.6f} seconds") - print(f"Deserialization time: {deserialization_time:.6f} seconds") From f0c2f3f7a3e991af15b9ec177116aba485e5eb14 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jannis=20Sch=C3=B6nleber?= <joennlae@gmail.com> Date: Mon, 2 Dec 2024 19:59:32 +0100 Subject: [PATCH 5/5] chore(pickle): handle unwraps with error message --- src/python_bindings/mod.rs | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/python_bindings/mod.rs b/src/python_bindings/mod.rs index 086fdb60..55d979d1 100644 --- a/src/python_bindings/mod.rs +++ b/src/python_bindings/mod.rs @@ -94,7 +94,10 @@ impl PyIndex { Python::with_gil(|py| { let cls = PyModule::import_bound(py, "outlines_core.fsm.outlines_core_rs")? .getattr("Index")?; - let binary_data: Vec<u8> = bincode::encode_to_vec(&self.0, config::standard()).unwrap(); + let binary_data: Vec<u8> = bincode::encode_to_vec(&self.0, config::standard()) + .map_err(|e| { + PyErr::new::<PyValueError, _>(format!("Serialization of Index failed: {}", e)) + })?; Ok((cls.getattr("from_binary")?.to_object(py), (binary_data,))) }) } @@ -102,7 +105,9 @@ impl PyIndex { #[staticmethod] fn from_binary(binary_data: Vec<u8>) -> PyResult<Self> { let (index, _): (Index, usize) = - bincode::decode_from_slice(&binary_data[..], config::standard()).unwrap(); + bincode::decode_from_slice(&binary_data[..], config::standard()).map_err(|e| { + PyErr::new::<PyValueError, _>(format!("Deserialization of Index failed: {}", e)) + })?; Ok(PyIndex(index)) }