diff --git a/Cargo.toml b/Cargo.toml index 5921105..a2d0bdb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -50,6 +50,4 @@ pyo3 = { version = "0.21.2", features = [ "experimental-declarative-modules", ] } validated_struct = "2.1.0" -zenoh = { version = "1.0.0-dev", git = "https://github.com/eclipse-zenoh/zenoh.git", branch = "main", features = ["unstable", - "internal", -], default-features = false } +zenoh = { version = "1.0.0-dev", git = "https://github.com/eclipse-zenoh/zenoh.git", branch = "main", features = ["unstable", "internal"], default-features = false } diff --git a/examples/z_pub_thr.py b/examples/z_pub_thr.py index 8230c60..68b7f96 100644 --- a/examples/z_pub_thr.py +++ b/examples/z_pub_thr.py @@ -21,7 +21,7 @@ def main(conf: zenoh.Config, payload_size: int): data = bytearray() for i in range(0, payload_size): data.append(i % 10) - data = zenoh.ZBytes(bytes(data)) + data = zenoh.ZBytes(data) congestion_control = zenoh.CongestionControl.BLOCK with zenoh.open(conf) as session: diff --git a/src/bytes.rs b/src/bytes.rs index 34bb9d5..b2b534b 100644 --- a/src/bytes.rs +++ b/src/bytes.rs @@ -11,31 +11,88 @@ // Contributors: // ZettaScale Zenoh Team, // -use std::borrow::Cow; +use std::{borrow::Cow, io::Read}; use pyo3::{ exceptions::{PyTypeError, PyValueError}, prelude::*, sync::GILOnceCell, types::{ - PyBool, PyBytes, PyCFunction, PyDict, PyFloat, PyInt, PyList, PyString, PyTuple, PyType, + PyBool, PyByteArray, PyBytes, PyCFunction, PyDict, PyFloat, PyFrozenSet, PyInt, PyList, + PySet, PyString, PyTuple, PyType, }, PyTypeInfo, }; -use zenoh::internal::buffers::{SplitBuffer, ZBuf}; use crate::{ macros::{downcast_or_new, import, try_import, wrapper}, - utils::{try_process, IntoPyResult, IntoPython, MapInto}, + utils::{IntoPyResult, MapInto}, }; +#[derive(Clone, Copy)] +#[repr(u8)] +enum SupportedType { + ZBytes, + Bytes, + ByteArray, + Str, + Int, + Float, + Bool, + List, + Tuple, + Dict, + Set, + FrozenSet, +} + +impl SupportedType { + fn init_dict(py: Python) -> Py { + fn add_type(py: Python, dict: &Bound, tp: SupportedType) { + dict.set_item(T::type_object_bound(py), tp as u8).unwrap() + } + let dict = PyDict::new_bound(py); + add_type::(py, &dict, SupportedType::ZBytes); + add_type::(py, &dict, SupportedType::Bytes); + add_type::(py, &dict, SupportedType::ByteArray); + add_type::(py, &dict, SupportedType::Str); + add_type::(py, &dict, SupportedType::Int); + add_type::(py, &dict, SupportedType::Float); + add_type::(py, &dict, SupportedType::Bool); + add_type::(py, &dict, SupportedType::List); + add_type::(py, &dict, SupportedType::Tuple); + add_type::(py, &dict, SupportedType::Dict); + add_type::(py, &dict, SupportedType::Set); + add_type::(py, &dict, SupportedType::FrozenSet); + dict.unbind() + } + + fn try_from_py(obj: &Bound) -> Option { + match u8::extract_bound(obj).ok()? { + n if n == Self::ZBytes as u8 => Some(Self::ZBytes), + n if n == Self::Bytes as u8 => Some(Self::Bytes), + n if n == Self::ByteArray as u8 => Some(Self::ByteArray), + n if n == Self::Str as u8 => Some(Self::Str), + n if n == Self::Int as u8 => Some(Self::Int), + n if n == Self::Float as u8 => Some(Self::Float), + n if n == Self::Bool as u8 => Some(Self::Bool), + n if n == Self::List as u8 => Some(Self::List), + n if n == Self::Tuple as u8 => Some(Self::Tuple), + n if n == Self::Dict as u8 => Some(Self::Dict), + n if n == Self::Set as u8 => Some(Self::Set), + n if n == Self::FrozenSet as u8 => Some(Self::FrozenSet), + _ => unreachable!(), + } + } +} + fn serializers(py: Python) -> &'static Py { static SERIALIZERS: GILOnceCell> = GILOnceCell::new(); - SERIALIZERS.get_or_init(py, || PyDict::new_bound(py).unbind()) + SERIALIZERS.get_or_init(py, || SupportedType::init_dict(py)) } fn deserializers(py: Python) -> &'static Py { static DESERIALIZERS: GILOnceCell> = GILOnceCell::new(); - DESERIALIZERS.get_or_init(py, || PyDict::new_bound(py).unbind()) + DESERIALIZERS.get_or_init(py, || SupportedType::init_dict(py)) } fn get_type<'py>(func: &Bound<'py, PyAny>, name: impl ToPyObject) -> PyResult> { @@ -93,61 +150,6 @@ pub(crate) fn serializer( } } -impl ZBytes { - fn serialize_impl(obj: &Bound) -> PyResult { - if let Ok(obj) = Self::extract_bound(obj) { - return Ok(obj); - } - let py = obj.py(); - Ok(Self(if let Ok(b) = obj.downcast::() { - zenoh::bytes::ZBytes::new(b.as_bytes().to_vec()) - } else if let Ok(s) = String::extract_bound(obj) { - zenoh::bytes::ZBytes::serialize(s) - } else if let Ok(i) = i128::extract_bound(obj) { - zenoh::bytes::ZBytes::serialize(i) - } else if let Ok(f) = f64::extract_bound(obj) { - zenoh::bytes::ZBytes::serialize(f) - } else if let Ok(b) = bool::extract_bound(obj) { - zenoh::bytes::ZBytes::serialize(b) - } else if let Ok(list) = obj.downcast::() { - try_process( - list.iter() - .map(|elt| PyResult::Ok(Self::serialize_impl(&elt)?.0)), - |iter| iter.collect(), - )? - } else if let Ok(dict) = obj.downcast::() { - try_process( - dict.iter().map(|(k, v)| { - PyResult::Ok((Self::serialize_impl(&k)?.0, Self::serialize_impl(&v)?.0)) - }), - |iter| iter.collect(), - )? - } else if let Ok(tuple) = obj.downcast::() { - if tuple.len() != 2 { - return Err(PyValueError::new_err( - "only two-elements tuple are supported", - )); - } - zenoh::bytes::ZBytes::serialize(( - Self::serialize_impl(&tuple.get_item(0)?)?, - Self::serialize_impl(&tuple.get_item(1)?)?, - )) - } else if let Ok(Some(ser)) = serializers(py).bind(py).get_item(obj.get_type()) { - return match ZBytes::extract_bound(&ser.call1((obj,))?) { - Ok(b) => Ok(b), - _ => Err(PyTypeError::new_err(format!( - "serializer {} didn't return ZBytes", - ser.repr()? - ))), - }; - } else { - return Err(PyValueError::new_err( - format!("no serializer registered for type {type}", type = obj.get_type().name()?), - )); - })) - } -} - #[pyfunction] #[pyo3(signature = (func = None, /, *, target = None))] pub(crate) fn deserializer( @@ -183,102 +185,196 @@ pub(crate) fn deserializer( wrapper!(zenoh::bytes::ZBytes: Clone, Default); downcast_or_new!(serialize_impl: ZBytes); -#[pymethods] impl ZBytes { - #[new] - fn new(bytes: Option<&Bound>) -> Self { - bytes.map_or_else(Self::default, |b| Self(b.as_bytes().into())) - } - - #[classmethod] - fn serialize(_cls: &Bound, obj: &Bound) -> PyResult { - Self::serialize_impl(obj) + fn serialize_impl(obj: &Bound) -> PyResult { + if let Ok(obj) = Self::extract_bound(obj) { + return Ok(obj); + } + let py = obj.py(); + let Ok(Some(serializer)) = serializers(py).bind(py).get_item(obj.get_type()) else { + return Err(PyValueError::new_err( + format!("no serializer registered for type {type}", type = obj.get_type().name()?), + )); + }; + let Some(tp) = SupportedType::try_from_py(&serializer) else { + return match ZBytes::extract_bound(&serializer.call1((obj,))?) { + Ok(b) => Ok(b), + _ => Err(PyTypeError::new_err(format!( + "serializer {} didn't return ZBytes", + serializer.repr()? + ))), + }; + }; + let serialize_item = |elt| PyResult::Ok(Self::serialize_impl(&elt)?.0); + let serialize_pair = + |(k, v)| PyResult::Ok((Self::serialize_impl(&k)?.0, Self::serialize_impl(&v)?.0)); + Ok(Self(match tp { + SupportedType::ZBytes => ZBytes::extract_bound(obj)?.0, + SupportedType::Bytes | SupportedType::ByteArray => { + >::extract_bound(obj)?.into() + } + SupportedType::Str => String::extract_bound(obj)?.into(), + SupportedType::Int => i64::extract_bound(obj)?.into(), + SupportedType::Float => f64::extract_bound(obj)?.into(), + SupportedType::Bool => bool::extract_bound(obj)?.into(), + SupportedType::List => obj + .downcast::()? + .into_iter() + .map(serialize_item) + .collect::>()?, + SupportedType::Tuple => obj + .downcast::()? + .into_iter() + .map(serialize_item) + .collect::>()?, + SupportedType::Dict => obj + .downcast::()? + .into_iter() + .map(serialize_pair) + .collect::>()?, + SupportedType::Set => obj + .downcast::()? + .into_iter() + .map(serialize_item) + .collect::>()?, + SupportedType::FrozenSet => obj + .downcast::()? + .into_iter() + .map(serialize_item) + .collect::>()?, + })) } - fn deserialize(this: PyRef, tp: &Bound) -> PyResult { + fn deserialize_impl(this: PyRef, tp: &Bound) -> PyResult { let py = tp.py(); - Ok(if tp.eq(PyBytes::type_object_bound(py))? { - this.__bytes__(py).into_any().unbind() - } else if tp.eq(PyString::type_object_bound(py))? { - this.0.deserialize::>().into_pyres()?.into_py(py) - } else if tp.eq(PyInt::type_object_bound(py))? { - this.0.deserialize::().into_pyres()?.into_py(py) - } else if tp.eq(PyFloat::type_object_bound(py))? { - this.0.deserialize::().into_pyres()?.into_py(py) - } else if tp.eq(PyBool::type_object_bound(py))? { - this.0.deserialize::().into_pyres()?.into_py(py) - } else if tp.eq(PyList::type_object_bound(py))? { - let list = PyList::empty_bound(py); - for elt in this.0.iter::() { - list.append(Self(elt.into_pyres()?).into_py(py))?; - } - list.into_py(py) - } else if tp.eq(PyDict::type_object_bound(py))? { - let dict = PyDict::new_bound(py); - for kv in this - .0 - .iter::<(zenoh::bytes::ZBytes, zenoh::bytes::ZBytes)>() + let Ok(Some(deserializer)) = deserializers(py).bind(py).get_item(tp) else { + if try_import!(py, types.GenericAlias) + .is_ok_and(|alias| tp.is_instance(alias).unwrap_or(false)) { - let (k, v) = kv.into_pyres()?; - dict.set_item(k.into_pyobject(py), v.into_pyobject(py))?; + return this.deserialize_generic(tp); } - dict.into_py(py) - } else if try_import!(py, types.GenericAlias) - .is_ok_and(|alias| tp.is_instance(alias).unwrap_or(false)) - { - let origin = import!(py, typing.get_origin).call1((tp,))?; - let args = import!(py, typing.get_args) - .call1((tp,))? - .downcast_into::()?; - let deserialize = - |bytes, tp| Self::deserialize(Py::new(py, Self(bytes)).unwrap().borrow(py), tp); - if origin.eq(PyList::type_object_bound(py))? { - let tp = args.get_item(0)?; - let list = PyList::empty_bound(py); - for elt in this.0.iter::() { - list.append(deserialize(elt.into_pyres()?, &tp)?)?; - } - list.into_py(py) - } else if origin.eq(PyTuple::type_object_bound(py))? - && args.len() == 2 - && args.get_item(1).is_ok_and(|item| !item.is(&py.Ellipsis())) - { - let tp_k = args.get_item(0)?; - let tp_v = args.get_item(1)?; - let (k, v): (zenoh::bytes::ZBytes, zenoh::bytes::ZBytes) = - this.0.deserialize().into_pyres()?; - PyTuple::new_bound(py, [deserialize(k, &tp_k)?, deserialize(v, &tp_v)?]).into_py(py) - } else if origin.eq(PyDict::type_object_bound(py))? { - let tp_k = args.get_item(0)?; - let tp_v = args.get_item(1)?; + return Err(PyValueError::new_err(format!( + "no deserializer registered for {tp:?}" + ))); + }; + let Some(tp) = SupportedType::try_from_py(&deserializer) else { + return Ok(deserializer.call1((this,))?.unbind()); + }; + let into_py = |zbytes| Self(zbytes).into_py(py); + let to_vec = || Vec::from_iter(this.0.iter().map(Result::unwrap).map(into_py)); + Ok(match tp { + SupportedType::ZBytes => this.into_py(py), + SupportedType::Bytes => this.__bytes__(py)?.into_py(py), + SupportedType::ByteArray => PyByteArray::new_bound_with(py, this.0.len(), |bytes| { + this.0.reader().read_exact(bytes).into_pyres() + })? + .into_py(py), + SupportedType::Str => this.0.deserialize::>().into_pyres()?.into_py(py), + SupportedType::Int => this.0.deserialize::().into_pyres()?.into_py(py), + SupportedType::Float => this.0.deserialize::().into_pyres()?.into_py(py), + SupportedType::Bool => this.0.deserialize::().into_pyres()?.into_py(py), + SupportedType::List => PyList::new_bound(py, to_vec()).into_py(py), + SupportedType::Tuple => PyTuple::new_bound(py, to_vec()).into_py(py), + SupportedType::Dict => { let dict = PyDict::new_bound(py); - for kv in this - .0 - .iter::<(zenoh::bytes::ZBytes, zenoh::bytes::ZBytes)>() - { + for kv in this.0.iter() { let (k, v) = kv.into_pyres()?; - dict.set_item(deserialize(k, &tp_k)?, deserialize(v, &tp_v)?)?; + dict.set_item(Self(k).into_py(py), Self(v).into_py(py))?; } dict.into_py(py) - } else { - return Err(PyValueError::new_err( - "only list[Any], dict[Any, Any] or tuple[Any, Any] are supported as generic type", - )); } - } else if tp.eq(Self::type_object_bound(py))? { - this.into_py(py) - } else if let Ok(Some(de)) = deserializers(py).bind(py).get_item(tp) { - de.call1((this,))?.unbind() - } else if let Ok(tp) = tp.downcast::() { - return Err(PyValueError::new_err( - format!("no deserializer registered for type {type}", type = tp.name()?), - )); + SupportedType::Set => PySet::new_bound(py, &to_vec())?.into_py(py), + SupportedType::FrozenSet => PyFrozenSet::new_bound(py, &to_vec())?.into_py(py), + }) + } + + fn deserialize_generic(&self, tp: &Bound) -> PyResult { + let py = tp.py(); + let origin = import!(py, typing.get_origin).call1((tp,))?; + let args = import!(py, typing.get_args) + .call1((tp,))? + .downcast_into::()?; + let deserialize = |tp| { + move |zbytes: Result<_, _>| { + Self::deserialize_impl(Py::new(py, Self(zbytes.unwrap())).unwrap().borrow(py), &tp) + } + }; + Ok(if origin.eq(PyList::type_object_bound(py))? { + let vec: Vec<_> = Result::from_iter(self.0.iter().map(deserialize(args.get_item(0)?)))?; + PyList::new_bound(py, vec).into_py(py) + } else if origin.eq(PyTuple::type_object_bound(py))? + && args.len() == 2 + && args.get_item(1).is_ok_and(|item| item.is(&py.Ellipsis())) + { + let vec: Vec<_> = Result::from_iter(self.0.iter().map(deserialize(args.get_item(0)?)))?; + PyTuple::new_bound(py, vec).into_py(py) + } else if origin.eq(PyTuple::type_object_bound(py))? { + let mut zbytes_iter = self.0.iter(); + let mut tp_iter = args.iter(); + let vec = zbytes_iter + .by_ref() + .zip(tp_iter.by_ref()) + .map(|(zbytes, tp)| deserialize(tp)(zbytes)) + .collect::, _>>()?; + let remaining = zbytes_iter.count(); + if remaining > 0 || tp_iter.next().is_some() { + return Err(PyTypeError::new_err(format!( + "tuple length doesn't match, found {}", + vec.len() + remaining + ))); + } + PyTuple::new_bound(py, vec).into_py(py) + } else if origin.eq(PyDict::type_object_bound(py))? { + let deserialize_key = deserialize(args.get_item(0)?); + let deserialize_value = deserialize(args.get_item(1)?); + let dict = PyDict::new_bound(py); + for kv in self.0.iter() { + let (k, v) = kv.into_pyres()?; + dict.set_item(deserialize_key(Ok(k))?, deserialize_value(Ok(v))?)?; + } + dict.into_py(py) + } else if origin.eq(PySet::type_object_bound(py))? { + let vec: Vec<_> = Result::from_iter(self.0.iter().map(deserialize(args.get_item(0)?)))?; + PySet::new_bound(py, &vec)?.into_py(py) + } else if origin.eq(PyFrozenSet::type_object_bound(py))? { + let vec: Vec<_> = Result::from_iter(self.0.iter().map(deserialize(args.get_item(0)?)))?; + PyFrozenSet::new_bound(py, &vec)?.into_py(py) } else { - return Err(PyTypeError::new_err( - format!("expected a type, found {type}", type = tp.get_type().name()?), + return Err(PyValueError::new_err( + "only `list`/`tuple`/`dict`/`set`/`frozenset` are supported as generic type", )); }) } +} + +#[pymethods] +impl ZBytes { + #[new] + fn new(obj: Option<&Bound>) -> PyResult { + let Some(obj) = obj else { + return Ok(Self::default()); + }; + if let Ok(bytes) = obj.downcast::() { + // SAFETY: bytes is immediately copied + Ok(Self(unsafe { bytes.as_bytes() }.into())) + } else if let Ok(bytes) = obj.downcast::() { + Ok(Self(bytes.as_bytes().into())) + } else { + Err(PyTypeError::new_err(format!( + "expected buffer type, found '{}'", + obj.get_type().name().unwrap() + ))) + } + } + + #[classmethod] + fn serialize(_cls: &Bound, obj: &Bound) -> PyResult { + Self::serialize_impl(obj) + } + + fn deserialize(this: PyRef, tp: &Bound) -> PyResult { + Self::deserialize_impl(this, tp) + } fn __len__(&self) -> usize { self.0.len() @@ -288,16 +384,10 @@ impl ZBytes { !self.0.is_empty() } - fn __bytes__<'py>(&self, py: Python<'py>) -> Bound<'py, PyBytes> { - PyBytes::new_bound_with(py, self.0.len(), |mut bytes| { - for slice in ZBuf::from(&self.0).slices() { - let len = slice.len(); - bytes[..len].copy_from_slice(slice); - bytes = &mut bytes[len..]; - } - Ok(()) + fn __bytes__<'py>(&self, py: Python<'py>) -> PyResult> { + PyBytes::new_bound_with(py, self.0.len(), |bytes| { + self.0.reader().read_exact(bytes).into_pyres() }) - .unwrap() } fn __eq__(&self, other: &Bound) -> PyResult { @@ -305,7 +395,7 @@ impl ZBytes { } fn __hash__(&self, py: Python) -> PyResult { - self.__bytes__(py).hash() + self.__bytes__(py)?.hash() } fn __repr__(&self) -> String { diff --git a/src/utils.rs b/src/utils.rs index fa70c11..fd9c250 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -121,25 +121,6 @@ impl>, T, E> Iterator for TryProcessIter<'_, I, } } -pub(crate) fn try_process( - iter: I, - process: impl FnOnce(TryProcessIter<'_, I::IntoIter, E>) -> R, -) -> Result -where - I: IntoIterator>, -{ - let mut error = None; - let iter = TryProcessIter { - iter: iter.into_iter(), - error: &mut error, - }; - let res = process(iter); - if let Some(err) = error { - return Err(err); - } - Ok(res) -} - pub(crate) fn short_type_name() -> &'static str { let name = std::any::type_name::(); name.rsplit_once("::").map_or(name, |(_, name)| name) diff --git a/tests/test_override_serializer.py b/tests/test_override_serializer.py new file mode 100644 index 0000000..89d89c3 --- /dev/null +++ b/tests/test_override_serializer.py @@ -0,0 +1,30 @@ +# +# Copyright (c) 2024 ZettaScale Technology +# +# This program and the accompanying materials are made available under the +# terms of the Eclipse Public License 2.0 which is available at +# http://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0 +# which is available at https://www.apache.org/licenses/LICENSE-2.0. +# +# SPDX-License-Identifier: EPL-2.0 OR Apache-2.0 +# +# Contributors: +# ZettaScale Zenoh Team, +# + +from zenoh import ZBytes, deserializer, serializer + + +def test_override_serializer(): + assert ZBytes.serialize(42) != ZBytes.serialize("42") + + @deserializer + def deserialize_int_from_str(zbytes: ZBytes) -> int: + return int(zbytes.deserialize(str)) + + @serializer + def serialize_int_as_str(foo: int) -> ZBytes: + return ZBytes.serialize(str(foo)) + + assert ZBytes.serialize(42).deserialize(int) == 42 + assert ZBytes.serialize(42) == ZBytes.serialize("42") diff --git a/tests/test_serializer.py b/tests/test_serializer.py index fd35ff3..9516562 100644 --- a/tests/test_serializer.py +++ b/tests/test_serializer.py @@ -19,21 +19,28 @@ from zenoh import ZBytes, deserializer, serializer default_serializer_tests = [ + (ZBytes, ZBytes.serialize(b"foo")), (bytes, b"foo"), + (bytearray, bytearray(b"foo")), (str, "foo"), (int, 42), (float, 0.5), (bool, True), - (ZBytes, ZBytes.serialize(b"foo")), (list, [ZBytes.serialize(0), ZBytes.serialize(1)]), + (tuple, (ZBytes.serialize(0), ZBytes.serialize(1))), (dict, {ZBytes.serialize("foo"): ZBytes.serialize("bar")}), + (set, {ZBytes.serialize(0), ZBytes.serialize(1)}), + (frozenset, frozenset([ZBytes.serialize(0), ZBytes.serialize(1)])), ] if sys.version_info >= (3, 9): default_serializer_tests = [ *default_serializer_tests, (list[int], [0, 1, 2]), - (dict[str, str], {"foo": "bar"}), (tuple[int, int], (0, 1)), + (tuple[int, ...], (0, 1, 2)), + (dict[str, str], {"foo": "bar"}), + (set[int], {0, 1, 2}), + (frozenset[int], frozenset([0, 1, 2])), (list[tuple[int, int]], [(0, 1), (2, 3)]), ] diff --git a/zenoh/__init__.pyi b/zenoh/__init__.pyi index 2fa1088..c8c8ee3 100644 --- a/zenoh/__init__.pyi +++ b/zenoh/__init__.pyi @@ -953,7 +953,7 @@ _IntoWhatAmIMatcher = WhatAmIMatcher | str class ZBytes: """ZBytes contains the serialized bytes of user data.""" - def __new__(cls, bytes: bytes = None) -> Self: ... + def __new__(cls, bytes: bytes | bytearray = None) -> Self: ... @classmethod def serialize(cls, obj: Any) -> Self: """Serialize object according to its type,