Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Zigpy serial protocol #160

Merged
merged 8 commits into from
Oct 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ license = {text = "GPL-3.0"}
requires-python = ">=3.8"
dependencies = [
"voluptuous",
"zigpy>=0.66.0",
"zigpy>=0.70.0",
"pyusb>=1.1.0",
"gpiozero",
'async-timeout; python_version<"3.11"',
Expand Down
11 changes: 7 additions & 4 deletions tests/test_api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import asyncio
from unittest.mock import MagicMock, patch, sentinel
from unittest.mock import AsyncMock, MagicMock, patch, sentinel

import pytest
import serial_asyncio
Expand Down Expand Up @@ -37,10 +37,13 @@ async def mock_conn(loop, protocol_factory, **kwargs):
await api.connect()


def test_close(api):
@pytest.mark.asyncio
async def test_disconnect(api):
uart = api._uart
api.close()
assert uart.close.call_count == 1
uart.disconnect = AsyncMock()

await api.disconnect()
assert uart.disconnect.call_count == 1
assert api._uart is None


Expand Down
12 changes: 6 additions & 6 deletions tests/test_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,28 +102,28 @@ async def mock_get_network_state():

@pytest.mark.asyncio
async def test_disconnect_success(app):
api = MagicMock()
api = AsyncMock()

app._api = api
await app.disconnect()

api.close.assert_called_once()
api.disconnect.assert_called_once()
assert app._api is None


@pytest.mark.asyncio
async def test_disconnect_failure(app, caplog):
api = MagicMock()
api.disconnect = MagicMock(side_effect=RuntimeError("Broken"))
api = AsyncMock()
api.reset = AsyncMock(side_effect=RuntimeError("Broken"))

app._api = api

with caplog.at_level(logging.WARNING):
await app.disconnect()

assert "disconnect" in caplog.text
assert "Failed to reset before disconnect" in caplog.text

api.close.assert_called_once()
api.disconnect.assert_called_once()
assert app._api is None


Expand Down
15 changes: 7 additions & 8 deletions tests/test_uart.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from unittest.mock import AsyncMock, MagicMock
from unittest.mock import AsyncMock, MagicMock, call

import gpiozero
import pytest
Expand Down Expand Up @@ -52,6 +52,12 @@ def test_close(gw):
assert gw._transport.close.call_count == 1


def test_connection_lost(gw):
exc = RuntimeError()
gw.connection_lost(exc)
assert gw._api.connection_lost.mock_calls == [call(exc)]


def test_data_received_chunk_frame(gw):
data = b"\x01\x80\x10\x02\x10\x02\x15\xaa\x02\x10\x02\x1f?\xf0\xff\x03"
gw.data_received(data[:-4])
Expand Down Expand Up @@ -108,13 +114,6 @@ def test_escape(gw):
assert r == data_escaped


def test_length(gw):
data = b"\x80\x10\x00\x05\xaa\x00\x0f?\xf0\xff"
length = 5
r = gw._length(data)
assert r == length


def test_checksum(gw):
data = b"\x00\x0f?\xf0"
checksum = 0xAA
Expand Down
6 changes: 3 additions & 3 deletions zigpy_zigate/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,9 +246,9 @@ def connection_lost(self, exc: Exception) -> None:
if self._app is not None:
self._app.connection_lost(exc)

def close(self):
if self._uart:
self._uart.close()
async def disconnect(self):
if self._uart is not None:
await self._uart.disconnect()
self._uart = None

def set_application(self, app):
Expand Down
70 changes: 21 additions & 49 deletions zigpy_zigate/uart.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from __future__ import annotations

import asyncio
import binascii
import logging
import struct
from typing import Any, Dict
from typing import Any

import zigpy.config
import zigpy.serial
Expand All @@ -12,39 +14,24 @@
LOGGER = logging.getLogger(__name__)


class Gateway(asyncio.Protocol):
class Gateway(zigpy.serial.SerialProtocol):
START = b"\x01"
END = b"\x03"

def __init__(self, api, connected_future=None):
self._buffer = b""
self._connected_future = connected_future
def __init__(self, api):
super().__init__()
self._api = api

def connection_lost(self, exc) -> None:
"""Port was closed expecteddly or unexpectedly."""
if self._connected_future and not self._connected_future.done():
if exc is None:
self._connected_future.set_result(True)
else:
self._connected_future.set_exception(exc)
if exc is None:
LOGGER.debug("Closed serial connection")
return

LOGGER.error("Lost serial connection: %s", exc)
self._api.connection_lost(exc)
def connection_lost(self, exc: Exception | None) -> None:
"""Port was closed expectedly or unexpectedly."""
super().connection_lost(exc)

def connection_made(self, transport):
"""Callback when the uart is connected"""
LOGGER.debug("Connection made")
self._transport = transport
if self._connected_future:
self._connected_future.set_result(True)
if self._api is not None:
self._api.connection_lost(exc)

def close(self):
if self._transport:
self._transport.close()
super().close()
self._api = None

def send(self, cmd, data=b""):
"""Send data, taking care of escaping and framing"""
Expand All @@ -60,8 +47,7 @@ def send(self, cmd, data=b""):

def data_received(self, data):
"""Callback when there is data received from the uart"""
self._buffer += data
# LOGGER.debug('data_received %s', self._buffer)
super().data_received(data)
endpos = self._buffer.find(self.END)
while endpos != -1:
startpos = self._buffer.rfind(self.START, 0, endpos)
Expand All @@ -71,7 +57,7 @@ def data_received(self, data):
cmd, length, checksum, f_data, lqi = struct.unpack(
"!HHB%dsB" % (len(frame) - 6), frame
)
if self._length(frame) != length:
if len(frame) - 5 != length:
LOGGER.warning(
"Invalid length: %s, data: %s", length, len(frame) - 6
)
Expand Down Expand Up @@ -126,42 +112,28 @@ def _checksum(self, *args):
chcksum ^= x
return chcksum

def _length(self, frame):
length = len(frame) - 5
return length


async def connect(device_config: Dict[str, Any], api, loop=None):
if loop is None:
loop = asyncio.get_event_loop()

connected_future = asyncio.Future()
protocol = Gateway(api, connected_future)

async def connect(device_config: dict[str, Any], api, loop=None):
loop = asyncio.get_running_loop()
port = device_config[zigpy.config.CONF_DEVICE_PATH]
if port == "auto":
port = await loop.run_in_executor(None, c.discover_port)

if await c.async_is_pizigate(port):
LOGGER.debug("PiZiGate detected")
await c.async_set_pizigate_running_mode()
# in case of pizigate:/dev/ttyAMA0 syntax
if port.startswith("pizigate:"):
port = port.replace("pizigate:", "", 1)
port = port.replace("pizigate:", "", 1)
elif await c.async_is_zigate_din(port):
LOGGER.debug("ZiGate USB DIN detected")
await c.async_set_zigatedin_running_mode()
elif c.is_zigate_wifi(port):
LOGGER.debug("ZiGate WiFi detected")

protocol = Gateway(api)
_, protocol = await zigpy.serial.create_serial_connection(
loop,
lambda: protocol,
url=port,
baudrate=device_config[zigpy.config.CONF_DEVICE_BAUDRATE],
xonxoff=False,
flow_control=device_config[zigpy.config.CONF_DEVICE_FLOW_CONTROL],
)

await connected_future
await protocol.wait_until_connected()

return protocol
2 changes: 1 addition & 1 deletion zigpy_zigate/zigbee/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ async def disconnect(self):
except Exception as e:
LOGGER.warning("Failed to reset before disconnect: %s", e)
finally:
self._api.close()
await self._api.disconnect()
self._api = None

async def start_network(self):
Expand Down