diff --git a/tests/merkle_set.py b/tests/merkle_set.py index e71ffad9e..75fa456b2 100644 --- a/tests/merkle_set.py +++ b/tests/merkle_set.py @@ -45,7 +45,7 @@ MIDDLE = bytes([2]) TRUNCATED = bytes([3]) -BLANK = bytes32([0] * 32) +BLANK = bytes32.zeros prehashed: Dict[bytes, _Hash] = {} diff --git a/tests/test_merkle_set.py b/tests/test_merkle_set.py index 08034a285..595abd987 100644 --- a/tests/test_merkle_set.py +++ b/tests/test_merkle_set.py @@ -66,7 +66,7 @@ def check_tree(leafs: List[bytes32]) -> None: ) for i in range(256): - item = bytes32([i] + [2] * 31) + item = bytes32.fill(i.to_bytes(), fill=b"\x02", align="<") py_included, py_proof = py_tree.is_included_already_hashed(item) assert not py_included ru_included, ru_proof = ru_tree.is_included_already_hashed(item) diff --git a/tests/test_sized_bytes.py b/tests/test_sized_bytes.py new file mode 100644 index 000000000..26c2484ea --- /dev/null +++ b/tests/test_sized_bytes.py @@ -0,0 +1,37 @@ +import pytest + +from chia_rs.sized_bytes import bytes8 + + +def test_fill_empty() -> None: + assert bytes8.fill(b"", b"\x01") == bytes8([1, 1, 1, 1, 1, 1, 1, 1]) + + +def test_fill_non_empty_with_single() -> None: + assert bytes8.fill(b"\x02", b"\x01") == bytes8([1, 1, 1, 1, 1, 1, 1, 2]) + + +def test_fill_non_empty_with_double() -> None: + assert bytes8.fill(b"\x02\x02", b"\x01\x01") == bytes8([1, 1, 1, 1, 1, 1, 2, 2]) + + +def test_fill_needed_with_0_length_fill_raises() -> None: + with pytest.raises(ValueError): + bytes8.fill(b"\x00", fill=b"") + + +def test_fill_not_needed_with_0_length_fill_works() -> None: + blob = b"\x00" * 8 + assert bytes8.fill(blob, fill=b"") == bytes8(blob) + + +def test_fill_not_multiple_raises() -> None: + with pytest.raises(ValueError): + bytes8.fill(b"\x00", fill=b"\x01\x01") + +def test_align_left() -> None: + assert bytes8.fill(b"\x01", fill=b"\x02", align="<") == bytes8([1, 2, 2, 2, 2, 2, 2, 2]) + +def test_invalid_alignment() -> None: + with pytest.raises(ValueError): + bytes8.fill(b"", fill=b"\x00", align="|") diff --git a/wheel/python/chia_rs/sized_byte_class.py b/wheel/python/chia_rs/sized_byte_class.py index 6dcbeb11e..8467b88ee 100644 --- a/wheel/python/chia_rs/sized_byte_class.py +++ b/wheel/python/chia_rs/sized_byte_class.py @@ -78,6 +78,27 @@ def random( def secret(cls: Type[_T_SizedBytes]) -> _T_SizedBytes: return cls.random(r=system_random) + @classmethod + def fill(cls: Type[_T_SizedBytes], blob: bytes, fill: bytes, align: Literal["<", ">"] = ">") -> _T_SizedBytes: + if len(blob) == cls._size: + return cls(blob) + + fill_length = len(fill) + if fill_length == 0: + raise ValueError("fill required but length is zero") + + div, mod = divmod(cls._size - len(blob), fill_length) + if mod != 0: + raise ValueError("invalid fill value, range to be filled must be multiple of fil size") + + all_fill = fill * div + if align == "<": + return cls(blob + all_fill) + elif align == ">": + return cls(all_fill + blob) + + raise ValueError(f"invalid alignment: {align!r}") + def __str__(self) -> str: return self.hex()