diff --git a/Cargo.toml b/Cargo.toml index a648b09bc..324b93d67 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -43,6 +43,7 @@ libc = { version = "0.2.82", optional = true } matrixmultiply = { version = "0.3.2", default-features = false, features=["cgemm"] } +borsh = { version = "1.2", optional = true, default-features = false } serde = { version = "1.0", optional = true, default-features = false, features = ["alloc"] } rawpointer = { version = "0.2" } @@ -66,7 +67,7 @@ serde-1 = ["serde"] test = [] # This feature is used for docs -docs = ["approx", "approx-0_5", "serde", "rayon"] +docs = ["approx", "approx-0_5", "serde", "borsh", "rayon"] std = ["num-traits/std", "matrixmultiply/std"] rayon = ["rayon_", "std"] diff --git a/src/array_borsh.rs b/src/array_borsh.rs new file mode 100644 index 000000000..abe3550de --- /dev/null +++ b/src/array_borsh.rs @@ -0,0 +1,101 @@ +use crate::imp_prelude::*; +use crate::IntoDimension; +use alloc::vec::Vec; +use borsh::{BorshDeserialize, BorshSerialize}; +use core::ops::Deref; + +/// **Requires crate feature `"borsh"`** +impl BorshSerialize for Dim +where + I: BorshSerialize, +{ + fn serialize(&self, writer: &mut W) -> borsh::io::Result<()> { + ::serialize(&self.ix(), writer) + } +} + +/// **Requires crate feature `"borsh"`** +impl BorshDeserialize for Dim +where + I: BorshDeserialize, +{ + fn deserialize_reader(reader: &mut R) -> borsh::io::Result { + ::deserialize_reader(reader).map(Dim::new) + } +} + +/// **Requires crate feature `"borsh"`** +impl BorshSerialize for IxDyn { + fn serialize(&self, writer: &mut W) -> borsh::io::Result<()> { + let elts = self.ix().deref(); + // Output length of dimensions. + ::serialize(&elts.len(), writer)?; + // Followed by actual data. + for elt in elts { + ::serialize(elt, writer)?; + } + Ok(()) + } +} + +/// **Requires crate feature `"borsh"`** +impl BorshDeserialize for IxDyn { + fn deserialize_reader(reader: &mut R) -> borsh::io::Result { + // Deserialize the length. + let len = ::deserialize_reader(reader)?; + // Deserialize the given number of elements. We assume the source is + // trusted so we use a capacity hint... + let mut elts = Vec::with_capacity(len); + for _ix in 0..len { + elts.push(::deserialize_reader(reader)?); + } + Ok(elts.into_dimension()) + } +} + +/// **Requires crate feature `"borsh"`** +impl BorshSerialize for ArrayBase +where + A: BorshSerialize, + D: Dimension + BorshSerialize, + S: Data, +{ + fn serialize(&self, writer: &mut W) -> borsh::io::Result<()> { + // Dimensions + ::serialize(&self.raw_dim(), writer)?; + // Followed by length of data + let iter = self.iter(); + ::serialize(&iter.len(), writer)?; + // Followed by data itself. + for elt in iter { + ::serialize(elt, writer)?; + } + Ok(()) + } +} + +/// **Requires crate feature `"borsh"`** +impl BorshDeserialize for ArrayBase +where + A: BorshDeserialize, + D: BorshDeserialize + Dimension, + S: DataOwned, +{ + fn deserialize_reader(reader: &mut R) -> borsh::io::Result { + // Dimensions + let dim = ::deserialize_reader(reader)?; + // Followed by length of data + let len = ::deserialize_reader(reader)?; + // Followed by data itself. + let mut data = Vec::with_capacity(len); + for _ix in 0..len { + data.push(::deserialize_reader(reader)?); + } + ArrayBase::from_shape_vec(dim, data).map_err(|_shape_err| { + borsh::io::Error::new( + borsh::io::ErrorKind::InvalidData, + "data and dimensions must match in size", + ) + }) + } +} diff --git a/src/lib.rs b/src/lib.rs index 07e5ed680..b09f526b2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -164,6 +164,8 @@ mod aliases; #[macro_use] mod itertools; mod argument_traits; +#[cfg(feature = "borsh")] +mod array_borsh; #[cfg(feature = "serde")] mod array_serde; mod arrayformat; diff --git a/xtest-serialization/Cargo.toml b/xtest-serialization/Cargo.toml index 857e31fe6..8754d598e 100644 --- a/xtest-serialization/Cargo.toml +++ b/xtest-serialization/Cargo.toml @@ -8,7 +8,7 @@ publish = false test = false [dependencies] -ndarray = { path = "..", features = ["serde"] } +ndarray = { path = "..", features = ["serde", "borsh"] } [features] default = ["ron"] @@ -23,6 +23,10 @@ version = "1.0.40" [dev-dependencies.rmp-serde] version = "0.14.0" +[dev-dependencies.borsh] +version = "1.2" +default-features = false + [dependencies.ron] version = "0.5.1" optional = true diff --git a/xtest-serialization/tests/serialize.rs b/xtest-serialization/tests/serialize.rs index efb3bacd9..95d5917d2 100644 --- a/xtest-serialization/tests/serialize.rs +++ b/xtest-serialization/tests/serialize.rs @@ -9,6 +9,8 @@ extern crate rmp_serde; #[cfg(feature = "ron")] extern crate ron; +extern crate borsh; + use ndarray::{arr0, arr1, arr2, s, ArcArray, ArcArray2, ArrayD, IxDyn}; #[test] @@ -218,3 +220,83 @@ fn serial_many_dim_ron() { assert_eq!(a, a_de); } } + +#[test] +fn serial_ixdyn_borsh() { + { + let a = arr0::(2.72).into_dyn(); + let serial = borsh::to_vec(&a).unwrap(); + println!("Serde encode {:?} => {:?}", a, serial); + let res = borsh::from_slice::>(&serial); + println!("{:?}", res); + assert_eq!(a, res.unwrap()); + } + + { + let a = arr1::(&[2.72, 1., 2.]).into_dyn(); + let serial = borsh::to_vec(&a).unwrap(); + println!("Serde encode {:?} => {:?}", a, serial); + let res = borsh::from_slice::>(&serial); + println!("{:?}", res); + assert_eq!(a, res.unwrap()); + } + + { + let a = arr2(&[[3., 1., 2.2], [3.1, 4., 7.]]) + .into_shape(IxDyn(&[3, 1, 1, 1, 2, 1])) + .unwrap(); + let serial = borsh::to_vec(&a).unwrap(); + println!("Serde encode {:?} => {:?}", a, serial); + let res = borsh::from_slice::>(&serial); + println!("{:?}", res); + assert_eq!(a, res.unwrap()); + } +} + +#[test] +fn serial_many_dim_borsh() { + use borsh::from_slice as borsh_deserialize; + use borsh::to_vec as borsh_serialize; + + { + let a = arr0::(2.72); + + let a_s = borsh_serialize(&a).unwrap(); + + let a_de: ArcArray = borsh_deserialize(&a_s).unwrap(); + + assert_eq!(a, a_de); + } + + { + let a = arr1::(&[2.72, 1., 2.]); + + let a_s = borsh_serialize(&a).unwrap(); + + let a_de: ArcArray = borsh_deserialize(&a_s).unwrap(); + + assert_eq!(a, a_de); + } + + { + let a = arr2(&[[3., 1., 2.2], [3.1, 4., 7.]]); + + let a_s = borsh_serialize(&a).unwrap(); + + let a_de: ArcArray = borsh_deserialize(&a_s).unwrap(); + + assert_eq!(a, a_de); + } + + { + // Test a sliced array. + let mut a = ArcArray::linspace(0., 31., 32).reshape((2, 2, 2, 4)); + a.slice_collapse(s![..;-1, .., .., ..2]); + + let a_s = borsh_serialize(&a).unwrap(); + + let a_de: ArcArray = borsh_deserialize(&a_s).unwrap(); + + assert_eq!(a, a_de); + } +}