diff --git a/chia_py_streamable_macro/src/lib.rs b/chia_py_streamable_macro/src/lib.rs index 99d37f6fe..724ce1593 100644 --- a/chia_py_streamable_macro/src/lib.rs +++ b/chia_py_streamable_macro/src/lib.rs @@ -120,8 +120,14 @@ pub fn py_streamable_macro(input: proc_macro::TokenStream) -> proc_macro::TokenS impl #ident { #[staticmethod] #[pyo3(name = "from_bytes")] - pub fn py_from_bytes(blob: &[u8]) -> pyo3::PyResult { - let mut input = std::io::Cursor::<&[u8]>::new(blob); + pub fn py_from_bytes(blob: pyo3::buffer::PyBuffer) -> pyo3::PyResult { + if !blob.is_c_contiguous() { + panic!("from_bytes() must be called with a contiguous buffer"); + } + let slice = unsafe { + std::slice::from_raw_parts(blob.buf_ptr() as *const u8, blob.len_bytes()) + }; + let mut input = std::io::Cursor::<&[u8]>::new(slice); ::parse(&mut input).map_err(|e| <#crate_name::chia_error::Error as Into>::into(e)) } diff --git a/tests/test_streamable.py b/tests/test_streamable.py index 0ca46a468..af90f973d 100644 --- a/tests/test_streamable.py +++ b/tests/test_streamable.py @@ -395,6 +395,10 @@ def test_program() -> None: assert str(p) == "Program(ff8080)" assert p.to_bytes() == bytes.fromhex("ff8080") + # make sure we can pass in a slice/memoryview + p = Program.from_bytes(bytes.fromhex("00ff8080")[1:]) + assert str(p) == "Program(ff8080)" + # truncated serialization with pytest.raises(ValueError, match="unexpected end of buffer"): Program.from_bytes(bytes.fromhex("ff80"))