diff --git a/bellows/ezsp/__init__.py b/bellows/ezsp/__init__.py index 220f22c4..9b4521b8 100644 --- a/bellows/ezsp/__init__.py +++ b/bellows/ezsp/__init__.py @@ -5,6 +5,7 @@ import asyncio import collections import contextlib +import dataclasses import functools import logging import sys @@ -20,6 +21,7 @@ import bellows.config as conf from bellows.exception import EzspError, InvalidCommandError +from bellows.ezsp.config import DEFAULT_CONFIG, RuntimeConfig, ValueConfig import bellows.types as t import bellows.uart @@ -133,15 +135,14 @@ async def probe(cls, device_config: dict) -> bool | dict[str, int | str | bool]: async def _probe(self) -> None: """Open port and try sending a command""" await self.connect(use_thread=False) - await self._startup_reset() - await self.version() + await self.startup_reset() @property def is_tcp_serial_port(self) -> bool: parsed_path = urllib.parse.urlparse(self._config[conf.CONF_DEVICE_PATH]) return parsed_path.scheme == "socket" - async def _startup_reset(self): + async def startup_reset(self) -> None: """Start EZSP and reset the stack.""" # `zigbeed` resets on startup if self.is_tcp_serial_port: @@ -157,6 +158,8 @@ async def _startup_reset(self): if not self.is_ezsp_running: await self.reset() + await self.version() + @classmethod async def initialize(cls, zigpy_config: dict) -> EZSP: """Return initialized EZSP instance.""" @@ -164,12 +167,7 @@ async def initialize(cls, zigpy_config: dict) -> EZSP: await ezsp.connect(use_thread=zigpy_config[conf.CONF_USE_THREAD]) try: - await ezsp._startup_reset() - await ezsp.version() - await ezsp._protocol.initialize(zigpy_config) - - if zigpy_config[zigpy.config.CONF_SOURCE_ROUTING]: - await ezsp.set_source_routing() + await ezsp.startup_reset() except Exception: ezsp.close() raise @@ -419,6 +417,20 @@ async def can_rewrite_custom_eui64(self) -> bool: """Checks if the device EUI64 can be written any number of times.""" return await self._get_nv3_restored_eui64_key() is not None + async def reset_custom_eui64(self) -> None: + """Reset the custom EUI64, if possible.""" + + nv3_eui64_key = await self._get_nv3_restored_eui64_key() + if nv3_eui64_key is None: + return + + (status,) = await self.setTokenData( + nv3_eui64_key, + 0, + t.LVBytes32(t.EmberEUI64.convert("FF:FF:FF:FF:FF:FF:FF:FF").serialize()), + ) + assert status == t.EmberStatus.SUCCESS + async def write_custom_eui64( self, ieee: t.EUI64, *, burn_into_userdata: bool = False ) -> None: @@ -488,7 +500,9 @@ async def set_source_routing(self) -> None: LOGGER.debug("Set concentrator type: %s", res) if res[0] != self.types.EmberStatus.SUCCESS: LOGGER.warning("Couldn't set concentrator type %s: %s", True, res) - await self._protocol.set_source_routing() + + if self._ezsp_version >= 8: + await self.setSourceRouteDiscoveryMode(1) def start_ezsp(self): """Mark EZSP as running.""" @@ -512,3 +526,104 @@ def ezsp_version(self): def types(self): """Return EZSP types for this specific version.""" return self._protocol.types + + async def write_config(self, config: dict) -> None: + """Initialize EmberZNet Stack.""" + config = self._protocol.SCHEMAS[conf.CONF_EZSP_CONFIG](config) + + # Not all config will be present in every EZSP version so only use valid keys + ezsp_config = {} + ezsp_values = {} + + for cfg in DEFAULT_CONFIG[self._ezsp_version]: + if isinstance(cfg, RuntimeConfig): + ezsp_config[cfg.config_id.name] = dataclasses.replace( + cfg, config_id=self.types.EzspConfigId[cfg.config_id.name] + ) + elif isinstance(cfg, ValueConfig): + ezsp_values[cfg.value_id.name] = dataclasses.replace( + cfg, value_id=self.types.EzspValueId[cfg.value_id.name] + ) + + # Override the defaults with user-specified values (or `None` for deletions) + for name, value in config.items(): + if value is None: + ezsp_config.pop(name) + continue + + ezsp_config[name] = RuntimeConfig( + config_id=self.types.EzspConfigId[name], + value=value, + ) + + # Make sure CONFIG_PACKET_BUFFER_COUNT is always set last + if self.types.EzspConfigId.CONFIG_PACKET_BUFFER_COUNT.name in ezsp_config: + ezsp_config = { + **ezsp_config, + self.types.EzspConfigId.CONFIG_PACKET_BUFFER_COUNT.name: ezsp_config[ + self.types.EzspConfigId.CONFIG_PACKET_BUFFER_COUNT.name + ], + } + + # First, set the values + for cfg in ezsp_values.values(): + # XXX: A read failure does not mean the value is not writeable! + status, current_value = await self.getValue(cfg.value_id) + + if status == self.types.EmberStatus.SUCCESS: + current_value, _ = type(cfg.value).deserialize(current_value) + else: + current_value = None + + LOGGER.debug( + "Setting value %s = %s (old value %s)", + cfg.value_id.name, + cfg.value, + current_value, + ) + + (status,) = await self.setValue(cfg.value_id, cfg.value.serialize()) + + if status != self.types.EmberStatus.SUCCESS: + LOGGER.debug( + "Could not set value %s = %s: %s", + cfg.value_id.name, + cfg.value, + status, + ) + continue + + # Finally, set the config + for cfg in ezsp_config.values(): + (status, current_value) = await self.getConfigurationValue(cfg.config_id) + + # Only grow some config entries, all others should be set + if ( + status == self.types.EmberStatus.SUCCESS + and cfg.minimum + and current_value >= cfg.value + ): + LOGGER.debug( + "Current config %s = %s exceeds the default of %s, skipping", + cfg.config_id.name, + current_value, + cfg.value, + ) + continue + + LOGGER.debug( + "Setting config %s = %s (old value %s)", + cfg.config_id.name, + cfg.value, + current_value, + ) + + (status,) = await self.setConfigurationValue(cfg.config_id, cfg.value) + if status != self.types.EmberStatus.SUCCESS: + LOGGER.debug( + "Could not set config %s = %s: %s", + cfg.config_id, + cfg.value, + status, + ) + continue diff --git a/bellows/ezsp/protocol.py b/bellows/ezsp/protocol.py index 459501d1..fe1e82d6 100644 --- a/bellows/ezsp/protocol.py +++ b/bellows/ezsp/protocol.py @@ -1,18 +1,17 @@ import abc import asyncio import binascii -import dataclasses import functools import logging import sys -from typing import Any, Callable, Dict, Optional, Tuple +from typing import Any, Callable, Tuple if sys.version_info[:2] < (3, 11): from async_timeout import timeout as asyncio_timeout # pragma: no cover else: from asyncio import timeout as asyncio_timeout # pragma: no cover -from bellows.config import CONF_EZSP_CONFIG, CONF_EZSP_POLICIES +from bellows.config import CONF_EZSP_POLICIES from bellows.exception import InvalidCommandError from bellows.typing import GatewayType @@ -55,120 +54,6 @@ def _ezsp_frame_tx(self, name: str) -> bytes: async def pre_permit(self, time_s: int) -> None: """Schedule task before allowing new joins.""" - async def initialize(self, zigpy_config: Dict) -> None: - """Initialize EmberZNet Stack.""" - - # Prevent circular import - from bellows.ezsp.config import DEFAULT_CONFIG, RuntimeConfig, ValueConfig - - # Not all config will be present in every EZSP version so only use valid keys - ezsp_config = {} - ezsp_values = {} - - for cfg in DEFAULT_CONFIG[self.VERSION]: - if isinstance(cfg, RuntimeConfig): - ezsp_config[cfg.config_id.name] = dataclasses.replace( - cfg, config_id=self.types.EzspConfigId[cfg.config_id.name] - ) - elif isinstance(cfg, ValueConfig): - ezsp_values[cfg.value_id.name] = dataclasses.replace( - cfg, value_id=self.types.EzspValueId[cfg.value_id.name] - ) - - # Override the defaults with user-specified values (or `None` for deletions) - for name, value in self.SCHEMAS[CONF_EZSP_CONFIG]( - zigpy_config[CONF_EZSP_CONFIG] - ).items(): - if value is None: - ezsp_config.pop(name) - continue - - ezsp_config[name] = RuntimeConfig( - config_id=self.types.EzspConfigId[name], - value=value, - ) - - # Make sure CONFIG_PACKET_BUFFER_COUNT is always set last - if self.types.EzspConfigId.CONFIG_PACKET_BUFFER_COUNT.name in ezsp_config: - ezsp_config = { - **ezsp_config, - self.types.EzspConfigId.CONFIG_PACKET_BUFFER_COUNT.name: ezsp_config[ - self.types.EzspConfigId.CONFIG_PACKET_BUFFER_COUNT.name - ], - } - - # First, set the values - for cfg in ezsp_values.values(): - # XXX: A read failure does not mean the value is not writeable! - status, current_value = await self.getValue(cfg.value_id) - - if status == self.types.EmberStatus.SUCCESS: - current_value, _ = type(cfg.value).deserialize(current_value) - else: - current_value = None - - LOGGER.debug( - "Setting value %s = %s (old value %s)", - cfg.value_id.name, - cfg.value, - current_value, - ) - - (status,) = await self.setValue(cfg.value_id, cfg.value.serialize()) - - if status != self.types.EmberStatus.SUCCESS: - LOGGER.debug( - "Could not set value %s = %s: %s", - cfg.value_id.name, - cfg.value, - status, - ) - continue - - # Finally, set the config - for cfg in ezsp_config.values(): - (status, current_value) = await self.getConfigurationValue(cfg.config_id) - - # Only grow some config entries, all others should be set - if ( - status == self.types.EmberStatus.SUCCESS - and cfg.minimum - and current_value >= cfg.value - ): - LOGGER.debug( - "Current config %s = %s exceeds the default of %s, skipping", - cfg.config_id.name, - current_value, - cfg.value, - ) - continue - - LOGGER.debug( - "Setting config %s = %s (old value %s)", - cfg.config_id.name, - cfg.value, - current_value, - ) - - (status,) = await self.setConfigurationValue(cfg.config_id, cfg.value) - if status != self.types.EmberStatus.SUCCESS: - LOGGER.debug( - "Could not set config %s = %s: %s", - cfg.config_id, - cfg.value, - status, - ) - continue - - async def get_free_buffers(self) -> Optional[int]: - status, value = await self.getValue(self.types.EzspValueId.VALUE_FREE_BUFFERS) - - if status != self.types.EzspStatus.SUCCESS: - LOGGER.debug("Couldn't get free buffers: %s", status) - return None - - return int.from_bytes(value, byteorder="little") - async def command(self, name, *args) -> Any: """Serialize command and send it.""" LOGGER.debug("Send command %s: %s", name, args) @@ -182,13 +67,10 @@ async def command(self, name, *args) -> Any: async with asyncio_timeout(EZSP_CMD_TIMEOUT): return await future - async def set_source_routing(self) -> None: - """Enable source routing on NCP.""" - - async def update_policies(self, zigpy_config: dict) -> None: + async def update_policies(self, policy_config: dict) -> None: """Set up the policies for what the NCP should do.""" - policies = self.SCHEMAS[CONF_EZSP_POLICIES](zigpy_config[CONF_EZSP_POLICIES]) + policies = self.SCHEMAS[CONF_EZSP_POLICIES](policy_config) self.tc_policy = policies[self.types.EzspPolicyId.TRUST_CENTER_POLICY.name] for policy, value in policies.items(): diff --git a/bellows/ezsp/v10/__init__.py b/bellows/ezsp/v10/__init__.py index 09bdab3d..147303a9 100644 --- a/bellows/ezsp/v10/__init__.py +++ b/bellows/ezsp/v10/__init__.py @@ -1,19 +1,17 @@ """"EZSP Protocol version 10 protocol handler.""" -import asyncio import logging -from typing import Tuple import voluptuous import bellows.config from . import commands, config, types as v10_types -from .. import protocol +from ..v9 import EZSPv9 LOGGER = logging.getLogger(__name__) -class EZSPv10(protocol.ProtocolHandler): +class EZSPv10(EZSPv9): """EZSP Version 10 Protocol version handler.""" VERSION = 10 @@ -23,36 +21,3 @@ class EZSPv10(protocol.ProtocolHandler): bellows.config.CONF_EZSP_POLICIES: voluptuous.Schema(config.EZSP_POLICIES_SCH), } types = v10_types - - def _ezsp_frame_tx(self, name: str) -> bytes: - """Serialize the frame id.""" - cmd_id = self.COMMANDS[name][0] - hdr = [self._seq, 0x00, 0x01] - return bytes(hdr) + self.types.uint16_t(cmd_id).serialize() - - def _ezsp_frame_rx(self, data: bytes) -> Tuple[int, int, bytes]: - """Handler for received data frame.""" - seq, data = data[0], data[3:] - frame_id, data = self.types.uint16_t.deserialize(data) - - return seq, frame_id, data - - async def pre_permit(self, time_s: int) -> None: - """Temporarily change TC policy while allowing new joins.""" - wild_card_ieee = v10_types.EmberEUI64([0xFF] * 8) - tc_link_key = v10_types.EmberKeyData(b"ZigBeeAlliance09") - await self.addTransientLinkKey(wild_card_ieee, tc_link_key) - await self.setPolicy( - v10_types.EzspPolicyId.TRUST_CENTER_POLICY, - v10_types.EzspDecisionBitmask.ALLOW_JOINS - | v10_types.EzspDecisionBitmask.ALLOW_UNSECURED_REJOINS, - ) - await asyncio.sleep(time_s + 2) - await self.setPolicy( - v10_types.EzspPolicyId.TRUST_CENTER_POLICY, - self.tc_policy, - ) - - async def set_source_routing(self) -> None: - """Enable source routing on NCP.""" - await self.setSourceRouteDiscoveryMode(1) diff --git a/bellows/ezsp/v4/__init__.py b/bellows/ezsp/v4/__init__.py index f786b76f..ed2ec6b6 100644 --- a/bellows/ezsp/v4/__init__.py +++ b/bellows/ezsp/v4/__init__.py @@ -31,3 +31,6 @@ def _ezsp_frame_tx(self, name: str) -> bytes: def _ezsp_frame_rx(self, data: bytes) -> Tuple[int, int, bytes]: """Handler for received data frame.""" return data[0], data[2], data[3:] + + async def pre_permit(self, time_s: int) -> None: + pass diff --git a/bellows/ezsp/v7/__init__.py b/bellows/ezsp/v7/__init__.py index 755711f8..da0a0018 100644 --- a/bellows/ezsp/v7/__init__.py +++ b/bellows/ezsp/v7/__init__.py @@ -6,12 +6,12 @@ import bellows.config from . import commands, config, types as v7_types -from ..v5 import EZSPv5 +from ..v6 import EZSPv6 LOGGER = logging.getLogger(__name__) -class EZSPv7(EZSPv5): +class EZSPv7(EZSPv6): """EZSP Version 7 Protocol version handler.""" VERSION = 7 diff --git a/bellows/ezsp/v8/__init__.py b/bellows/ezsp/v8/__init__.py index e731043e..98a3ab82 100644 --- a/bellows/ezsp/v8/__init__.py +++ b/bellows/ezsp/v8/__init__.py @@ -8,12 +8,12 @@ import bellows.config from . import commands, config, types as v8_types -from .. import protocol +from ..v7 import EZSPv7 LOGGER = logging.getLogger(__name__) -class EZSPv8(protocol.ProtocolHandler): +class EZSPv8(EZSPv7): """EZSP Version 8 Protocol version handler.""" VERSION = 8 @@ -39,9 +39,7 @@ def _ezsp_frame_rx(self, data: bytes) -> Tuple[int, int, bytes]: async def pre_permit(self, time_s: int) -> None: """Temporarily change TC policy while allowing new joins.""" - wild_card_ieee = v8_types.EmberEUI64([0xFF] * 8) - tc_link_key = v8_types.EmberKeyData(b"ZigBeeAlliance09") - await self.addTransientLinkKey(wild_card_ieee, tc_link_key) + await super().pre_permit(time_s) await self.setPolicy( v8_types.EzspPolicyId.TRUST_CENTER_POLICY, v8_types.EzspDecisionBitmask.ALLOW_JOINS @@ -52,7 +50,3 @@ async def pre_permit(self, time_s: int) -> None: v8_types.EzspPolicyId.TRUST_CENTER_POLICY, self.tc_policy, ) - - async def set_source_routing(self) -> None: - """Enable source routing on NCP.""" - await self.setSourceRouteDiscoveryMode(1) diff --git a/bellows/ezsp/v9/__init__.py b/bellows/ezsp/v9/__init__.py index dba7b3a6..1bd04afa 100644 --- a/bellows/ezsp/v9/__init__.py +++ b/bellows/ezsp/v9/__init__.py @@ -1,19 +1,17 @@ -""""EZSP Protocol version 8 protocol handler.""" -import asyncio +""""EZSP Protocol version 9 protocol handler.""" import logging -from typing import Tuple import voluptuous import bellows.config from . import commands, config, types as v9_types -from .. import protocol +from ..v8 import EZSPv8 LOGGER = logging.getLogger(__name__) -class EZSPv9(protocol.ProtocolHandler): +class EZSPv9(EZSPv8): """EZSP Version 9 Protocol version handler.""" VERSION = 9 @@ -23,36 +21,3 @@ class EZSPv9(protocol.ProtocolHandler): bellows.config.CONF_EZSP_POLICIES: voluptuous.Schema(config.EZSP_POLICIES_SCH), } types = v9_types - - def _ezsp_frame_tx(self, name: str) -> bytes: - """Serialize the frame id.""" - cmd_id = self.COMMANDS[name][0] - hdr = [self._seq, 0x00, 0x01] - return bytes(hdr) + self.types.uint16_t(cmd_id).serialize() - - def _ezsp_frame_rx(self, data: bytes) -> Tuple[int, int, bytes]: - """Handler for received data frame.""" - seq, data = data[0], data[3:] - frame_id, data = self.types.uint16_t.deserialize(data) - - return seq, frame_id, data - - async def pre_permit(self, time_s: int) -> None: - """Temporarily change TC policy while allowing new joins.""" - wild_card_ieee = v9_types.EmberEUI64([0xFF] * 8) - tc_link_key = v9_types.EmberKeyData(b"ZigBeeAlliance09") - await self.addTransientLinkKey(wild_card_ieee, tc_link_key) - await self.setPolicy( - v9_types.EzspPolicyId.TRUST_CENTER_POLICY, - v9_types.EzspDecisionBitmask.ALLOW_JOINS - | v9_types.EzspDecisionBitmask.ALLOW_UNSECURED_REJOINS, - ) - await asyncio.sleep(time_s + 2) - await self.setPolicy( - v9_types.EzspPolicyId.TRUST_CENTER_POLICY, - self.tc_policy, - ) - - async def set_source_routing(self) -> None: - """Enable source routing on NCP.""" - await self.setSourceRouteDiscoveryMode(1) diff --git a/bellows/types/struct.py b/bellows/types/struct.py index f5073902..a56c3793 100644 --- a/bellows/types/struct.py +++ b/bellows/types/struct.py @@ -349,3 +349,11 @@ class EmberGpAddress(EzspStruct): applicationId: basic.uint8_t # The GPD endpoint. endpoint: basic.uint8_t + + +class NV3StackTrustCenterToken(EzspStruct): + """NV3 stack trust center token value.""" + + mode: basic.uint16_t + eui64: named.EmberEUI64 + key: named.EmberKeyData diff --git a/bellows/zigbee/application.py b/bellows/zigbee/application.py index d6620e78..538fc74e 100644 --- a/bellows/zigbee/application.py +++ b/bellows/zigbee/application.py @@ -25,7 +25,9 @@ import bellows from bellows.config import ( + CONF_EZSP_POLICIES, CONF_PARAM_MAX_WATCHDOG_FAILURES, + CONF_USE_THREAD, CONFIG_SCHEMA, SCHEMA_DEVICE, ) @@ -34,6 +36,7 @@ from bellows.ezsp.v8.types.named import EmberDeviceUpdate import bellows.multicast import bellows.types as t +from bellows.zigbee import repairs from bellows.zigbee.device import EZSPEndpoint import bellows.zigbee.util as util @@ -129,17 +132,17 @@ async def _get_board_info(self) -> tuple[str, str, str] | tuple[None, None, None return None, None, None - async def connect(self): - self._ezsp = await bellows.ezsp.EZSP.initialize(self.config) - ezsp = self._ezsp + async def connect(self) -> None: + ezsp = bellows.ezsp.EZSP(self.config[zigpy.config.CONF_DEVICE]) + await ezsp.connect(use_thread=self.config[CONF_USE_THREAD]) - self._multicast = bellows.multicast.Multicast(ezsp) - await self.register_endpoints() + try: + await ezsp.startup_reset() + except Exception: + ezsp.close() + raise - brd_manuf, brd_name, version = await self._get_board_info() - LOGGER.info("EZSP Radio manufacturer: %s", brd_manuf) - LOGGER.info("EZSP Radio board name: %s", brd_name) - LOGGER.info("EmberZNet version: %s", version) + self._ezsp = ezsp async def _ensure_network_running(self) -> bool: """Ensures the network is currently running and returns whether or not the network @@ -168,7 +171,16 @@ async def start_network(self): await self._ensure_network_running() - await ezsp.update_policies(self.config) + if await repairs.fix_invalid_tclk_partner_ieee(ezsp): + # Reboot the stack after modifying NV3 + ezsp.stop_ezsp() + await ezsp.startup_reset() + await self._ensure_network_running() + + if self.config[zigpy.config.CONF_SOURCE_ROUTING]: + await ezsp.set_source_routing() + + await ezsp._protocol.update_policies(self.config[CONF_EZSP_POLICIES]) await self.load_network_info(load_devices=False) for cnt_group in self.state.counters: @@ -206,7 +218,8 @@ async def start_network(self): if db_device is not None and 1 in db_device.endpoints: ezsp_device.endpoints[1].member_of.update(db_device.endpoints[1].member_of) - await self.multicast.startup(ezsp_device) + self._multicast = bellows.multicast.Multicast(ezsp) + await self._multicast.startup(ezsp_device) async def load_network_info(self, *, load_devices=False) -> None: ezsp = self._ezsp @@ -353,6 +366,7 @@ async def write_network_info( stack_specific = network_info.stack_specific.get("ezsp", {}) (current_eui64,) = await ezsp.getEui64() + wrote_eui64 = False if ( node_info.ieee != zigpy.types.EUI64.UNKNOWN @@ -360,6 +374,7 @@ async def write_network_info( ): if await ezsp.can_rewrite_custom_eui64(): await ezsp.write_custom_eui64(node_info.ieee) + wrote_eui64 = True elif not stack_specific.get( "i_understand_i_can_update_eui64_only_once_and_i_still_want_to_do_it" ): @@ -375,6 +390,13 @@ async def write_network_info( ) else: await ezsp.write_custom_eui64(node_info.ieee, burn_into_userdata=True) + wrote_eui64 = True + + # If we cannot write the new EUI64, don't mess up key entries with the unwritten + # EUI64 address + if not wrote_eui64: + node_info.ieee = current_eui64 + network_info.tc_link_key.partner_ieee = current_eui64 use_hashed_tclk = ezsp.ezsp_version > 4 @@ -386,7 +408,6 @@ async def write_network_info( initial_security_state = util.zha_security( network_info=network_info, - node_info=node_info, use_hashed_tclk=use_hashed_tclk, ) (status,) = await ezsp.setInitialSecurityState(initial_security_state) @@ -444,6 +465,9 @@ async def reset_network_info(self): (status,) = await self._ezsp.clearKeyTable() assert status == t.EmberStatus.SUCCESS + # Reset the custom EUI64 + await self._ezsp.reset_custom_eui64() + async def disconnect(self): # TODO: how do you shut down the stack? self.controller_event.clear() diff --git a/bellows/zigbee/repairs.py b/bellows/zigbee/repairs.py new file mode 100644 index 00000000..0f4863d2 --- /dev/null +++ b/bellows/zigbee/repairs.py @@ -0,0 +1,51 @@ +"""Coordinator state repairs.""" + +import logging + +import zigpy.types + +from bellows.exception import InvalidCommandError +from bellows.ezsp import EZSP +import bellows.types as t + +LOGGER = logging.getLogger(__name__) + + +async def fix_invalid_tclk_partner_ieee(ezsp: EZSP) -> bool: + """Fix invalid TCLK partner IEEE address.""" + (ieee,) = await ezsp.getEui64() + ieee = zigpy.types.EUI64(ieee) + + (status, state) = await ezsp.getCurrentSecurityState() + assert status == t.EmberStatus.SUCCESS + + if state.trustCenterLongAddress == ieee: + return False + + LOGGER.warning( + "Fixing invalid TCLK partner IEEE (%s => %s)", + state.trustCenterLongAddress, + ieee, + ) + + try: + (status, value) = await ezsp.getTokenData( + t.NV3KeyId.NVM3KEY_STACK_TRUST_CENTER, 0 + ) + assert status == t.EmberStatus.SUCCESS + except InvalidCommandError: + LOGGER.warning("NV3 interface not available in this firmware, please upgrade!") + return False + + token, remaining = t.NV3StackTrustCenterToken.deserialize(value) + assert not remaining + assert token.eui64 == state.trustCenterLongAddress + + (status,) = await ezsp.setTokenData( + t.NV3KeyId.NVM3KEY_STACK_TRUST_CENTER, + 0, + token.replace(eui64=ieee).serialize(), + ) + assert status == t.EmberStatus.SUCCESS + + return True diff --git a/bellows/zigbee/util.py b/bellows/zigbee/util.py index 1d3779d0..38ff9b25 100644 --- a/bellows/zigbee/util.py +++ b/bellows/zigbee/util.py @@ -3,7 +3,6 @@ import zigpy.state import zigpy.types as zigpy_t -import zigpy.zdo.types as zdo_t import bellows.types as t @@ -21,7 +20,6 @@ def zha_security( *, network_info: zigpy.state.NetworkInfo, - node_info: zigpy.state.NodeInfo, use_hashed_tclk: bool, ) -> t.EmberInitialSecurityState: """Construct an `EmberInitialSecurityState` out of zigpy network state.""" @@ -35,16 +33,15 @@ def zha_security( isc.networkKey = t.EmberKeyData(network_info.network_key.key) isc.networkKeySequenceNumber = t.uint8_t(network_info.network_key.seq) - if ( - node_info.logical_type != zdo_t.LogicalType.Coordinator - and network_info.tc_link_key.partner_ieee != zigpy_t.EUI64.UNKNOWN - ): + if network_info.tc_link_key.partner_ieee != zigpy_t.EUI64.UNKNOWN: isc.bitmask |= t.EmberInitialSecurityBitmask.HAVE_TRUST_CENTER_EUI64 isc.preconfiguredTrustCenterEui64 = t.EmberEUI64( network_info.tc_link_key.partner_ieee ) else: - isc.preconfiguredTrustCenterEui64 = t.EmberEUI64([0x00] * 8) + isc.preconfiguredTrustCenterEui64 = t.EmberEUI64.convert( + "00:00:00:00:00:00:00:00" + ) if use_hashed_tclk: if network_info.tc_link_key.key != zigpy_t.KeyData(b"ZigBeeAlliance09"): diff --git a/tests/test_application.py b/tests/test_application.py index 1ab899e1..87fa09ee 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -1,5 +1,6 @@ import asyncio import logging +from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch, sentinel import pytest import zigpy.config @@ -19,11 +20,10 @@ import bellows.types.struct import bellows.uart as uart import bellows.zigbee.application +from bellows.zigbee.application import ControllerApplication import bellows.zigbee.device from bellows.zigbee.util import map_rssi_to_energy -from .async_mock import AsyncMock, MagicMock, PropertyMock, patch, sentinel - APP_CONFIG = { config.CONF_DEVICE: { config.CONF_DEVICE_PATH: "/dev/null", @@ -34,11 +34,31 @@ @pytest.fixture -def ezsp_mock(): +def ieee(init=0): + return t.EmberEUI64(map(t.uint8_t, range(init, init + 8))) + + +@pytest.fixture +def ezsp_mock(ieee): """EZSP fixture""" mock_ezsp = MagicMock(spec=ezsp.EZSP) mock_ezsp.ezsp_version = 7 mock_ezsp.setManufacturerCode = AsyncMock() + mock_ezsp.getEui64 = AsyncMock(return_value=[ieee]) + mock_ezsp.getConfigurationValue = AsyncMock(return_value=[t.EmberStatus.SUCCESS, 0]) + mock_ezsp.getCurrentSecurityState = AsyncMock( + return_value=[ + t.EmberStatus.SUCCESS, + t.EmberCurrentSecurityState( + bitmask=( + t.EmberCurrentSecurityBitmask.GLOBAL_LINK_KEY + | t.EmberCurrentSecurityBitmask.HAVE_TRUST_CENTER_LINK_KEY + | 224 + ), + trustCenterLongAddress=ieee, + ), + ] + ) mock_ezsp.set_source_route = AsyncMock(return_value=[t.EmberStatus.SUCCESS]) mock_ezsp.addTransientLinkKey = AsyncMock(return_value=[0]) mock_ezsp.readCounters = AsyncMock(return_value=[[0] * 10]) @@ -50,6 +70,7 @@ def ezsp_mock(): mock_ezsp.wait_for_stack_status.return_value.__enter__ = AsyncMock( return_value=t.EmberStatus.NETWORK_UP ) + mock_ezsp._protocol = AsyncMock() type(mock_ezsp).types = ezsp_t7 type(mock_ezsp).is_ezsp_running = PropertyMock(return_value=True) @@ -60,10 +81,8 @@ def ezsp_mock(): @pytest.fixture def make_app(monkeypatch, event_loop, ezsp_mock): def inner(config): - app_cfg = bellows.zigbee.application.ControllerApplication.SCHEMA( - {**APP_CONFIG, **config} - ) - app = bellows.zigbee.application.ControllerApplication(app_cfg) + app_cfg = ControllerApplication.SCHEMA({**APP_CONFIG, **config}) + app = ControllerApplication(app_cfg) app._ezsp = ezsp_mock monkeypatch.setattr(bellows.zigbee.application, "APS_ACK_TIMEOUT", 0.05) @@ -93,11 +112,6 @@ def aps(): return f -@pytest.fixture -def ieee(init=0): - return t.EmberEUI64(map(t.uint8_t, range(init, init + 8))) - - @patch("zigpy.device.Device._initialize", new=AsyncMock()) @patch("bellows.zigbee.application.ControllerApplication._watchdog", new=AsyncMock()) async def _test_startup( @@ -126,12 +140,14 @@ async def mock_leave(*args, **kwargs): return [t.EmberStatus.NETWORK_DOWN] app._in_flight_msg = None - ezsp_mock = MagicMock() + ezsp_mock = MagicMock(spec=ezsp.EZSP) ezsp_mock.types = ezsp_t7 type(ezsp_mock).ezsp_version = PropertyMock(return_value=ezsp_version) ezsp_mock.initialize = AsyncMock(return_value=ezsp_mock) ezsp_mock.connect = AsyncMock() + ezsp_mock._protocol = AsyncMock() ezsp_mock.setConcentrator = AsyncMock() + ezsp_mock.getTokenData = AsyncMock(return_value=[t.EmberStatus.ERR_FATAL, b""]) ezsp_mock._command = AsyncMock(return_value=t.EmberStatus.SUCCESS) ezsp_mock.addEndpoint = AsyncMock(return_value=t.EmberStatus.SUCCESS) ezsp_mock.setConfigurationValue = AsyncMock(return_value=t.EmberStatus.SUCCESS) @@ -179,12 +195,15 @@ async def mock_leave(*args, **kwargs): return_value=[ t.EmberStatus.SUCCESS, t.EmberCurrentSecurityState( - bitmask=t.EmberCurrentSecurityBitmask.TRUST_CENTER_USES_HASHED_LINK_KEY, - trustCenterLongAddress=t.EmberEUI64.convert("ff:ff:ff:ff:ff:ff:ff:ff"), + bitmask=( + t.EmberCurrentSecurityBitmask.GLOBAL_LINK_KEY + | t.EmberCurrentSecurityBitmask.HAVE_TRUST_CENTER_LINK_KEY + | 224 + ), + trustCenterLongAddress=ieee, ), ] ) - ezsp_mock.pre_permit = AsyncMock() app.permit = AsyncMock() def form_network(): @@ -196,7 +215,7 @@ def form_network(): app.form_network = AsyncMock(side_effect=form_network) - p1 = patch.object(bellows.ezsp, "EZSP", new=ezsp_mock) + p1 = patch("bellows.ezsp.EZSP", return_value=ezsp_mock) p2 = patch.object(bellows.multicast.Multicast, "startup") with p1, p2 as multicast_mock: @@ -1641,8 +1660,6 @@ async def test_startup_coordinator_existing_groups_joined(app, ieee): app._ensure_network_running = AsyncMock() app._ezsp.update_policies = AsyncMock() app.load_network_info = AsyncMock() - - app._multicast = bellows.multicast.Multicast(app._ezsp) app.state.node_info.ieee = ieee db_device = app.add_device(ieee, 0x0000) @@ -1666,8 +1683,6 @@ async def test_startup_new_coordinator_no_groups_joined(app, ieee): app._ensure_network_running = AsyncMock() app._ezsp.update_policies = AsyncMock() app.load_network_info = AsyncMock() - - app._multicast = bellows.multicast.Multicast(app._ezsp) app.state.node_info.ieee = ieee p1 = patch.object(bellows.multicast.Multicast, "_initialize") @@ -1692,9 +1707,6 @@ async def test_startup_source_routing(make_app, ieee, enable_source_routing): app.load_network_info = AsyncMock() app.state.node_info.ieee = ieee - app._multicast = bellows.multicast.Multicast(app._ezsp) - app._multicast._initialize = AsyncMock() - mock_device = MagicMock() mock_device.relays = sentinel.relays mock_device.initialize = AsyncMock() @@ -1774,3 +1786,34 @@ async def test_energy_scanning_partial(app): assert len(app._ezsp.startScan.mock_calls) == 6 assert set(results.keys()) == {11, 13, 14, 15, 20, 25, 26} assert results == {c: map_rssi_to_energy(c) for c in [11, 13, 14, 15, 20, 25, 26]} + + +async def test_connect_failure( + app: ControllerApplication, ezsp_mock: ezsp.EZSP +) -> None: + """Test that a failure to connect propagates.""" + ezsp_mock.startup_reset = AsyncMock(side_effect=OSError()) + ezsp_mock.connect = AsyncMock() + app._ezsp = None + + with patch("bellows.ezsp.EZSP", return_value=ezsp_mock): + with pytest.raises(OSError): + await app.connect() + + assert app._ezsp is None + + assert len(ezsp_mock.close.mock_calls) == 1 + + +async def test_repair_tclk_partner_ieee(app: ControllerApplication) -> None: + """Test that EZSP is reset after repairing TCLK.""" + app._ensure_network_running = AsyncMock() + app.load_network_info = AsyncMock() + + with patch( + "bellows.zigbee.repairs.fix_invalid_tclk_partner_ieee", + AsyncMock(return_value=True), + ): + await app.start_network() + + assert len(app._ensure_network_running.mock_calls) == 2 diff --git a/tests/test_application_network_state.py b/tests/test_application_network_state.py index 7bb370e4..04171b0e 100644 --- a/tests/test_application_network_state.py +++ b/tests/test_application_network_state.py @@ -9,7 +9,7 @@ import bellows.types as t from tests.async_mock import AsyncMock, PropertyMock -from tests.test_application import app, ezsp_mock, make_app +from tests.test_application import app, ezsp_mock, ieee, make_app @pytest.fixture diff --git a/tests/test_ezsp.py b/tests/test_ezsp.py index 795876d4..02f37321 100644 --- a/tests/test_ezsp.py +++ b/tests/test_ezsp.py @@ -1,11 +1,13 @@ import asyncio import functools +import logging import sys import pytest from bellows import config, ezsp, uart -from bellows.exception import EzspError +from bellows.exception import EzspError, InvalidCommandError +import bellows.ezsp.v4.types as v4_t import bellows.types as t if sys.version_info[:2] < (3, 11): @@ -13,7 +15,7 @@ else: from asyncio import timeout as asyncio_timeout # pragma: no cover -from .async_mock import AsyncMock, MagicMock, call, patch, sentinel +from unittest.mock import ANY, AsyncMock, MagicMock, call, patch, sentinel DEVICE_CONFIG = { config.CONF_DEVICE_PATH: "/dev/null", @@ -312,29 +314,16 @@ async def test_probe_fail(exception): assert mock_connect.return_value.close.call_count == 2 -@patch.object(ezsp.EZSP, "set_source_routing", new_callable=AsyncMock) -@patch("bellows.ezsp.v4.EZSPv4.initialize", new_callable=AsyncMock) @patch.object(ezsp.EZSP, "version", new_callable=AsyncMock) @patch.object(ezsp.EZSP, "reset", new_callable=AsyncMock) @patch("bellows.uart.connect", return_value=MagicMock(spec_set=uart.Gateway)) -async def test_ezsp_init( - conn_mock, reset_mock, version_mock, prot_handler_mock, src_mock -): +async def test_ezsp_init(conn_mock, reset_mock, version_mock): """Test initialize method.""" zigpy_config = config.CONFIG_SCHEMA({"device": DEVICE_CONFIG}) await ezsp.EZSP.initialize(zigpy_config) assert conn_mock.await_count == 1 assert reset_mock.await_count == 1 assert version_mock.await_count == 1 - assert prot_handler_mock.await_count == 1 - assert src_mock.call_count == 0 - assert src_mock.await_count == 0 - - zigpy_config = config.CONFIG_SCHEMA( - {"device": DEVICE_CONFIG, "source_routing": "yes"} - ) - await ezsp.EZSP.initialize(zigpy_config) - assert src_mock.await_count == 1 @patch.object(ezsp.EZSP, "version", side_effect=RuntimeError("Uh oh")) @@ -458,7 +447,7 @@ async def test_update_policies(ezsp_f): assert pol_mock.await_count == 1 -async def test_set_concentrator(ezsp_f): +async def test_set_source_routing_set_concentrator(ezsp_f): """Test enabling source routing.""" with patch.object(ezsp_f, "setConcentrator", new=AsyncMock()) as cnc_mock: cnc_mock.return_value = (ezsp_f.types.EmberStatus.SUCCESS,) @@ -470,6 +459,17 @@ async def test_set_concentrator(ezsp_f): assert cnc_mock.await_count == 2 +async def test_set_source_routing_ezsp_v8(ezsp_f): + """Test enabling source routing on EZSPv8.""" + + ezsp_f._ezsp_version = 8 + ezsp_f.setConcentrator = AsyncMock(return_value=(ezsp_f.types.EmberStatus.SUCCESS,)) + ezsp_f.setSourceRouteDiscoveryMode = AsyncMock() + + await ezsp_f.set_source_routing() + assert len(ezsp_f.setSourceRouteDiscoveryMode.mock_calls) == 1 + + async def test_leave_network_error(ezsp_f): """Test EZSP leaveNetwork command failure.""" @@ -627,14 +627,10 @@ async def test_write_custom_eui64(ezsp_f): ezsp_f.setTokenData.assert_not_called() -@patch.object(ezsp.EZSP, "set_source_routing", new_callable=AsyncMock) -@patch("bellows.ezsp.v4.EZSPv4.initialize", new_callable=AsyncMock) @patch.object(ezsp.EZSP, "version", new_callable=AsyncMock) @patch.object(ezsp.EZSP, "reset", new_callable=AsyncMock) @patch("bellows.uart.connect", return_value=MagicMock(spec_set=uart.Gateway)) -async def test_ezsp_init_zigbeed( - conn_mock, reset_mock, version_mock, prot_handler_mock, src_mock -): +async def test_ezsp_init_zigbeed(conn_mock, reset_mock, version_mock): """Test initialize method with a received startup reset frame.""" zigpy_config = config.CONFIG_SCHEMA( { @@ -653,20 +649,13 @@ async def test_ezsp_init_zigbeed( assert reset_mock.await_count == 0 # Reset is not called assert gw_wait_reset_mock.await_count == 1 assert version_mock.await_count == 1 - assert prot_handler_mock.await_count == 1 - assert src_mock.call_count == 0 - assert src_mock.await_count == 0 -@patch.object(ezsp.EZSP, "set_source_routing", new_callable=AsyncMock) -@patch("bellows.ezsp.v4.EZSPv4.initialize", new_callable=AsyncMock) @patch.object(ezsp.EZSP, "version", new_callable=AsyncMock) @patch.object(ezsp.EZSP, "reset", new_callable=AsyncMock) @patch("bellows.uart.connect", return_value=MagicMock(spec_set=uart.Gateway)) @patch("bellows.ezsp.NETWORK_COORDINATOR_STARTUP_RESET_WAIT", 0.01) -async def test_ezsp_init_zigbeed_timeout( - conn_mock, reset_mock, version_mock, prot_handler_mock, src_mock -): +async def test_ezsp_init_zigbeed_timeout(conn_mock, reset_mock, version_mock): """Test initialize method with a received startup reset frame.""" zigpy_config = config.CONFIG_SCHEMA( { @@ -690,9 +679,6 @@ async def wait_forever(*args, **kwargs): assert reset_mock.await_count == 1 # Reset will be called assert gw_wait_reset_mock.await_count == 1 assert version_mock.await_count == 1 - assert prot_handler_mock.await_count == 1 - assert src_mock.call_count == 0 - assert src_mock.await_count == 0 async def test_wait_for_stack_status(ezsp_f): @@ -723,3 +709,143 @@ def test_ezsp_versions(ezsp_f): assert version in ezsp_f._BY_VERSION assert ezsp_f._BY_VERSION[version].__name__ == f"EZSPv{version}" assert ezsp_f._BY_VERSION[version].VERSION == version + + +async def test_config_initialize_husbzb1(ezsp_f): + """Test timeouts are properly set for HUSBZB-1.""" + + ezsp_f._ezsp_version = 4 + + ezsp_f.getConfigurationValue = AsyncMock(return_value=(t.EzspStatus.SUCCESS, 0)) + ezsp_f.setConfigurationValue = AsyncMock(return_value=(t.EzspStatus.SUCCESS,)) + + await ezsp_f.write_config({}) + ezsp_f.setConfigurationValue.assert_has_calls( + [ + call(v4_t.EzspConfigId.CONFIG_SOURCE_ROUTE_TABLE_SIZE, 16), + call(v4_t.EzspConfigId.CONFIG_END_DEVICE_POLL_TIMEOUT, 60), + call(v4_t.EzspConfigId.CONFIG_END_DEVICE_POLL_TIMEOUT_SHIFT, 8), + call(v4_t.EzspConfigId.CONFIG_INDIRECT_TRANSMISSION_TIMEOUT, 7680), + call(v4_t.EzspConfigId.CONFIG_STACK_PROFILE, 2), + call(v4_t.EzspConfigId.CONFIG_SUPPORTED_NETWORKS, 1), + call(v4_t.EzspConfigId.CONFIG_MULTICAST_TABLE_SIZE, 16), + call(v4_t.EzspConfigId.CONFIG_TRUST_CENTER_ADDRESS_CACHE_SIZE, 2), + call(v4_t.EzspConfigId.CONFIG_SECURITY_LEVEL, 5), + call(v4_t.EzspConfigId.CONFIG_ADDRESS_TABLE_SIZE, 16), + call(v4_t.EzspConfigId.CONFIG_PAN_ID_CONFLICT_REPORT_THRESHOLD, 2), + call(v4_t.EzspConfigId.CONFIG_KEY_TABLE_SIZE, 4), + call(v4_t.EzspConfigId.CONFIG_MAX_END_DEVICE_CHILDREN, 32), + call( + v4_t.EzspConfigId.CONFIG_APPLICATION_ZDO_FLAGS, + ( + v4_t.EmberZdoConfigurationFlags.APP_HANDLES_UNSUPPORTED_ZDO_REQUESTS + | v4_t.EmberZdoConfigurationFlags.APP_RECEIVES_SUPPORTED_ZDO_REQUESTS + ), + ), + call(v4_t.EzspConfigId.CONFIG_PACKET_BUFFER_COUNT, 255), + ] + ) + + +@pytest.mark.parametrize("version", ezsp.EZSP._BY_VERSION) +async def test_config_initialize(version: int, ezsp_f, caplog): + """Test config initialization for all protocol versions.""" + + assert ezsp_f.ezsp_version == 4 + + with patch.object(ezsp_f, "_command", AsyncMock(return_value=[version, 2, 2046])): + await ezsp_f.version() + + assert ezsp_f.ezsp_version == version + + ezsp_f.getConfigurationValue = AsyncMock(return_value=(t.EzspStatus.SUCCESS, 0)) + ezsp_f.setConfigurationValue = AsyncMock(return_value=(t.EzspStatus.SUCCESS,)) + + ezsp_f.setValue = AsyncMock(return_value=(t.EzspStatus.SUCCESS,)) + ezsp_f.getValue = AsyncMock(return_value=(t.EzspStatus.SUCCESS, b"\xFF")) + + await ezsp_f.write_config({}) + + with caplog.at_level(logging.DEBUG): + ezsp_f.setConfigurationValue.return_value = (t.EzspStatus.ERROR_OUT_OF_MEMORY,) + await ezsp_f.write_config({}) + + assert "Could not set config" in caplog.text + ezsp_f.setConfigurationValue.return_value = (t.EzspStatus.SUCCESS,) + caplog.clear() + + # EZSPv6 does not set any values on startup + if version < 7: + return + + ezsp_f.setValue.reset_mock() + ezsp_f.getValue.return_value = (t.EzspStatus.ERROR_INVALID_ID, b"") + await ezsp_f.write_config({}) + assert len(ezsp_f.setValue.mock_calls) == 1 + + ezsp_f.getValue = AsyncMock(return_value=(t.EzspStatus.SUCCESS, b"\xFF")) + caplog.clear() + + with caplog.at_level(logging.DEBUG): + ezsp_f.setValue.return_value = (t.EzspStatus.ERROR_INVALID_ID,) + await ezsp_f.write_config({}) + + assert "Could not set value" in caplog.text + ezsp_f.setValue.return_value = (t.EzspStatus.SUCCESS,) + caplog.clear() + + +async def test_cfg_initialize_skip(ezsp_f): + """Test initialization.""" + + p1 = patch.object( + ezsp_f, + "setConfigurationValue", + new=AsyncMock(return_value=(t.EzspStatus.SUCCESS,)), + ) + p2 = patch.object( + ezsp_f, + "getConfigurationValue", + new=AsyncMock(return_value=(t.EzspStatus.SUCCESS, 22)), + ) + with p1, p2: + await ezsp_f.write_config({"CONFIG_END_DEVICE_POLL_TIMEOUT": None}) + + # Config not set when it is explicitly disabled + with pytest.raises(AssertionError): + ezsp_f.setConfigurationValue.assert_called_with( + v4_t.EzspConfigId.CONFIG_END_DEVICE_POLL_TIMEOUT, ANY + ) + + with p1, p2: + await ezsp_f.write_config({"CONFIG_MULTICAST_TABLE_SIZE": 123}) + + # Config is overridden + ezsp_f.setConfigurationValue.assert_any_call( + v4_t.EzspConfigId.CONFIG_MULTICAST_TABLE_SIZE, 123 + ) + + with p1, p2: + await ezsp_f.write_config({}) + + # Config is set by default + ezsp_f.setConfigurationValue.assert_any_call( + v4_t.EzspConfigId.CONFIG_END_DEVICE_POLL_TIMEOUT, ANY + ) + + +async def test_reset_custom_eui64(ezsp_f): + """Test resetting custom EUI64.""" + # No NV3 interface + ezsp_f.getTokenData = AsyncMock(side_effect=InvalidCommandError) + ezsp_f.setTokenData = AsyncMock(return_value=[t.EmberStatus.SUCCESS]) + await ezsp_f.reset_custom_eui64() + + assert len(ezsp_f.setTokenData.mock_calls) == 0 + + # With NV3 interface + ezsp_f.getTokenData = AsyncMock(return_value=[t.EmberStatus.SUCCESS, b"\xAB" * 8]) + await ezsp_f.reset_custom_eui64() + assert ezsp_f.setTokenData.mock_calls == [ + call(t.NV3KeyId.CREATOR_STACK_RESTORED_EUI64, 0, t.LVBytes32(b"\xFF" * 8)) + ] diff --git a/tests/test_ezsp_protocol.py b/tests/test_ezsp_protocol.py index b5f58648..f9dd25d2 100644 --- a/tests/test_ezsp_protocol.py +++ b/tests/test_ezsp_protocol.py @@ -62,159 +62,17 @@ def test_receive_reply_invalid_command(prot_hndl): assert prot_hndl._handle_callback.call_count == 0 -async def test_config_initialize_husbzb1(prot_hndl): - """Test timeouts are properly set for HUSBZB-1.""" - - prot_hndl.getConfigurationValue = AsyncMock(return_value=(t.EzspStatus.SUCCESS, 0)) - prot_hndl.setConfigurationValue = AsyncMock(return_value=(t.EzspStatus.SUCCESS,)) - - await prot_hndl.initialize({"ezsp_config": {}}) - prot_hndl.setConfigurationValue.assert_has_calls( - [ - call(t.EzspConfigId.CONFIG_SOURCE_ROUTE_TABLE_SIZE, 16), - call(t.EzspConfigId.CONFIG_END_DEVICE_POLL_TIMEOUT, 60), - call(t.EzspConfigId.CONFIG_END_DEVICE_POLL_TIMEOUT_SHIFT, 8), - call(t.EzspConfigId.CONFIG_INDIRECT_TRANSMISSION_TIMEOUT, 7680), - call(t.EzspConfigId.CONFIG_STACK_PROFILE, 2), - call(t.EzspConfigId.CONFIG_SUPPORTED_NETWORKS, 1), - call(t.EzspConfigId.CONFIG_MULTICAST_TABLE_SIZE, 16), - call(t.EzspConfigId.CONFIG_TRUST_CENTER_ADDRESS_CACHE_SIZE, 2), - call(t.EzspConfigId.CONFIG_SECURITY_LEVEL, 5), - call(t.EzspConfigId.CONFIG_ADDRESS_TABLE_SIZE, 16), - call(t.EzspConfigId.CONFIG_PAN_ID_CONFLICT_REPORT_THRESHOLD, 2), - call(t.EzspConfigId.CONFIG_KEY_TABLE_SIZE, 4), - call(t.EzspConfigId.CONFIG_MAX_END_DEVICE_CHILDREN, 32), - call( - t.EzspConfigId.CONFIG_APPLICATION_ZDO_FLAGS, - ( - t.EmberZdoConfigurationFlags.APP_HANDLES_UNSUPPORTED_ZDO_REQUESTS - | t.EmberZdoConfigurationFlags.APP_RECEIVES_SUPPORTED_ZDO_REQUESTS - ), - ), - call(t.EzspConfigId.CONFIG_PACKET_BUFFER_COUNT, 255), - ] - ) - - -@pytest.mark.parametrize("prot_hndl_cls", EZSP._BY_VERSION.values()) -async def test_config_initialize(prot_hndl_cls, caplog): - """Test config initialization for all protocol versions.""" - - prot_hndl = prot_hndl_cls(MagicMock(), MagicMock()) - prot_hndl.getConfigurationValue = AsyncMock(return_value=(t.EzspStatus.SUCCESS, 0)) - prot_hndl.setConfigurationValue = AsyncMock(return_value=(t.EzspStatus.SUCCESS,)) - - prot_hndl.setValue = AsyncMock(return_value=(t.EzspStatus.SUCCESS,)) - prot_hndl.getValue = AsyncMock(return_value=(t.EzspStatus.SUCCESS, b"\xFF")) - - await prot_hndl.initialize({"ezsp_config": {}}) - - with caplog.at_level(logging.DEBUG): - prot_hndl.setConfigurationValue.return_value = ( - t.EzspStatus.ERROR_OUT_OF_MEMORY, - ) - await prot_hndl.initialize({"ezsp_config": {}}) - - assert "Could not set config" in caplog.text - prot_hndl.setConfigurationValue.return_value = (t.EzspStatus.SUCCESS,) - caplog.clear() - - # EZSPv6 does not set any values on startup - if prot_hndl_cls.VERSION < 7: - return - - prot_hndl.setValue.reset_mock() - prot_hndl.getValue.return_value = (t.EzspStatus.ERROR_INVALID_ID, b"") - await prot_hndl.initialize({"ezsp_config": {}}) - assert len(prot_hndl.setValue.mock_calls) == 1 - - prot_hndl.getValue = AsyncMock(return_value=(t.EzspStatus.SUCCESS, b"\xFF")) - caplog.clear() - - with caplog.at_level(logging.DEBUG): - prot_hndl.setValue.return_value = (t.EzspStatus.ERROR_INVALID_ID,) - await prot_hndl.initialize({"ezsp_config": {}}) - - assert "Could not set value" in caplog.text - prot_hndl.setValue.return_value = (t.EzspStatus.SUCCESS,) - caplog.clear() - - -async def test_cfg_initialize_skip(prot_hndl): - """Test initialization.""" - - p1 = patch.object( - prot_hndl, - "setConfigurationValue", - new=AsyncMock(return_value=(t.EzspStatus.SUCCESS,)), - ) - p2 = patch.object( - prot_hndl, - "getConfigurationValue", - new=AsyncMock(return_value=(t.EzspStatus.SUCCESS, 22)), - ) - p3 = patch.object(prot_hndl, "get_free_buffers", new=AsyncMock(22)) - with p1, p2, p3: - await prot_hndl.initialize( - {"ezsp_config": {"CONFIG_END_DEVICE_POLL_TIMEOUT": None}} - ) - - # Config not set when it is explicitly disabled - with pytest.raises(AssertionError): - prot_hndl.setConfigurationValue.assert_called_with( - t.EzspConfigId.CONFIG_END_DEVICE_POLL_TIMEOUT, ANY - ) - - with p1, p2, p3: - await prot_hndl.initialize( - {"ezsp_config": {"CONFIG_MULTICAST_TABLE_SIZE": 123}} - ) - - # Config is overridden - prot_hndl.setConfigurationValue.assert_any_call( - t.EzspConfigId.CONFIG_MULTICAST_TABLE_SIZE, 123 - ) - - with p1, p2, p3: - await prot_hndl.initialize({"ezsp_config": {}}) - - # Config is set by default - prot_hndl.setConfigurationValue.assert_any_call( - t.EzspConfigId.CONFIG_END_DEVICE_POLL_TIMEOUT, ANY - ) - - async def test_update_policies(prot_hndl): """Test update_policies.""" with patch.object(prot_hndl, "setPolicy", new=AsyncMock()) as pol_mock: pol_mock.return_value = (t.EzspStatus.SUCCESS,) - await prot_hndl.update_policies({"ezsp_policies": {}}) + await prot_hndl.update_policies({}) + with patch.object(prot_hndl, "setPolicy", new=AsyncMock()) as pol_mock: pol_mock.return_value = (t.EzspStatus.ERROR_OUT_OF_MEMORY,) with pytest.raises(AssertionError): - await prot_hndl.update_policies({"ezsp_policies": {}}) - - -@pytest.mark.parametrize( - "status, raw, expected_value", - ( - (t.EzspStatus.ERROR_OUT_OF_MEMORY, b"", None), - (t.EzspStatus.ERROR_OUT_OF_MEMORY, b"\x02\x02", None), - (t.EzspStatus.SUCCESS, b"\x02\x02", 514), - ), -) -async def test_get_free_buffers(prot_hndl, status, raw, expected_value): - """Test getting free buffers.""" - - p1 = patch.object(prot_hndl, "getValue", new=AsyncMock()) - with p1 as value_mock: - value_mock.return_value = (status, raw) - free_buffers = await prot_hndl.get_free_buffers() - if expected_value is None: - assert free_buffers is expected_value - else: - assert free_buffers == expected_value + await prot_hndl.update_policies({}) async def test_unknown_command(prot_hndl, caplog): @@ -226,3 +84,13 @@ async def test_unknown_command(prot_hndl, caplog): prot_hndl(bytes([0x00, 0x00, unregistered_command, 0xAB, 0xCD])) assert "0x0004 received: b'abcd' (b'000004abcd')" in caplog.text + + +async def test_logging_frame_parsing_failure(prot_hndl, caplog) -> None: + """Test logging when frame parsing fails.""" + + with caplog.at_level(logging.WARNING): + with pytest.raises(ValueError): + prot_hndl(b"\xAA\xAA\x71\x22") + + assert "Failed to parse frame getKeyTableEntry: b'22'" in caplog.text diff --git a/tests/test_ezsp_v10.py b/tests/test_ezsp_v10.py index 91c5eafb..bcf6a983 100644 --- a/tests/test_ezsp_v10.py +++ b/tests/test_ezsp_v10.py @@ -25,15 +25,6 @@ def test_ezsp_frame_rx(ezsp_f): assert ezsp_f._handle_callback.call_args[0][1] == [0x01, 0x02, 0x1234] -async def test_set_source_routing(ezsp_f): - """Test setting source routing.""" - with patch.object( - ezsp_f, "setSourceRouteDiscoveryMode", new=AsyncMock() - ) as src_mock: - await ezsp_f.set_source_routing() - assert src_mock.await_count == 1 - - async def test_pre_permit(ezsp_f): """Test pre permit.""" p1 = patch.object(ezsp_f, "setPolicy", new=AsyncMock()) diff --git a/tests/test_ezsp_v11.py b/tests/test_ezsp_v11.py index 97b81c5d..84f201c8 100644 --- a/tests/test_ezsp_v11.py +++ b/tests/test_ezsp_v11.py @@ -25,15 +25,6 @@ def test_ezsp_frame_rx(ezsp_f): assert ezsp_f._handle_callback.call_args[0][1] == [0x01, 0x02, 0x1234] -async def test_set_source_routing(ezsp_f): - """Test setting source routing.""" - with patch.object( - ezsp_f, "setSourceRouteDiscoveryMode", new=AsyncMock() - ) as src_mock: - await ezsp_f.set_source_routing() - assert src_mock.await_count == 1 - - async def test_pre_permit(ezsp_f): """Test pre permit.""" p1 = patch.object(ezsp_f, "setPolicy", new=AsyncMock()) diff --git a/tests/test_ezsp_v12.py b/tests/test_ezsp_v12.py index 322d86b1..871f2184 100644 --- a/tests/test_ezsp_v12.py +++ b/tests/test_ezsp_v12.py @@ -25,15 +25,6 @@ def test_ezsp_frame_rx(ezsp_f): assert ezsp_f._handle_callback.call_args[0][1] == [0x01, 0x02, 0x1234] -async def test_set_source_routing(ezsp_f): - """Test setting source routing.""" - with patch.object( - ezsp_f, "setSourceRouteDiscoveryMode", new=AsyncMock() - ) as src_mock: - await ezsp_f.set_source_routing() - assert src_mock.await_count == 1 - - async def test_pre_permit(ezsp_f): """Test pre permit.""" p1 = patch.object(ezsp_f, "setPolicy", new=AsyncMock()) diff --git a/tests/test_ezsp_v4.py b/tests/test_ezsp_v4.py index 54ba08a5..0fe5b237 100644 --- a/tests/test_ezsp_v4.py +++ b/tests/test_ezsp_v4.py @@ -25,6 +25,11 @@ def test_ezsp_frame_rx(ezsp_f): assert ezsp_f._handle_callback.call_args[0][1] == [0x01, 0x02, 0x1234] +async def test_pre_permit(ezsp_f): + """Test pre permit.""" + await ezsp_f.pre_permit(1.9) + + command_frames = { "addEndpoint": 0x02, "addOrUpdateKeyTableEntry": 0x66, diff --git a/tests/test_ezsp_v8.py b/tests/test_ezsp_v8.py index a39f9384..8d6c5c93 100644 --- a/tests/test_ezsp_v8.py +++ b/tests/test_ezsp_v8.py @@ -25,15 +25,6 @@ def test_ezsp_frame_rx(ezsp_f): assert ezsp_f._handle_callback.call_args[0][1] == [0x01, 0x02, 0x1234] -async def test_set_source_routing(ezsp_f): - """Test setting source routing.""" - with patch.object( - ezsp_f, "setSourceRouteDiscoveryMode", new=AsyncMock() - ) as src_mock: - await ezsp_f.set_source_routing() - assert src_mock.await_count == 1 - - async def test_pre_permit(ezsp_f): """Test pre permit.""" p1 = patch.object(ezsp_f, "setPolicy", new=AsyncMock()) diff --git a/tests/test_ezsp_v9.py b/tests/test_ezsp_v9.py index f4876b4c..b9b5806d 100644 --- a/tests/test_ezsp_v9.py +++ b/tests/test_ezsp_v9.py @@ -25,15 +25,6 @@ def test_ezsp_frame_rx(ezsp_f): assert ezsp_f._handle_callback.call_args[0][1] == [0x01, 0x02, 0x1234] -async def test_set_source_routing(ezsp_f): - """Test setting source routing.""" - with patch.object( - ezsp_f, "setSourceRouteDiscoveryMode", new=AsyncMock() - ) as src_mock: - await ezsp_f.set_source_routing() - assert src_mock.await_count == 1 - - async def test_pre_permit(ezsp_f): """Test pre permit.""" p1 = patch.object(ezsp_f, "setPolicy", new=AsyncMock()) diff --git a/tests/test_util.py b/tests/test_util.py index da76e26b..16880dcb 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -8,7 +8,7 @@ import bellows.types as bellows_t import bellows.zigbee.util as util -from tests.test_application import ezsp_mock +from tests.test_application import ezsp_mock, ieee from tests.test_application_network_state import network_info, node_info @@ -39,14 +39,12 @@ def ezsp_key(ezsp_mock, network_info, node_info, zigpy_key): def test_zha_security_normal(network_info, node_info): - security = util.zha_security( - network_info=network_info, node_info=node_info, use_hashed_tclk=True - ) + security = util.zha_security(network_info=network_info, use_hashed_tclk=True) - assert security.preconfiguredTrustCenterEui64 == bellows_t.EmberEUI64([0x00] * 8) + assert security.preconfiguredTrustCenterEui64 == node_info.ieee assert ( bellows_t.EmberInitialSecurityBitmask.HAVE_TRUST_CENTER_EUI64 - not in security.bitmask + in security.bitmask ) assert ( @@ -59,34 +57,11 @@ def test_zha_security_normal(network_info, node_info): ) -def test_zha_security_router(network_info, node_info): - security = util.zha_security( - network_info=network_info, - node_info=node_info.replace(logical_type=zdo_t.LogicalType.Router), - use_hashed_tclk=False, - ) - - assert security.preconfiguredTrustCenterEui64 == bellows_t.EmberEUI64( - network_info.tc_link_key.partner_ieee - ) - assert ( - bellows_t.EmberInitialSecurityBitmask.HAVE_TRUST_CENTER_EUI64 - in security.bitmask - ) - - assert security.preconfiguredKey == network_info.tc_link_key.key - assert ( - bellows_t.EmberInitialSecurityBitmask.TRUST_CENTER_USES_HASHED_LINK_KEY - not in security.bitmask - ) - - -def test_zha_security_router_unknown_tclk_partner_ieee(network_info, node_info): +def test_zha_security_router_unknown_tclk_partner_ieee(network_info): security = util.zha_security( network_info=network_info.replace( tc_link_key=network_info.tc_link_key.replace(partner_ieee=t.EUI64.UNKNOWN) ), - node_info=node_info.replace(logical_type=zdo_t.LogicalType.Router), use_hashed_tclk=False, ) @@ -98,25 +73,11 @@ def test_zha_security_router_unknown_tclk_partner_ieee(network_info, node_info): ) -def test_zha_security_replace_missing_tc_partner_addr(network_info, node_info): - security = util.zha_security( - network_info=network_info.replace( - tc_link_key=network_info.tc_link_key.replace(partner_ieee=t.EUI64.UNKNOWN) - ), - node_info=node_info, - use_hashed_tclk=True, - ) - - assert node_info.ieee != t.EUI64.UNKNOWN - assert security.preconfiguredTrustCenterEui64 == bellows_t.EmberEUI64([0x00] * 8) - - -def test_zha_security_hashed_nonstandard_tclk_warning(network_info, node_info, caplog): +def test_zha_security_hashed_nonstandard_tclk_warning(network_info, caplog): # Nothing should be logged normally with caplog.at_level(logging.WARNING): util.zha_security( network_info=network_info, - node_info=node_info, use_hashed_tclk=True, ) @@ -130,7 +91,6 @@ def test_zha_security_hashed_nonstandard_tclk_warning(network_info, node_info, c key=t.KeyData(b"ANonstandardTCLK") ) ), - node_info=node_info, use_hashed_tclk=True, ) diff --git a/tests/test_zigbee_repairs.py b/tests/test_zigbee_repairs.py new file mode 100644 index 00000000..59562db3 --- /dev/null +++ b/tests/test_zigbee_repairs.py @@ -0,0 +1,114 @@ +"""Test network state repairs.""" + +import logging +from unittest.mock import AsyncMock, call + +import pytest + +from bellows.exception import InvalidCommandError +from bellows.ezsp import EZSP +import bellows.types as t +from bellows.zigbee import repairs + +from tests.test_ezsp import ezsp_f + + +@pytest.fixture +def ezsp_tclk_f(ezsp_f: EZSP) -> EZSP: + """Mock an EZSP instance with a valid TCLK.""" + ezsp_f.getEui64 = AsyncMock( + return_value=[t.EmberEUI64.convert("AA:AA:AA:AA:AA:AA:AA:AA")] + ) + ezsp_f.getTokenData = AsyncMock(side_effect=InvalidCommandError()) + ezsp_f.getCurrentSecurityState = AsyncMock( + return_value=[ + t.EmberStatus.SUCCESS, + t.EmberCurrentSecurityState( + bitmask=( + t.EmberCurrentSecurityBitmask.GLOBAL_LINK_KEY + | t.EmberCurrentSecurityBitmask.HAVE_TRUST_CENTER_LINK_KEY + | 224 + ), + trustCenterLongAddress=t.EmberEUI64.convert("AA:AA:AA:AA:AA:AA:AA:AA"), + ), + ] + ) + return ezsp_f + + +async def test_fix_invalid_tclk_noop(ezsp_tclk_f: EZSP, caplog) -> None: + """Test that the TCLK is not rewritten unnecessarily.""" + + ezsp_tclk_f.getEui64.return_value[0] = t.EmberEUI64.convert( + "AA:AA:AA:AA:AA:AA:AA:AA" + ) + ezsp_tclk_f.getCurrentSecurityState.return_value[ + 1 + ].trustCenterLongAddress = t.EmberEUI64.convert("AA:AA:AA:AA:AA:AA:AA:AA") + + with caplog.at_level(logging.WARNING): + assert await repairs.fix_invalid_tclk_partner_ieee(ezsp_tclk_f) is False + + assert "Fixing invalid TCLK" not in caplog.text + + +async def test_fix_invalid_tclk_old_firmware(ezsp_tclk_f: EZSP, caplog) -> None: + """Test that the TCLK is not rewritten when the firmware is too old.""" + + ezsp_tclk_f.getTokenData = AsyncMock(side_effect=InvalidCommandError()) + ezsp_tclk_f.getEui64.return_value[0] = t.EmberEUI64.convert( + "AA:AA:AA:AA:AA:AA:AA:AA" + ) + ezsp_tclk_f.getCurrentSecurityState.return_value[ + 1 + ].trustCenterLongAddress = t.EmberEUI64.convert("BB:BB:BB:BB:BB:BB:BB:BB") + + with caplog.at_level(logging.WARNING): + assert await repairs.fix_invalid_tclk_partner_ieee(ezsp_tclk_f) is False + + assert "Fixing invalid TCLK" in caplog.text + assert "NV3 interface not available in this firmware" in caplog.text + + +async def test_fix_invalid_tclk(ezsp_tclk_f: EZSP, caplog) -> None: + """Test that the TCLK is not rewritten when the firmware is too old.""" + + ezsp_tclk_f.setTokenData = AsyncMock(return_value=[t.EmberStatus.SUCCESS]) + ezsp_tclk_f.getTokenData = AsyncMock( + return_value=[ + t.EmberStatus.SUCCESS, + t.NV3StackTrustCenterToken( + mode=228, + eui64=t.EmberEUI64.convert("BB:BB:BB:BB:BB:BB:BB:BB"), + key=t.EmberKeyData.convert( + "21:8e:df:b8:50:a0:4a:b6:8b:c6:10:25:bc:4e:93:6a" + ), + ).serialize(), + ] + ) + ezsp_tclk_f.getEui64.return_value[0] = t.EmberEUI64.convert( + "AA:AA:AA:AA:AA:AA:AA:AA" + ) + ezsp_tclk_f.getCurrentSecurityState.return_value[ + 1 + ].trustCenterLongAddress = t.EmberEUI64.convert("BB:BB:BB:BB:BB:BB:BB:BB") + + with caplog.at_level(logging.WARNING): + assert await repairs.fix_invalid_tclk_partner_ieee(ezsp_tclk_f) is True + + assert "Fixing invalid TCLK" in caplog.text + assert "NV3 interface not available in this firmware" not in caplog.text + + assert ezsp_tclk_f.setTokenData.mock_calls == [ + call( + t.NV3KeyId.NVM3KEY_STACK_TRUST_CENTER, + 0, + t.NV3StackTrustCenterToken( + mode=228, + eui64=t.EmberEUI64.convert("AA:AA:AA:AA:AA:AA:AA:AA"), + key=t.EmberKeyData.convert( + "21:8e:df:b8:50:a0:4a:b6:8b:c6:10:25:bc:4e:93:6a" + ), + ).serialize(), + ) + ]