Skip to content

Commit

Permalink
Add tests for capnp to AnnotatedValue
Browse files Browse the repository at this point in the history
  • Loading branch information
tobiasah committed Aug 31, 2023
1 parent 4d97789 commit c04d19b
Show file tree
Hide file tree
Showing 4 changed files with 492 additions and 10 deletions.
25 changes: 15 additions & 10 deletions src/labone/core/shf_vector_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def from_binary(
)
msg = str(
f"Unsupported extra header version: {version} for "
"ShfScopeVectorExtraHeader",
"ShfDemodulatorVectorExtraHeader",
)
raise NotImplementedError(msg)

Expand Down Expand Up @@ -249,14 +249,14 @@ def from_binary(
Raises:
NotImplementedError: If the version is not supported.
"""
# To be correct, these values should be read from
# /dev.../system/properties/timebase and
# /dev.../system/properties/maxdemodrate
# Here we have read them once and hardcoded for simplicity
timebase = 2.5e-10
max_demod_rate = 5e7
if version.minor >= 2: # noqa: PLR2004
timestamp_diff = struct.unpack("I", binary[8:12])[0]
# To be correct, these values should be read from
# /dev.../system/properties/timebase and
# /dev.../system/properties/maxdemodrate
# Here we have read them once and hardcoded for simplicity
timebase = 2.5e-10
max_demod_rate = 5e7
timestamp_diff *= 1 / (timebase * max_demod_rate)
return ShfDemodulatorVectorExtraHeader(
timestamp=struct.unpack("q", binary[0:8])[0],
Expand All @@ -273,6 +273,8 @@ def from_binary(
signal_source=struct.unpack("H", binary[50:52])[0],
)
if version.minor >= 1:
timestamp_diff = struct.unpack("I", binary[8:12])[0]
timestamp_diff *= 1 / (timebase * max_demod_rate)
return ShfDemodulatorVectorExtraHeader(
timestamp=struct.unpack("q", binary[0:8])[0],
timestamp_diff=timestamp_diff,
Expand Down Expand Up @@ -312,6 +314,9 @@ def _parse_extra_header_version(extra_header_info: int) -> _HeaderVersion:
Returns:
The header version.
"""
if extra_header_info == 0:
msg = "Vector data does not contain extra header."
raise ValueError(msg)
version = extra_header_info >> 16
return _HeaderVersion(major=(version & 0xE0) >> 5, minor=version & 0x1F)

Expand Down Expand Up @@ -357,7 +362,7 @@ def _deserialize_shf_result_logger_vector(
extra_header_info: int,
header_length: int,
element_type: int,
) -> tuple[np.ndarray, ExtraHeader]:
) -> tuple[np.ndarray, ShfResultLoggerVectorExtraHeader]:
"""Deserialize the vector data for result logger vector.
Args:
Expand Down Expand Up @@ -390,7 +395,7 @@ def _deserialize_shf_scope_vector(
raw_data: bytes,
extra_header_info: int,
header_length: int,
) -> tuple[np.ndarray, ExtraHeader]:
) -> tuple[np.ndarray, ShfScopeVectorExtraHeader]:
"""Deserialize the vector data for waveform vectors.
Args:
Expand Down Expand Up @@ -423,7 +428,7 @@ def _deserialize_shf_demodulator_vector(
raw_data: bytes,
extra_header_info: int,
header_length: int,
) -> tuple[SHFDemodSample, ExtraHeader]:
) -> tuple[SHFDemodSample, ShfDemodulatorVectorExtraHeader]:
"""Deserialize the vector data for waveform vectors.
Args:
Expand Down
3 changes: 3 additions & 0 deletions src/labone/core/value.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ def from_capnp(raw: session_protocol_capnp.CntSample) -> CntSample:
SHFDemodSample,
TriggerSample,
CntSample,
None,
]


Expand Down Expand Up @@ -257,6 +258,8 @@ def _capnp_value_to_python_value(
return CntSample.from_capnp(capnp_value.cntSample), None
if capnp_type == "triggerSample":
return TriggerSample.from_capnp(capnp_value.triggerSample), None
if capnp_type == "none":
return None, None
msg = f"Unknown capnp type: {capnp_type}"
raise ValueError(msg)

Expand Down
230 changes: 230 additions & 0 deletions tests/core/test_annotated_value_from_capnp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
"""Tests the conversion from capnp to AnnotatedValue"""
from unittest.mock import patch

import labone.core.value as value_module
import numpy as np
import pytest
from labone.core.resources import session_protocol_capnp


class IllegalAnnotatedValue:
class IllegalValue:
def which(self):
return "illegal"

@property
def value(self):
return IllegalAnnotatedValue.IllegalValue()


def test_illegal_type():
with pytest.raises(ValueError):
value_module.AnnotatedValue.from_capnp(IllegalAnnotatedValue())


def test_void():
input_dict = {
"metadata": {"timestamp": 42, "path": "/non/of/your/business"},
"value": {"none": {}},
}
msg = session_protocol_capnp.AnnotatedValue.new_message()
msg.from_dict(input_dict)
parsed_value = value_module.AnnotatedValue.from_capnp(msg)
assert parsed_value.timestamp == input_dict["metadata"]["timestamp"]
assert parsed_value.path == input_dict["metadata"]["path"]
assert parsed_value.extra_header is None
assert parsed_value.value is None


def test_trigger_sample():
input_dict = {
"metadata": {"timestamp": 42, "path": "/non/of/your/business"},
"value": {
"triggerSample": {
"timestamp": 1,
"sampleTick": 2,
"trigger": 3,
"missedTriggers": 4,
"awgTrigger": 5,
"dio": 6,
"sequenceIndex": 7,
},
},
}
msg = session_protocol_capnp.AnnotatedValue.new_message()
msg.from_dict(input_dict)
parsed_value = value_module.AnnotatedValue.from_capnp(msg)
assert parsed_value.timestamp == input_dict["metadata"]["timestamp"]
assert parsed_value.path == input_dict["metadata"]["path"]
assert parsed_value.extra_header is None
assert (
parsed_value.value.timestamp
== input_dict["value"]["triggerSample"]["timestamp"]
)
assert (
parsed_value.value.sample_tick
== input_dict["value"]["triggerSample"]["sampleTick"]
)
assert parsed_value.value.trigger == input_dict["value"]["triggerSample"]["trigger"]
assert (
parsed_value.value.missed_triggers
== input_dict["value"]["triggerSample"]["missedTriggers"]
)
assert (
parsed_value.value.awg_trigger
== input_dict["value"]["triggerSample"]["awgTrigger"]
)
assert parsed_value.value.dio == input_dict["value"]["triggerSample"]["dio"]
assert (
parsed_value.value.sequence_index
== input_dict["value"]["triggerSample"]["sequenceIndex"]
)


def test_cnt_sample():
input_dict = {
"metadata": {"timestamp": 42, "path": "/non/of/your/business"},
"value": {
"cntSample": {
"timestamp": 1,
"counter": 2,
"trigger": 3,
},
},
}
msg = session_protocol_capnp.AnnotatedValue.new_message()
msg.from_dict(input_dict)
parsed_value = value_module.AnnotatedValue.from_capnp(msg)
assert parsed_value.timestamp == input_dict["metadata"]["timestamp"]
assert parsed_value.path == input_dict["metadata"]["path"]
assert parsed_value.extra_header is None
assert parsed_value.value.timestamp == input_dict["value"]["cntSample"]["timestamp"]
assert parsed_value.value.counter == input_dict["value"]["cntSample"]["counter"]
assert parsed_value.value.trigger == input_dict["value"]["cntSample"]["trigger"]


@pytest.mark.parametrize(
("type_name", "input_val", "output_val"),
[
("int64", 42, 42),
("double", 42.0, 42.0),
("complex", {"real": 42, "imag": 66}, 42 + 66j),
("string", "42", "42"),
],
)
def test_generic_types(type_name, input_val, output_val):
input_dict = {
"metadata": {"timestamp": 42, "path": "/non/of/your/business"},
"value": {type_name: input_val},
}
msg = session_protocol_capnp.AnnotatedValue.new_message()
msg.from_dict(input_dict)
parsed_value = value_module.AnnotatedValue.from_capnp(msg)
assert parsed_value.timestamp == input_dict["metadata"]["timestamp"]
assert parsed_value.path == input_dict["metadata"]["path"]
assert parsed_value.extra_header is None
assert parsed_value.value == output_val


def test_string_vector():
input_dict = {
"metadata": {"timestamp": 42, "path": "/non/of/your/business"},
"value": {
"vectorData": {
"valueType": 7,
"vectorElementType": 6,
"extraHeaderInfo": 0,
"data": b"Hello World",
},
},
}
msg = session_protocol_capnp.AnnotatedValue.new_message()
msg.from_dict(input_dict)
parsed_value = value_module.AnnotatedValue.from_capnp(msg)
assert parsed_value.timestamp == input_dict["metadata"]["timestamp"]
assert parsed_value.path == input_dict["metadata"]["path"]
assert parsed_value.extra_header is None
assert parsed_value.value == "Hello World"


def test_generic_vector():
input_array = np.array([1, 2, 3, 4, 5, 6, 7, 8], dtype=np.uint32)
input_dict = {
"metadata": {"timestamp": 42, "path": "/non/of/your/business"},
"value": {
"vectorData": {
"valueType": 67,
"vectorElementType": 2,
"extraHeaderInfo": 0,
"data": input_array.tobytes(),
},
},
}
msg = session_protocol_capnp.AnnotatedValue.new_message()
msg.from_dict(input_dict)
parsed_value = value_module.AnnotatedValue.from_capnp(msg)
assert parsed_value.timestamp == input_dict["metadata"]["timestamp"]
assert parsed_value.path == input_dict["metadata"]["path"]
assert parsed_value.extra_header is None
assert np.array_equal(parsed_value.value, input_array)


def test_shf_vector():
input_dict = {
"metadata": {"timestamp": 42, "path": "/non/of/your/business"},
"value": {
"vectorData": {
"valueType": 69,
"vectorElementType": 2,
"extraHeaderInfo": 0,
"data": b"",
},
},
}
msg = session_protocol_capnp.AnnotatedValue.new_message()
msg.from_dict(input_dict)
with patch.object(
value_module,
"parse_shf_vector_data_struct",
autospec=True,
) as mock_method:
mock_method.return_value = "array", "extra_header"
parsed_value = value_module.AnnotatedValue.from_capnp(msg)
mock_method.assert_called_once()
assert mock_method.call_args[0][0].to_dict() == input_dict["value"]["vectorData"]
assert parsed_value.timestamp == input_dict["metadata"]["timestamp"]
assert parsed_value.path == input_dict["metadata"]["path"]
assert parsed_value.extra_header == "extra_header"
assert parsed_value.value == "array"


@pytest.mark.parametrize("vector_length", range(0, 200, 32))
@pytest.mark.parametrize("header_length", range(0, 200, 32))
def test_unknown_shf_vector(vector_length, header_length):
input_array = np.linspace(0, 1, vector_length, dtype=np.uint32)
input_dict = {
"metadata": {"timestamp": 42, "path": "/non/of/your/business"},
"value": {
"vectorData": {
"valueType": 69,
"vectorElementType": 2,
"extraHeaderInfo": header_length,
"data": input_array.tobytes(),
},
},
}
msg = session_protocol_capnp.AnnotatedValue.new_message()
msg.from_dict(input_dict)
with patch.object(
value_module,
"parse_shf_vector_data_struct",
autospec=True,
) as mock_method:
mock_method.side_effect = ValueError("Unknown SHF vector type")
parsed_value = value_module.AnnotatedValue.from_capnp(msg)
mock_method.assert_called_once()
assert mock_method.call_args[0][0].to_dict() == input_dict["value"]["vectorData"]
assert parsed_value.timestamp == input_dict["metadata"]["timestamp"]
assert parsed_value.path == input_dict["metadata"]["path"]
assert parsed_value.extra_header is None
assert np.array_equal(parsed_value.value, input_array[header_length:])
Loading

0 comments on commit c04d19b

Please sign in to comment.