diff --git a/lib/logitech_receiver/receiver.py b/lib/logitech_receiver/receiver.py index 05a6057ef..983d20ea6 100644 --- a/lib/logitech_receiver/receiver.py +++ b/lib/logitech_receiver/receiver.py @@ -83,6 +83,53 @@ class Pairing: error: Optional[any] = None +def extract_serial(response: bytes) -> str: + """Extracts serial number from receiver response.""" + return response.hex().upper() + + +def extract_max_devices(response: bytes) -> int: + """Extracts maximum number of supported devices from response.""" + max_devices = response[6] + return int(max_devices) + + +def extract_remaining_pairings(response: bytes) -> int: + ps = ord(response[2:3]) + remaining_pairings = ps - 5 if ps >= 5 else -1 + return int(remaining_pairings) + + +def extract_codename(response: bytes) -> str: + codename = response[2 : 2 + ord(response[1:2])] + return codename.decode("ascii") + + +def extract_power_switch_location(response: bytes) -> str: + """Extracts power switch location from response.""" + index = response[9] & 0x0F + return hidpp10_constants.PowerSwitchLocation(index).name.lower() + + +def extract_connection_count(response: bytes) -> int: + """Extract connection count from receiver response.""" + return ord(response[1:2]) + + +def extract_wpid(response: bytes) -> str: + """Extract wpid from receiver response.""" + return response.hex().upper() + + +def extract_polling_rate(response: bytes) -> int: + """Returns polling rate in milliseconds.""" + return int(response[2]) + + +def extract_device_kind(response: int) -> str: + return hidpp10_constants.DEVICE_KIND[response] + + class Receiver: """A generic Receiver instance, mostly implementing the interface used on Unifying, Nano, and LightSpeed receivers" The paired devices are available through the sequence interface. @@ -129,9 +176,9 @@ def initialize(self, product_info: dict): # read the receiver information subregister, so we can find out max_devices serial_reply = self.read_register(Registers.RECEIVER_INFO, InfoSubRegisters.RECEIVER_INFORMATION) if serial_reply: - self.serial = serial_reply[1:5].hex().upper() - self.max_devices = serial_reply[6] - if self.max_devices <= 0 or self.max_devices > 6: + self.serial = extract_serial(serial_reply[1:5]) + self.max_devices = extract_max_devices(serial_reply) + if not (1 <= self.max_devices <= 6): self.max_devices = product_info.get("max_devices", 1) else: # handle receivers that don't have a serial number specially (i.e., c534) self.serial = None @@ -164,8 +211,7 @@ def remaining_pairings(self, cache=True): if self._remaining_pairings is None or not cache: ps = self.read_register(Registers.RECEIVER_CONNECTION) if ps is not None: - ps = ord(ps[2:3]) - self._remaining_pairings = ps - 5 if ps >= 5 else -1 + self._remaining_pairings = extract_remaining_pairings(ps) return self._remaining_pairings def enable_connection_notifications(self, enable=True): @@ -195,8 +241,7 @@ def enable_connection_notifications(self, enable=True): def device_codename(self, n): codename = self.read_register(Registers.RECEIVER_INFO, InfoSubRegisters.DEVICE_NAME + n - 1) if codename: - codename = codename[2 : 2 + ord(codename[1:2])] - return codename.decode("ascii") + return extract_codename(codename) def notify_devices(self): """Scan all devices.""" @@ -209,8 +254,8 @@ def notification_information(self, number, notification: HIDPPNotification) -> t assert notification.address != 0x02 online = not bool(notification.data[0] & 0x40) encrypted = bool(notification.data[0] & 0x20) or notification.address == 0x10 - kind = hidpp10_constants.DEVICE_KIND[notification.data[0] & 0x0F] - wpid = (notification.data[2:3] + notification.data[1:2]).hex().upper() + kind = extract_device_kind(notification.data[0] & 0x0F) + wpid = extract_wpid(notification.data[2:3] + notification.data[1:2]) return online, encrypted, wpid, kind def device_pairing_information(self, n: int) -> dict: @@ -220,28 +265,29 @@ def device_pairing_information(self, n: int) -> dict: power_switch = "(unknown)" pair_info = self.read_register(Registers.RECEIVER_INFO, InfoSubRegisters.PAIRING_INFORMATION + n - 1) if pair_info: # a receiver that uses Unifying-style pairing registers - wpid = pair_info[3:5].hex().upper() - kind = hidpp10_constants.DEVICE_KIND[pair_info[7] & 0x0F] - polling_rate = str(pair_info[2]) + "ms" + wpid = extract_wpid(pair_info[3:5]) + kind = extract_device_kind(pair_info[7] & 0x0F) + polling_rate_ms = extract_polling_rate(pair_info) + polling_rate = f"{polling_rate_ms}ms" elif not self.receiver_kind == "unifying": # may be an old Nano receiver device_info = self.read_register(Registers.RECEIVER_INFO, 0x04) # undocumented if device_info: logger.warning("using undocumented register for device wpid") - wpid = device_info[3:5].hex().upper() - kind = hidpp10_constants.DEVICE_KIND[0x00] # unknown kind + wpid = extract_wpid(device_info[3:5]) + kind = extract_device_kind(0x00) # unknown kind else: raise exceptions.NoSuchDevice(number=n, receiver=self, error="read pairing information - non-unifying") else: raise exceptions.NoSuchDevice(number=n, receiver=self, error="read pairing information") pair_info = self.read_register(Registers.RECEIVER_INFO, InfoSubRegisters.EXTENDED_PAIRING_INFORMATION + n - 1) if pair_info: - power_switch = hidpp10_constants.PowerSwitchLocation(pair_info[9] & 0x0F) - serial = pair_info[1:5].hex().upper() + power_switch = extract_power_switch_location(pair_info) + serial = extract_serial(pair_info[1:5]) else: # some Nano receivers? pair_info = self.read_register(0x2D5) # undocumented and questionable if pair_info: logger.warning("using undocumented register for device serial number") - serial = pair_info[1:5].hex().upper() + serial = extract_serial(pair_info[1:5]) return {"wpid": wpid, "kind": kind, "polling": polling_rate, "serial": serial, "power_switch": power_switch} def register_new_device(self, number, notification=None): @@ -287,7 +333,9 @@ def set_lock(self, lock_closed=True, device=0, timeout=0): def count(self): count = self.read_register(Registers.RECEIVER_CONNECTION) - return 0 if count is None else ord(count[1:2]) + if count is None: + return 0 + return extract_connection_count(count) def request(self, request_id, *params): if bool(self): @@ -412,7 +460,7 @@ def __init__(self, *args, **kwargs): def initialize(self, product_info: dict): serial_reply = self.read_register(Registers.BOLT_UNIQUE_ID) - self.serial = serial_reply.hex().upper() + self.serial = extract_serial(serial_reply) self.max_devices = product_info.get("max_devices", 1) def device_codename(self, n): @@ -424,9 +472,9 @@ def device_codename(self, n): def device_pairing_information(self, n: int) -> dict: pair_info = self.read_register(Registers.RECEIVER_INFO, InfoSubRegisters.BOLT_PAIRING_INFORMATION + n) if pair_info: - wpid = (pair_info[3:4] + pair_info[2:3]).hex().upper() - kind = hidpp10_constants.DEVICE_KIND[pair_info[1] & 0x0F] - serial = pair_info[4:8].hex().upper() + wpid = extract_wpid(pair_info[3:4] + pair_info[2:3]) + kind = extract_device_kind(pair_info[1] & 0x0F) + serial = extract_serial(pair_info[4:8]) return {"wpid": wpid, "kind": kind, "polling": None, "serial": serial, "power_switch": "(unknown)"} else: raise exceptions.NoSuchDevice(number=n, receiver=self, error="can't read Bolt pairing register") @@ -484,8 +532,8 @@ def notification_information(self, number, notification): assert notification.address == 0x02 online = True encrypted = bool(notification.data[0] & 0x80) - kind = hidpp10_constants.DEVICE_KIND[_get_kind_from_index(self, number)] - wpid = "00" + notification.data[2:3].hex().upper() + kind = extract_device_kind(_get_kind_from_index(self, number)) + wpid = extract_wpid("00" + notification.data[2:3]) return online, encrypted, wpid, kind def device_pairing_information(self, number: int) -> dict: @@ -494,11 +542,11 @@ def device_pairing_information(self, number: int) -> dict: if not wpid: logger.error("Unable to get wpid from udev for device %d of %s", number, self) raise exceptions.NoSuchDevice(number=number, receiver=self, error="Not present 27Mhz device") - kind = hidpp10_constants.DEVICE_KIND[_get_kind_from_index(self, number)] + kind = extract_device_kind(_get_kind_from_index(self, number)) return {"wpid": wpid, "kind": kind, "polling": "", "serial": None, "power_switch": "(unknown)"} -def _get_kind_from_index(receiver, index): +def _get_kind_from_index(receiver, index: int) -> int: """Get device kind from 27Mhz device index""" # From drivers/hid/hid-logitech-dj.c if index == 1: # mouse diff --git a/tests/logitech_receiver/test_receiver.py b/tests/logitech_receiver/test_receiver.py index 37bff645a..10e915582 100644 --- a/tests/logitech_receiver/test_receiver.py +++ b/tests/logitech_receiver/test_receiver.py @@ -220,3 +220,85 @@ def test_notification_information_nano_receiver(nano_recv, address, data, expect assert encrypted == expected_encrypted assert wpid == "0302" assert kind == "keyboard" + + +def test_extract_serial_number(): + response = b'\x03\x16\xcc\x9c\xb4\x05\x06"\x00\x00\x00\x00\x00\x00\x00\x00' + + serial_number = receiver.extract_serial(response[1:5]) + + assert serial_number == "16CC9CB4" + + +def test_extract_max_devices(): + response = b'\x03\x16\xcc\x9c\xb4\x05\x06"\x00\x00\x00\x00\x00\x00\x00\x00' + + max_devices = receiver.extract_max_devices(response) + + assert max_devices == 6 + + +@pytest.mark.parametrize( + "response, expected_remaining_pairings", + [ + (b"\x00\x03\x00", -1), + (b"\x00\x02\t", 4), + ], +) +def test_extract_remaining_pairings(response, expected_remaining_pairings): + remaining_pairings = receiver.extract_remaining_pairings(response) + + assert remaining_pairings == expected_remaining_pairings + + +def test_extract_codename(): + response = b"A\x04K520" + + codename = receiver.extract_codename(response) + + assert codename == "K520" + + +def test_extract_power_switch_location(): + response = b"0\x19\x8e>\xb8\x06\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00" + + ps_location = receiver.extract_power_switch_location(response) + + assert ps_location == "base" + + +def test_extract_connection_count(): + response = b"\x00\x03\x00" + + connection_count = receiver.extract_connection_count(response) + + assert connection_count == 3 + + +def test_extract_wpid(): + response = b"@\x82" + + res = receiver.extract_wpid(response) + + assert res == "4082" + + +def test_extract_polling_rate(): + response = b"\x08@\x82\x04\x02\x02\x07\x00\x00\x00\x00\x00\x00\x00" + + polling_rate = receiver.extract_polling_rate(response) + + assert polling_rate == 130 + + +@pytest.mark.parametrize( + "data, expected_device_kind", + [ + (0x00, "unknown"), + (0x03, "numpad"), + ], +) +def test_extract_device_kind(data, expected_device_kind): + device_kind = receiver.extract_device_kind(data) + + assert str(device_kind) == expected_device_kind