Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add datastruct test suite #6

Merged
merged 11 commits into from
Oct 12, 2024
12 changes: 12 additions & 0 deletions .github/workflows/push-dev.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
name: Push (dev), Pull Request
on:
push:
branches: ["**"]
pull_request:
jobs:
lint-python:
name: Run Python lint
uses: kuba2k2/kuba2k2/.github/workflows/lint-python.yml@master
test-python:
name: Run Python tests
uses: kuba2k2/kuba2k2/.github/workflows/test-python.yml@master
4 changes: 4 additions & 0 deletions .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,14 @@ jobs:
lint-python:
name: Run Python lint
uses: kuba2k2/kuba2k2/.github/workflows/lint-python.yml@master
test-python:
name: Run Python tests
uses: kuba2k2/kuba2k2/.github/workflows/test-python.yml@master
publish-pypi:
name: Publish PyPI package
needs:
- lint-python
- test-python
uses: kuba2k2/kuba2k2/.github/workflows/publish-pypi.yml@master
secrets:
PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }}
2 changes: 1 addition & 1 deletion datastruct/adapters/time.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def filetime_field(*, default=...):
"<Q", int((value.timestamp() + 11644473600) * 10000000)
),
decode=lambda value, ctx: datetime.fromtimestamp(
int(unpack("<Q", value)[0] / 10000000) - 11644473600
(unpack("<Q", value)[0] / 10000000) - 11644473600
),
)(field(8, default=default))

Expand Down
12 changes: 4 additions & 8 deletions datastruct/fields/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,18 +77,14 @@ def decode(self, value: bytes, ctx: Context) -> bytes:

return adapter(ByteStr())(field(length, default=default))

def varbytes(
length: Value[int],
*,
default: bytes = ...
):

def varbytes(length: Value[int], *, default: bytes = ...):
return field(
lambda ctx: (
len(ctx.P.self) if ctx.G.packing else evaluate(ctx, length)
),
lambda ctx: (len(ctx.P.self) if ctx.G.packing else evaluate(ctx, length)),
default=default,
)


def text(
length: Value[int],
*,
Expand Down
600 changes: 452 additions & 148 deletions poetry.lock

Large diffs are not rendered by default.

14 changes: 11 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,19 @@ packages = [
]

[tool.poetry.dependencies]
python = "^3.7"
python = "^3.8"

[tool.poetry.group.dev.dependencies]
black = "^22.12.0"
isort = "^5.11.4"
black = "^24.1.0"
isort = "^5.12.0"
autoflake = "^2.1.1"

[tool.poetry.group.test.dependencies]
pytest = "^8.3.3"
macaddress = "^2.0.2"
pycryptodome = "^3.21.0"
requests = "^2.32.3"
bitstruct = "^8.19.0"

[build-system]
requires = ["poetry-core"]
Expand Down
142 changes: 142 additions & 0 deletions tests/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# Copyright (c) Kuba Szczodrzyński 2024-10-11.

import re
from dataclasses import dataclass
from pprint import pformat
from typing import Type

import pytest

from datastruct import DataStruct


@dataclass
class TestData:
__test__ = False

cls: Type[DataStruct] | None = None
data: bytes | None = None
obj_full: DataStruct | None = None
obj_simple: DataStruct | None = None
context: dict = None

full_after_packing: bool = True
unpack_then_pack: bool = True
pack_then_unpack: bool = True

def __post_init__(self) -> None:
if self.context is None:
self.context = {}


class TestBase:
test: TestData

@pytest.fixture(scope="function", autouse=True)
def setup_and_teardown(self, test: TestData) -> None:
self.test = test

def get_cls(self) -> Type[DataStruct]:
return self.test.cls or type(self.test.obj_full) or type(self.test.obj_simple)

def obj_to_str(self, obj: DataStruct) -> str:
pp = pformat(obj)
# fix enum representation
pp = re.sub(r"<([^.]+\.[^:]+?):.+?>", "\\1", pp)
pp = re.sub(r"([A-Za-z][A-Za-z0-9]+?)\.0", "\\1(0)", pp)
return pp

def bytes_to_hex_repr(self, data: bytes) -> str:
out = ""
for i in range(0, len(data), 16):
line = data[i : i + 16]
out += 'b"\\x' + line.hex(" ").replace(" ", "\\x") + '"\n'
return out

def bytes_to_hex_str(self, data: bytes) -> str:
out = ""
for i in range(0, len(data), 16):
line = data[i : i + 16]
out += line.hex(" ") + "\n"
return out

def test_unpack_from_bytes(self) -> None:
if self.test.data is None:
pytest.skip()
unpacked = self.get_cls().unpack(self.test.data, **self.test.context)
if self.test.obj_full is None:
print("Unpacked (from bytes):")
print(self.obj_to_str(unpacked))
return
if unpacked != self.test.obj_full:
print()
print(unpacked)
print(self.test.obj_full)
assert unpacked == self.test.obj_full

def test_pack_full_to_bytes(self) -> None:
if self.test.obj_full is None:
pytest.skip()
packed = self.test.obj_full.pack(**self.test.context)
if self.test.data is None:
print("Packed (full):")
print(self.bytes_to_hex_repr(packed))
return
if packed != self.test.data:
print()
print(packed.hex(" "))
print(self.test.data.hex(" "))
assert packed == self.test.data

def test_pack_simple_to_bytes(self) -> None:
if self.test.obj_simple is None:
pytest.skip()
packed = self.test.obj_simple.pack(**self.test.context)
if self.test.obj_full is None:
print("Unpacked (from simple):")
print(self.obj_to_str(self.test.obj_simple))
if self.test.data is None:
print("Packed (simple):")
print(self.bytes_to_hex_repr(packed))
return
if packed != self.test.data:
print()
print(packed.hex(" "))
print(self.test.data.hex(" "))
assert packed == self.test.data

def test_full_after_packing(self) -> None:
if (
not self.test.full_after_packing
or self.test.obj_full is None
or self.test.obj_simple is None
):
pytest.skip()
self.test.obj_simple.pack(**self.test.context)
if self.test.obj_full != self.test.obj_simple:
print()
print(self.test.obj_full)
print(self.test.obj_simple)
assert self.test.obj_full == self.test.obj_simple

def test_unpack_then_pack(self) -> None:
if not self.test.unpack_then_pack or self.test.data is None:
pytest.skip()
unpacked = self.get_cls().unpack(self.test.data, **self.test.context)
packed = unpacked.pack(**self.test.context)
if packed != self.test.data:
print()
print(packed.hex(" "))
print(self.test.data.hex(" "))
assert packed == self.test.data

def test_pack_then_unpack(self) -> None:
if not self.test.pack_then_unpack or self.test.obj_full is None:
pytest.skip()
packed = self.test.obj_full.pack(**self.test.context)
unpacked = self.get_cls().unpack(packed, **self.test.context)
if unpacked != self.test.obj_full:
print()
print(unpacked)
print(self.test.obj_full)
assert unpacked == self.test.obj_full
113 changes: 113 additions & 0 deletions tests/test_ambz2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright (c) Kuba Szczodrzyński 2024-10-11.

import pytest
from base import TestBase, TestData
from test_ambz2_structs import *
from util import read_data_file

config = get_image_config()


def build_firmware():
# defaults from libretiny/boards/bw15
board_flash = {
"part_table": (0x000000, 0x1000, 0x000000 + 0x1000),
"system": (0x001000, 0x1000, 0x001000 + 0x1000),
"calibration": (0x002000, 0x1000, 0x002000 + 0x1000),
"boot": (0x004000, 0x8000, 0x004000 + 0x8000),
"ota1": (0x00C000, 0xF8000, 0x00C000 + 0xF8000),
"ota2": (0x104000, 0xF8000, 0x104000 + 0xF8000),
"kvs": (0x1FC000, 0x4000, 0x1FC000 + 0x400),
}

ptab_offset, _, ptab_end = board_flash["part_table"]
boot_offset, _, boot_end = board_flash["boot"]
ota1_offset, _, ota1_end = board_flash["ota1"]

# build the partition table
ptable = PartitionTable(user_data=b"\xFF" * 256)
for region, type in config.ptable.items():
offset, length, _ = board_flash[region]
hash_key = config.keys.hash_keys[region]
ptable.partitions.append(
PartitionRecord(offset, length, type, hash_key=hash_key),
)
ptable = Image(
keyblock=build_keyblock(config, "part_table"),
header=ImageHeader(
type=ImageType.PARTAB,
),
data=ptable,
)

# build boot image
region = "boot"
boot = Image(
keyblock=build_keyblock(config, region),
header=ImageHeader(
type=ImageType.BOOT,
user_keys=[config.keys.user_keys[region], FF_32],
),
data=build_section(config.boot),
)

# build firmware (sub)images
firmware = []
region = "ota1"
for idx, image in enumerate(config.fw):
obj = Image(
keyblock=build_keyblock(config, region),
header=ImageHeader(
type=image.type,
# use FF to allow recalculating by OTA code
serial=0xFFFFFFFF if idx == 0 else 0,
user_keys=(
[FF_32, config.keys.user_keys[region]]
if idx == 0
else [FF_32, FF_32]
),
),
data=Firmware(
sections=[build_section(section) for section in image.sections],
),
)
# remove empty sections
obj.data.sections = [s for s in obj.data.sections if s.data]
firmware.append(obj)
if image.type != ImageType.XIP:
continue
# update SCE keys for XIP images
for section in obj.data.sections:
section.header.sce_key = config.keys.xip_sce_key
section.header.sce_iv = config.keys.xip_sce_iv

# build main flash image
return Flash(
ptable=ptable,
boot=boot,
firmware=firmware,
)


TEST_DATA = [
pytest.param(
TestData(
cls=Flash,
data=read_data_file(TEST_DATA_URLS["image_flash_is.bin"]),
obj_full=None,
obj_simple=build_firmware(),
context=dict(
hash_key=config.keys.hash_keys["part_table"],
),
),
id="dummy",
),
]


@pytest.mark.parametrize("test", TEST_DATA)
class TestAmbZ2(TestBase):
pass


del TestBase
Loading
Loading