Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanly shut down the serial port on disconnect #633

Merged
merged 10 commits into from
Oct 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion bellows/ash.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,12 @@ def __init__(self, code: t.NcpResetCode) -> None:
def __repr__(self) -> str:
return f"<{self.__class__.__name__}(code={self.code})>"

def __eq__(self, other: object) -> bool | NotImplemented:
if not isinstance(other, NcpFailure):
return NotImplemented

return self.code == other.code


class AshFrame(abc.ABC, BaseDataclassMixin):
MASK: t.uint8_t
Expand Down Expand Up @@ -368,7 +374,7 @@ def connection_made(self, transport):
self._transport = transport
self._ezsp_protocol.connection_made(self)

def connection_lost(self, exc):
def connection_lost(self, exc: Exception | None) -> None:
self._transport = None
self._cancel_pending_data_frames()
self._ezsp_protocol.connection_lost(exc)
Expand Down
2 changes: 1 addition & 1 deletion bellows/cli/dump.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def dump(ctx, channel, outfile):
finally:
if "ezsp" in ctx.obj:
loop.run_until_complete(ctx.obj["ezsp"].mfglibEnd())
ctx.obj["ezsp"].close()
loop.run_until_complete(ctx.obj["ezsp"].disconnect())


def ieee_15_4_fcs(data: bytes) -> bytes:
Expand Down
10 changes: 5 additions & 5 deletions bellows/cli/ncp.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ async def config(ctx, config, all_):
if v[0] == t.EzspStatus.ERROR_INVALID_ID:
continue
click.echo(f"{config.name}={v[1]}")
s.close()
await s.disconnect()
return

if "=" in config:
Expand All @@ -54,7 +54,7 @@ async def config(ctx, config, all_):

v = await s.setConfigurationValue(config, value)
click.echo(v)
s.close()
await s.disconnect()
return

v = await s.getConfigurationValue(config)
Expand Down Expand Up @@ -86,7 +86,7 @@ async def info(ctx):
click.echo(f"Board name: {brd_name}")
click.echo(f"EmberZNet version: {version}")

s.close()
await s.disconnect()


@main.command()
Expand All @@ -105,7 +105,7 @@ async def bootloader(ctx):
version, plat, micro, phy = await ezsp.getStandaloneBootloaderVersionPlatMicroPhy()
if version == 0xFFFF:
click.echo("No boot loader installed")
ezsp.close()
await ezsp.disconnect()
return

click.echo(
Expand All @@ -118,4 +118,4 @@ async def bootloader(ctx):
click.echo(f"Couldn't launch bootloader: {res[0]}")
else:
click.echo("bootloader launched successfully")
ezsp.close()
await ezsp.disconnect()
6 changes: 3 additions & 3 deletions bellows/cli/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def cb(fut, frame_name, response):

s.remove_callback(cbid)

s.close()
await s.disconnect()


@main.command()
Expand All @@ -126,7 +126,7 @@ async def leave(ctx):
expected=t.EmberStatus.NETWORK_DOWN,
)

s.close()
await s.disconnect()


@main.command()
Expand Down Expand Up @@ -157,4 +157,4 @@ async def scan(ctx, channels, duration_ms, energy_scan):
for network in v:
click.echo(network)

s.close()
await s.disconnect()
2 changes: 1 addition & 1 deletion bellows/cli/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def stream(ctx, channel, power):
s = ctx.obj["ezsp"]
loop.run_until_complete(s.mfglibStopStream())
loop.run_until_complete(s.mfglibEnd())
s.close()
loop.run_until_complete(s.disconnect())


async def _stream(ctx, channel, power):
Expand Down
2 changes: 1 addition & 1 deletion bellows/cli/tone.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def tone(ctx, channel, power):
s = ctx.obj["ezsp"]
loop.run_until_complete(s.mfglibStopTone())
loop.run_until_complete(s.mfglibEnd())
s.close()
loop.run_until_complete(s.disconnect())


async def _tone(ctx, channel, power):
Expand Down
25 changes: 7 additions & 18 deletions bellows/cli/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,28 +59,17 @@ async def async_inner(ctx, *args, **kwargs):
if extra_config:
app_config.update(extra_config)
application = await setup_application(app_config, startup=app_startup)
ctx.obj["app"] = application
await f(ctx, *args, **kwargs)
await asyncio.sleep(0.5)
await application.shutdown()

def shutdown():
with contextlib.suppress(Exception):
application._ezsp.close()
try:
ctx.obj["app"] = application
await f(ctx, *args, **kwargs)
finally:
with contextlib.suppress(Exception):
await application.shutdown()

@functools.wraps(f)
def inner(*args, **kwargs):
loop = asyncio.get_event_loop()
try:
loop.run_until_complete(async_inner(*args, **kwargs))
except: # noqa: E722
# It seems that often errors like a message send will try to send
# two messages, and not reading all of them will leave the NCP in
# a bad state. This seems to mitigate this somewhat. Better way?
loop.run_until_complete(asyncio.sleep(0.5))
raise
finally:
shutdown()
loop.run_until_complete(async_inner(*args, **kwargs))

return inner

Expand Down
52 changes: 18 additions & 34 deletions bellows/ezsp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from typing import Any, Callable, Generator
import urllib.parse

from bellows.ash import NcpFailure

if sys.version_info[:2] < (3, 11):
from async_timeout import timeout as asyncio_timeout # pragma: no cover
else:
Expand Down Expand Up @@ -55,13 +57,14 @@ class EZSP:
v14.EZSPv14.VERSION: v14.EZSPv14,
}

def __init__(self, device_config: dict):
def __init__(self, device_config: dict, application: Any | None = None):
self._config = device_config
self._callbacks = {}
self._ezsp_event = asyncio.Event()
self._ezsp_version = v4.EZSPv4.VERSION
self._gw = None
self._protocol = None
self._application = application

self._stack_status_listeners: collections.defaultdict[
t.sl_Status, list[asyncio.Future]
Expand Down Expand Up @@ -122,25 +125,17 @@ async def startup_reset(self) -> None:

await self.version()

@classmethod
async def initialize(cls, zigpy_config: dict) -> EZSP:
"""Return initialized EZSP instance."""
ezsp = cls(zigpy_config[conf.CONF_DEVICE])
await ezsp.connect(use_thread=zigpy_config[conf.CONF_USE_THREAD])
async def connect(self, *, use_thread: bool = True) -> None:
assert self._gw is None
self._gw = await bellows.uart.connect(self._config, self, use_thread=use_thread)

try:
await ezsp.startup_reset()
self._protocol = v4.EZSPv4(self.handle_callback, self._gw)
await self.startup_reset()
except Exception:
ezsp.close()
await self.disconnect()
raise

return ezsp

async def connect(self, *, use_thread: bool = True) -> None:
assert self._gw is None
self._gw = await bellows.uart.connect(self._config, self, use_thread=use_thread)
self._protocol = v4.EZSPv4(self.handle_callback, self._gw)

async def reset(self):
LOGGER.debug("Resetting EZSP")
self.stop_ezsp()
Expand Down Expand Up @@ -179,10 +174,10 @@ async def version(self):
ver,
)

def close(self):
async def disconnect(self):
self.stop_ezsp()
if self._gw:
self._gw.close()
await self._gw.disconnect()
self._gw = None

async def _command(self, name: str, *args: Any, **kwargs: Any) -> Any:
Expand Down Expand Up @@ -264,23 +259,12 @@ async def leaveNetwork(self, timeout: float | int = NETWORK_OPS_TIMEOUT) -> None

def connection_lost(self, exc):
"""Lost serial connection."""
LOGGER.debug(
"%s connection lost unexpectedly: %s",
self._config[conf.CONF_DEVICE_PATH],
exc,
)
self.enter_failed_state(f"Serial connection loss: {exc!r}")

def enter_failed_state(self, error):
"""UART received error frame."""
if len(self._callbacks) > 1:
LOGGER.error("NCP entered failed state. Requesting APP controller restart")
self.close()
self.handle_callback("_reset_controller_application", (error,))
else:
LOGGER.info(
"NCP entered failed state. No application handler registered, ignoring..."
)
if self._application is not None:
self._application.connection_lost(exc)

def enter_failed_state(self, code: t.NcpResetCode) -> None:
"""UART received reset code."""
self.connection_lost(NcpFailure(code=code))

def __getattr__(self, name: str) -> Callable:
if name not in self._protocol.COMMANDS:
Expand Down
56 changes: 19 additions & 37 deletions bellows/uart.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,38 +18,29 @@
RESET_TIMEOUT = 5


class Gateway(asyncio.Protocol):
def __init__(self, application, connected_future=None, connection_done_future=None):
self._application = application
class Gateway(zigpy.serial.SerialProtocol):
def __init__(self, api, connection_done_future=None):
super().__init__()
self._api = api

self._reset_future = None
self._startup_reset_future = None
self._connected_future = connected_future
self._connection_done_future = connection_done_future

self._transport = None

def close(self):
self._transport.close()

def connection_made(self, transport):
"""Callback when the uart is connected"""
self._transport = transport
if self._connected_future is not None:
self._connected_future.set_result(True)

async def send_data(self, data: bytes) -> None:
await self._transport.send_data(data)

def data_received(self, data):
"""Callback when there is data received from the uart"""
self._application.frame_received(data)

# We intentionally do not call `SerialProtocol.data_received`
self._api.frame_received(data)

def reset_received(self, code: t.NcpResetCode) -> None:
"""Reset acknowledgement frame receive handler"""
# not a reset we've requested. Signal application reset
# not a reset we've requested. Signal api reset
if code is not t.NcpResetCode.RESET_SOFTWARE:
self._application.enter_failed_state(code)
self._api.enter_failed_state(code)
return

if self._reset_future and not self._reset_future.done():
Expand All @@ -61,7 +52,7 @@ def reset_received(self, code: t.NcpResetCode) -> None:

def error_received(self, code: t.NcpResetCode) -> None:
"""Error frame receive handler."""
self._application.enter_failed_state(code)
self._api.enter_failed_state(code)

async def wait_for_startup_reset(self) -> None:
"""Wait for the first reset frame on startup."""
Expand All @@ -77,12 +68,9 @@ def _reset_cleanup(self, future):
"""Delete reset future."""
self._reset_future = None

def eof_received(self):
"""Server gracefully closed its side of the connection."""
self.connection_lost(ConnectionResetError("Remote server closed connection"))

def connection_lost(self, exc):
"""Port was closed unexpectedly."""
super().connection_lost(exc)

LOGGER.debug("Connection lost: %r", exc)
reason = exc or ConnectionResetError("Remote server closed connection")
Expand All @@ -102,12 +90,7 @@ def connection_lost(self, exc):
self._reset_future.set_exception(reason)
self._reset_future = None

if exc is None:
LOGGER.debug("Closed serial connection")
return

LOGGER.error("Lost serial connection: %r", exc)
self._application.connection_lost(exc)
self._api.connection_lost(exc)

async def reset(self):
"""Send a reset frame and init internal state."""
Expand All @@ -126,13 +109,12 @@ async def reset(self):
return await self._reset_future


async def _connect(config, application):
async def _connect(config, api):
loop = asyncio.get_event_loop()

connection_future = loop.create_future()
connection_done_future = loop.create_future()

gateway = Gateway(application, connection_future, connection_done_future)
gateway = Gateway(api, connection_done_future)
protocol = AshProtocol(gateway)

if config[zigpy.config.CONF_DEVICE_FLOW_CONTROL] is None:
Expand All @@ -149,25 +131,25 @@ async def _connect(config, application):
rtscts=rtscts,
)

await connection_future
await gateway.wait_until_connected()

thread_safe_protocol = ThreadsafeProxy(gateway, loop)
return thread_safe_protocol, connection_done_future


async def connect(config, application, use_thread=True):
async def connect(config, api, use_thread=True):
if use_thread:
application = ThreadsafeProxy(application, asyncio.get_event_loop())
api = ThreadsafeProxy(api, asyncio.get_event_loop())
thread = EventLoopThread()
await thread.start()
try:
protocol, connection_done = await thread.run_coroutine_threadsafe(
_connect(config, application)
_connect(config, api)
)
except Exception:
thread.force_stop()
raise
connection_done.add_done_callback(lambda _: thread.force_stop())
else:
protocol, _ = await _connect(config, application)
protocol, _ = await _connect(config, api)
return protocol
Loading