From 3427117cae8ae9865d7fba043e41a7dfd484cbd0 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jannis=20Sch=C3=B6nleber?= <joennlae@gmail.com>
Date: Sat, 30 Nov 2024 20:01:32 +0100
Subject: [PATCH 1/5] feat(pickle): make `Index` pickleable by using `serde`

---
 src/index.rs               |  4 +++-
 src/python_bindings/mod.rs | 19 ++++++++++++++++++-
 2 files changed, 21 insertions(+), 2 deletions(-)

diff --git a/src/index.rs b/src/index.rs
index cc1187e8..c37defdf 100644
--- a/src/index.rs
+++ b/src/index.rs
@@ -1,3 +1,5 @@
+use serde::{Deserialize, Serialize};
+
 /// Construct an Index.
 use crate::prelude::{State, TransitionKey};
 use crate::regex::{get_vocabulary_transition_keys, state_scan_tokens};
@@ -32,7 +34,7 @@ impl FSMInfo {
     }
 }
 
-#[derive(Debug)]
+#[derive(Debug, Serialize, Deserialize)]
 pub struct Index {
     initial: u32,
     finals: HashSet<u32>,
diff --git a/src/python_bindings/mod.rs b/src/python_bindings/mod.rs
index 046d7ce9..5697a00c 100644
--- a/src/python_bindings/mod.rs
+++ b/src/python_bindings/mod.rs
@@ -72,7 +72,7 @@ impl PyFSMInfo {
     }
 }
 
-#[pyclass(name = "Index")]
+#[pyclass(name = "Index", module = "outlines_core.fsm.outlines_core_rs")]
 pub struct PyIndex(Index);
 
 #[pymethods]
@@ -89,6 +89,23 @@ impl PyIndex {
             .map_err(Into::into)
     }
 
+    fn __reduce__(&self) -> PyResult<(PyObject, (String,))> {
+        Python::with_gil(|py| {
+            let cls = PyModule::import_bound(py, "outlines_core.fsm.outlines_core_rs")?
+                .getattr("Index")?;
+            let json_data = serde_json::to_string(&self.0)
+                .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?;
+            Ok((cls.getattr("from_json")?.to_object(py), (json_data,)))
+        })
+    }
+
+    #[staticmethod]
+    fn from_json(json_data: String) -> PyResult<Self> {
+        let index: Index = serde_json::from_str(&json_data)
+            .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?;
+        Ok(PyIndex(index))
+    }
+
     fn get_allowed_tokens(&self, state: u32) -> Option<Vec<u32>> {
         self.0.allowed_tokens(state)
     }

From b88907533e6fcc74840f84ef84a0a529b88220d0 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jannis=20Sch=C3=B6nleber?= <joennlae@gmail.com>
Date: Sat, 30 Nov 2024 20:04:51 +0100
Subject: [PATCH 2/5] test(pickle): add simple + complex pickle test

---
 tests/fsm/test_serialization.py | 67 +++++++++++++++++++++++++++++++++
 1 file changed, 67 insertions(+)
 create mode 100644 tests/fsm/test_serialization.py

diff --git a/tests/fsm/test_serialization.py b/tests/fsm/test_serialization.py
new file mode 100644
index 00000000..33850f8e
--- /dev/null
+++ b/tests/fsm/test_serialization.py
@@ -0,0 +1,67 @@
+import pickle
+from timeit import default_timer as timer
+
+import pytest
+from outlines_core.fsm.guide import RegexGuide
+from transformers import AutoTokenizer
+
+from tests.fsm.test_regex import TransformerTokenizer
+
+
+def test_serialization():
+    class MockTokenizer:
+        vocabulary = {"1": 1, "a": 2, "eos": 3}
+        special_tokens = {"eos"}
+        eos_token_id = 3
+
+        def convert_token_to_string(self, token):
+            return token
+
+    regex_str = "[1-9]"
+    tokenizer = MockTokenizer()
+
+    fsm = RegexGuide.from_regex(regex_str, tokenizer)
+
+    serialized = pickle.dumps(fsm)
+    deserialized = pickle.loads(serialized)
+
+    assert fsm.eos_tensor == deserialized.eos_tensor
+    assert fsm.initial_state == deserialized.initial_state
+
+
+@pytest.mark.parametrize(
+    "hf_tokenizer_uri, revision",
+    [
+        ("openai-community/gpt2", "607a30d783dfa663caf39e06633721c8d4cfcd7e"),
+        ("microsoft/phi-2", "ef382358ec9e382308935a992d908de099b64c23"),
+        ("Qwen/Qwen1.5-0.5B-Chat", "4d14e384a4b037942bb3f3016665157c8bcb70ea"),
+        (
+            "NousResearch/Hermes-2-Pro-Llama-3-8B",
+            "783fd50eb82d7f57758de033861f54d62dde234f",
+        ),
+    ],
+)
+def test_complex_serialization(hf_tokenizer_uri, revision):
+    # The combined regular expressions of a lexer state in a Python grammar
+    regex_str = "(?:(?:[0-9](?:(?:_)?[0-9])*(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*|(?:[0-9](?:(?:_)?[0-9])*\\.(?:[0-9](?:(?:_)?[0-9])*)?|\\.[0-9](?:(?:_)?[0-9])*)(?:(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*)?)|[0-9](?:(?:_)?[0-9])*)(?:J|j)|(?:[0-9](?:(?:_)?[0-9])*(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*|(?:[0-9](?:(?:_)?[0-9])*\\.(?:[0-9](?:(?:_)?[0-9])*)?|\\.[0-9](?:(?:_)?[0-9])*)(?:(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*)?)|0(?:x|X)(?:(?:_)?(?:[0-9]|[a-f]|[A-F]))+|0(?:b|B)(?:(?:_)?[0-1])+|0(?:o|O)(?:(?:_)?[0-7])+|(?:(?i:([ubf]?r?|r[ubf])('([^\\\\']|.)*?'))|(?i:([ubf]?r?|r[ubf])(\"([^\\\"]|.)*?\")))|(?:(?:\r?\n[\t ]*|#[^\n]*))+|[1-9](?:(?:_)?[0-9])*|\\\\[\t \x0c]*\r?\n|continue|nonlocal|assert|global|import|lambda|return|async|await|break|class|False|match|raise|while|yield|case|from|None|pass|True|with|def|del|for|not|try|if|[^\\W\\d]\\w*|#[^\n]*|[\t \x0c]+|\\.\\.\\.|@|\\{|\\(|\\[|\\-|\\+|\\*|\\~"
+
+    tokenizer = AutoTokenizer.from_pretrained(hf_tokenizer_uri, revision=revision)
+    tokenizer = TransformerTokenizer(tokenizer)
+
+    fsm = RegexGuide.from_regex(regex_str, tokenizer)
+
+    start = timer()
+    serialized = pickle.dumps(fsm)
+    serialization_time = timer() - start
+
+    # Measure deserialization time
+    start = timer()
+    deserialized = pickle.loads(serialized)
+    deserialization_time = timer() - start
+
+    assert fsm.eos_tensor == deserialized.eos_tensor
+    assert fsm.initial_state == deserialized.initial_state
+
+    # Print or log the timing results
+    print(f"Serialization time: {serialization_time:.6f} seconds")
+    print(f"Deserialization time: {deserialization_time:.6f} seconds")

From 8754128863276c920e0fc4157c33d38125d72127 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jannis=20Sch=C3=B6nleber?= <joennlae@gmail.com>
Date: Sat, 30 Nov 2024 20:22:33 +0100
Subject: [PATCH 3/5] feat(pickle): change to `bincode` for slighly faster
 serialize and deserialize times

---
 Cargo.toml                 |  1 +
 src/index.rs               |  5 ++---
 src/python_bindings/mod.rs | 14 +++++++-------
 3 files changed, 10 insertions(+), 10 deletions(-)

diff --git a/Cargo.toml b/Cargo.toml
index 94eab3a0..2082b140 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -15,6 +15,7 @@ regex = "1.10.6"
 serde-pyobject = "0.4.0"
 serde_json = { version = "1.0", features = ["preserve_order"] }
 serde = {version = "1.0", features = ["derive"]}
+bincode = "2.0.0-rc.3"
 # Fragile dependencies, minor updates often break the code
 hf-hub = "=0.3.2"
 tokenizers = { version = "=0.20.3", features = ["http"] }
diff --git a/src/index.rs b/src/index.rs
index c37defdf..a756445c 100644
--- a/src/index.rs
+++ b/src/index.rs
@@ -1,10 +1,9 @@
-use serde::{Deserialize, Serialize};
-
 /// Construct an Index.
 use crate::prelude::{State, TransitionKey};
 use crate::regex::{get_vocabulary_transition_keys, state_scan_tokens};
 use crate::vocabulary::Vocabulary;
 use crate::{Error, Result};
+use bincode::{Decode, Encode};
 use std::collections::{HashMap, HashSet};
 
 #[derive(Debug)]
@@ -34,7 +33,7 @@ impl FSMInfo {
     }
 }
 
-#[derive(Debug, Serialize, Deserialize)]
+#[derive(Debug, Encode, Decode)]
 pub struct Index {
     initial: u32,
     finals: HashSet<u32>,
diff --git a/src/python_bindings/mod.rs b/src/python_bindings/mod.rs
index 5697a00c..086fdb60 100644
--- a/src/python_bindings/mod.rs
+++ b/src/python_bindings/mod.rs
@@ -5,6 +5,7 @@ use crate::regex::get_token_transition_keys;
 use crate::regex::get_vocabulary_transition_keys;
 use crate::regex::state_scan_tokens;
 use crate::regex::walk_fsm;
+use bincode::config;
 use pyo3::exceptions::PyValueError;
 use pyo3::prelude::*;
 use pyo3::types::PyDict;
@@ -89,20 +90,19 @@ impl PyIndex {
             .map_err(Into::into)
     }
 
-    fn __reduce__(&self) -> PyResult<(PyObject, (String,))> {
+    fn __reduce__(&self) -> PyResult<(PyObject, (Vec<u8>,))> {
         Python::with_gil(|py| {
             let cls = PyModule::import_bound(py, "outlines_core.fsm.outlines_core_rs")?
                 .getattr("Index")?;
-            let json_data = serde_json::to_string(&self.0)
-                .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?;
-            Ok((cls.getattr("from_json")?.to_object(py), (json_data,)))
+            let binary_data: Vec<u8> = bincode::encode_to_vec(&self.0, config::standard()).unwrap();
+            Ok((cls.getattr("from_binary")?.to_object(py), (binary_data,)))
         })
     }
 
     #[staticmethod]
-    fn from_json(json_data: String) -> PyResult<Self> {
-        let index: Index = serde_json::from_str(&json_data)
-            .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?;
+    fn from_binary(binary_data: Vec<u8>) -> PyResult<Self> {
+        let (index, _): (Index, usize) =
+            bincode::decode_from_slice(&binary_data[..], config::standard()).unwrap();
         Ok(PyIndex(index))
     }
 

From 412ef296392a0814a5490ccc15080e79f98cd411 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jannis=20Sch=C3=B6nleber?= <joennlae@gmail.com>
Date: Sat, 30 Nov 2024 20:24:16 +0100
Subject: [PATCH 4/5] refactor(pickle): remove timing infra + remove lgos

---
 tests/fsm/test_serialization.py | 11 -----------
 1 file changed, 11 deletions(-)

diff --git a/tests/fsm/test_serialization.py b/tests/fsm/test_serialization.py
index 33850f8e..d3c38365 100644
--- a/tests/fsm/test_serialization.py
+++ b/tests/fsm/test_serialization.py
@@ -1,5 +1,4 @@
 import pickle
-from timeit import default_timer as timer
 
 import pytest
 from outlines_core.fsm.guide import RegexGuide
@@ -50,18 +49,8 @@ def test_complex_serialization(hf_tokenizer_uri, revision):
 
     fsm = RegexGuide.from_regex(regex_str, tokenizer)
 
-    start = timer()
     serialized = pickle.dumps(fsm)
-    serialization_time = timer() - start
-
-    # Measure deserialization time
-    start = timer()
     deserialized = pickle.loads(serialized)
-    deserialization_time = timer() - start
 
     assert fsm.eos_tensor == deserialized.eos_tensor
     assert fsm.initial_state == deserialized.initial_state
-
-    # Print or log the timing results
-    print(f"Serialization time: {serialization_time:.6f} seconds")
-    print(f"Deserialization time: {deserialization_time:.6f} seconds")

From f0c2f3f7a3e991af15b9ec177116aba485e5eb14 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jannis=20Sch=C3=B6nleber?= <joennlae@gmail.com>
Date: Mon, 2 Dec 2024 19:59:32 +0100
Subject: [PATCH 5/5] chore(pickle): handle unwraps with error message

---
 src/python_bindings/mod.rs | 9 +++++++--
 1 file changed, 7 insertions(+), 2 deletions(-)

diff --git a/src/python_bindings/mod.rs b/src/python_bindings/mod.rs
index 086fdb60..55d979d1 100644
--- a/src/python_bindings/mod.rs
+++ b/src/python_bindings/mod.rs
@@ -94,7 +94,10 @@ impl PyIndex {
         Python::with_gil(|py| {
             let cls = PyModule::import_bound(py, "outlines_core.fsm.outlines_core_rs")?
                 .getattr("Index")?;
-            let binary_data: Vec<u8> = bincode::encode_to_vec(&self.0, config::standard()).unwrap();
+            let binary_data: Vec<u8> = bincode::encode_to_vec(&self.0, config::standard())
+                .map_err(|e| {
+                    PyErr::new::<PyValueError, _>(format!("Serialization of Index failed: {}", e))
+                })?;
             Ok((cls.getattr("from_binary")?.to_object(py), (binary_data,)))
         })
     }
@@ -102,7 +105,9 @@ impl PyIndex {
     #[staticmethod]
     fn from_binary(binary_data: Vec<u8>) -> PyResult<Self> {
         let (index, _): (Index, usize) =
-            bincode::decode_from_slice(&binary_data[..], config::standard()).unwrap();
+            bincode::decode_from_slice(&binary_data[..], config::standard()).map_err(|e| {
+                PyErr::new::<PyValueError, _>(format!("Deserialization of Index failed: {}", e))
+            })?;
         Ok(PyIndex(index))
     }