From e0a1e78beb6b77db2c68e83d85660f1b101a35a4 Mon Sep 17 00:00:00 2001 From: arvidn Date: Sun, 29 Oct 2023 19:53:16 +0100 Subject: [PATCH] make from_bytes() accept python's buffer API (i.e. memoryview) --- chia_py_streamable_macro/src/lib.rs | 10 ++++++++-- tests/test_streamable.py | 4 ++++ 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/chia_py_streamable_macro/src/lib.rs b/chia_py_streamable_macro/src/lib.rs index 99d37f6fe..724ce1593 100644 --- a/chia_py_streamable_macro/src/lib.rs +++ b/chia_py_streamable_macro/src/lib.rs @@ -120,8 +120,14 @@ pub fn py_streamable_macro(input: proc_macro::TokenStream) -> proc_macro::TokenS impl #ident { #[staticmethod] #[pyo3(name = "from_bytes")] - pub fn py_from_bytes(blob: &[u8]) -> pyo3::PyResult { - let mut input = std::io::Cursor::<&[u8]>::new(blob); + pub fn py_from_bytes(blob: pyo3::buffer::PyBuffer) -> pyo3::PyResult { + if !blob.is_c_contiguous() { + panic!("from_bytes() must be called with a contiguous buffer"); + } + let slice = unsafe { + std::slice::from_raw_parts(blob.buf_ptr() as *const u8, blob.len_bytes()) + }; + let mut input = std::io::Cursor::<&[u8]>::new(slice); ::parse(&mut input).map_err(|e| <#crate_name::chia_error::Error as Into>::into(e)) } diff --git a/tests/test_streamable.py b/tests/test_streamable.py index 0ca46a468..af90f973d 100644 --- a/tests/test_streamable.py +++ b/tests/test_streamable.py @@ -395,6 +395,10 @@ def test_program() -> None: assert str(p) == "Program(ff8080)" assert p.to_bytes() == bytes.fromhex("ff8080") + # make sure we can pass in a slice/memoryview + p = Program.from_bytes(bytes.fromhex("00ff8080")[1:]) + assert str(p) == "Program(ff8080)" + # truncated serialization with pytest.raises(ValueError, match="unexpected end of buffer"): Program.from_bytes(bytes.fromhex("ff80"))