Skip to content

Commit

Permalink
electrical_protocol: Allow string and character packets to be constru…
Browse files Browse the repository at this point in the history
…cted
  • Loading branch information
cbrxyz committed Nov 3, 2024
1 parent d28e9d2 commit cd629ec
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -112,17 +112,15 @@ def __init_subclass__(

def __post_init__(self):
for name, field_type in get_cache_hints(self.__class__).items():
if (
name
not in [
"class_id",
"subclass_id",
"payload_format",
]
and not isinstance(self.__dict__[name], field_type)
and issubclass(field_type, Enum)
):
setattr(self, name, field_type(self.__dict__[name]))
if name not in [
"class_id",
"subclass_id",
"payload_format",
] and not isinstance(self.__dict__[name], field_type):
if issubclass(field_type, Enum):
setattr(self, name, field_type(self.__dict__[name]))
elif issubclass(field_type, str):
setattr(self, name, self.__dict__[name].rstrip(b"\x00").decode())
if self.payload_format and not self.payload_format.startswith(
("<", ">", "=", "!"),
):
Expand Down Expand Up @@ -163,7 +161,21 @@ def _calculate_checksum(cls, data: bytes) -> tuple[int, int]:
return sum1, sum2

def __bytes__(self):
payload = struct.pack(self.payload_format, *self.__dict__.values())
ready_values = []
for value in self.__dict__.values():
if isinstance(value, Enum):
ready_values.append(
(
value.value
if not isinstance(value.value, str)
else value.value.encode()
),
)
elif isinstance(value, str):
ready_values.append(value.encode())
else:
ready_values.append(value)
payload = struct.pack(self.payload_format, *ready_values)
data = struct.pack(
f"<BBBBH{len(payload)}s",
SYNC_CHAR_1,
Expand Down
47 changes: 42 additions & 5 deletions mil_common/drivers/electrical_protocol/test/calculator_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

from dataclasses import dataclass
from enum import Enum, IntEnum
from threading import Event
from typing import Union

Expand All @@ -12,6 +13,18 @@
from std_srvs.srv import Empty, EmptyRequest, EmptyResponse


class CalculatorMode(IntEnum):
ADD = 0
SUB = 1
MUL = 2


class Sign(Enum):
PLUS = "+"
MINUS = "-"
MULTIPLY = "*"


@dataclass
class RequestAddPacket(Packet, class_id=0x37, subclass_id=0x00, payload_format="<ff"):
number_one: float
Expand All @@ -29,13 +42,30 @@ class AnswerPacket(Packet, class_id=0x37, subclass_id=0x02, payload_format="<f")
result: float


@dataclass
class CharacterPacket(Packet, class_id=0x37, subclass_id=0x03, payload_format="<c10s"):
single_char: str
big_str: str


@dataclass
class EnumPacket(Packet, class_id=0x37, subclass_id=0x04, payload_format="<cb"):
symbol: Sign
number: CalculatorMode


class CalculatorDevice(
ROSSerialDevice[Union[RequestAddPacket, RequestSubPacket], AnswerPacket],
ROSSerialDevice[
Union[RequestAddPacket, RequestSubPacket, CharacterPacket],
Union[AnswerPacket, EnumPacket],
],
):
def __init__(self):
self.port_topic = rospy.Subscriber("~port", String, self.port_callback)
self.start_service = rospy.Service("~trigger", Empty, self.trigger)
self.answer_topic = rospy.Publisher("~answer", Float32, queue_size=10)
self.start_one_service = rospy.Service("~trigger_one", Empty, self.trigger)
self.start_two_service = rospy.Service("~trigger_two", Empty, self.trigger_two)
self.answer_one_topic = rospy.Publisher("~answer_one", Float32, queue_size=10)
self.answer_two_topic = rospy.Publisher("~answer_two", Float32, queue_size=10)
self.next_packet = Event()
self.i = 0
super().__init__(None, 115200)
Expand All @@ -51,9 +81,16 @@ def trigger(self, _: EmptyRequest):
)
return EmptyResponse()

def trigger_two(self, _: EmptyRequest):
self.send_packet(CharacterPacket("a", "small"))
return EmptyResponse()

def on_packet_received(self, packet) -> None:
self.answer_topic.publish(Float32(data=packet.result))
self.next_packet.set()
if isinstance(packet, AnswerPacket):
self.answer_one_topic.publish(Float32(data=packet.result))
self.next_packet.set()
elif isinstance(packet, EnumPacket):
self.answer_two_topic.publish(Float32(data=packet.number.value))


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pty
import unittest
from dataclasses import dataclass
from enum import Enum, IntEnum

import rospy
import rostest
Expand All @@ -12,6 +13,18 @@
from std_srvs.srv import Empty


class CalculatorMode(IntEnum):
ADD = 0
SUB = 1
MUL = 2


class Sign(Enum):
PLUS = "+"
MINUS = "-"
MULTIPLY = "*"


@dataclass
class RequestAddPacket(Packet, class_id=0x37, subclass_id=0x00, payload_format="<ff"):
number_one: float
Expand All @@ -29,6 +42,18 @@ class AnswerPacket(Packet, class_id=0x37, subclass_id=0x02, payload_format="<f")
result: float


@dataclass
class CharacterPacket(Packet, class_id=0x37, subclass_id=0x03, payload_format="<c10s"):
single_char: str
big_str: str


@dataclass
class EnumPacket(Packet, class_id=0x37, subclass_id=0x04, payload_format="<cb"):
symbol: Sign
number: CalculatorMode


class SimulatedBasicTest(unittest.TestCase):
def __init__(self, *args):
super().__init__(*args)
Expand All @@ -37,10 +62,15 @@ def __init__(self, *args):
String,
queue_size=1,
)
self.answer_subscriber = rospy.Subscriber(
"/calculator_device/answer",
self.answer_one_subscriber = rospy.Subscriber(
"/calculator_device/answer_one",
Float32,
self.answer_callback_one,
)
self.answer_two_subscriber = rospy.Subscriber(
"/calculator_device/answer_two",
Float32,
self.answer_callback,
self.answer_callback_two,
)
self.count = 0

Expand All @@ -52,7 +82,7 @@ def test_simulated(self):
rospy.sleep(0.1)
self.port_publisher.publish(String(serial_name))
self.trigger_service_caller = rospy.ServiceProxy(
"/calculator_device/trigger",
"/calculator_device/trigger_one",
Empty,
)
for i in range(1000):
Expand All @@ -74,12 +104,37 @@ def test_simulated(self):
bytes(AnswerPacket(packet.number_one + packet.number_two)),
)
rospy.sleep(2)
self.assertEqual(self.count, 1000)
self.assertGreaterEqual(self.count, 900)
self.trigger_two_service_caller = rospy.ServiceProxy(
"/calculator_device/trigger_two",
Empty,
)
self.trigger_two_service_caller()
packet_bytes = os.read(self.master, 100)
packet = CharacterPacket.from_bytes(packet_bytes)
self.assertEqual(
packet.single_char,
"a",
f"packet.single_char: {packet.single_char}",
)
self.assertEqual(
packet.big_str,
"small",
f"packet.big_str: {packet.big_str}",
)
os.write(
self.master,
bytes(EnumPacket(Sign.MULTIPLY, CalculatorMode.ADD)),
)

def answer_callback(self, msg: Float32):
self.assertEqual(msg.data, 1000)
def answer_callback_one(self, msg: Float32):
# at least 900 packets gotten (sometimes lower due to performance)
self.assertGreaterEqual(msg.data, 900)
self.count += 1

def answer_callback_two(self, msg: Float32):
self.assertEqual(msg.data, 5)

def tearDown(self):
os.close(self.master)
os.close(self.slave)
Expand Down

0 comments on commit cd629ec

Please sign in to comment.