Skip to content

Commit

Permalink
Require specifying extended address in otbr WS API calls (#108282)
Browse files Browse the repository at this point in the history
Co-authored-by: Stefan Agner <[email protected]>
  • Loading branch information
emontnemery and agners authored Jul 17, 2024
1 parent 14ec7e5 commit 054242f
Show file tree
Hide file tree
Showing 2 changed files with 395 additions and 69 deletions.
95 changes: 69 additions & 26 deletions homeassistant/components/otbr/websocket_api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Websocket API for OTBR."""

from typing import cast
from collections.abc import Callable, Coroutine
from functools import wraps
from typing import Any, cast

import python_otbr_api
from python_otbr_api import PENDING_DATASET_DELAY_TIMER, tlv_parser
Expand Down Expand Up @@ -55,42 +57,84 @@ async def websocket_info(
border_agent_id = await data.get_border_agent_id()
dataset = await data.get_active_dataset()
dataset_tlvs = await data.get_active_dataset_tlvs()
extended_address = await data.get_extended_address()
extended_address = (await data.get_extended_address()).hex()
except HomeAssistantError as exc:
connection.send_error(msg["id"], "otbr_info_failed", str(exc))
return

# The border agent ID is checked when the OTBR config entry is setup,
# we can assert it's not None
assert border_agent_id is not None

extended_pan_id = (
dataset.extended_pan_id.lower() if dataset and dataset.extended_pan_id else None
)
connection.send_result(
msg["id"],
{
"active_dataset_tlvs": dataset_tlvs.hex() if dataset_tlvs else None,
"border_agent_id": border_agent_id.hex(),
"channel": dataset.channel if dataset else None,
"extended_address": extended_address.hex(),
"url": data.url,
extended_address: {
"active_dataset_tlvs": dataset_tlvs.hex() if dataset_tlvs else None,
"border_agent_id": border_agent_id.hex(),
"channel": dataset.channel if dataset else None,
"extended_address": extended_address,
"extended_pan_id": extended_pan_id,
"url": data.url,
}
},
)


def async_get_otbr_data(
orig_func: Callable[
[HomeAssistant, websocket_api.ActiveConnection, dict, OTBRData],
Coroutine[Any, Any, None],
],
) -> Callable[
[HomeAssistant, websocket_api.ActiveConnection, dict], Coroutine[Any, Any, None]
]:
"""Decorate function to get OTBR data."""

@wraps(orig_func)
async def async_check_extended_address_func(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
) -> None:
"""Fetch OTBR data and pass to orig_func."""
if DOMAIN not in hass.data:
connection.send_error(msg["id"], "not_loaded", "No OTBR API loaded")
return

data: OTBRData = hass.data[DOMAIN]

try:
extended_address = await data.get_extended_address()
except HomeAssistantError as exc:
connection.send_error(msg["id"], "get_extended_address_failed", str(exc))
return
if extended_address.hex() != msg["extended_address"]:
connection.send_error(msg["id"], "unknown_router", "")
return

await orig_func(hass, connection, msg, data)

return async_check_extended_address_func


@websocket_api.websocket_command(
{
"type": "otbr/create_network",
vol.Required("extended_address"): str,
}
)
@websocket_api.require_admin
@websocket_api.async_response
@async_get_otbr_data
async def websocket_create_network(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: dict,
data: OTBRData,
) -> None:
"""Create a new Thread network."""
if DOMAIN not in hass.data:
connection.send_error(msg["id"], "not_loaded", "No OTBR API loaded")
return

data: OTBRData = hass.data[DOMAIN]
channel = await get_allowed_channel(hass, data.url) or DEFAULT_CHANNEL

try:
Expand Down Expand Up @@ -144,19 +188,20 @@ async def websocket_create_network(
@websocket_api.websocket_command(
{
"type": "otbr/set_network",
vol.Required("extended_address"): str,
vol.Required("dataset_id"): str,
}
)
@websocket_api.require_admin
@websocket_api.async_response
@async_get_otbr_data
async def websocket_set_network(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: dict,
data: OTBRData,
) -> None:
"""Set the Thread network to be used by the OTBR."""
if DOMAIN not in hass.data:
connection.send_error(msg["id"], "not_loaded", "No OTBR API loaded")
return

dataset_tlv = await async_get_dataset(hass, msg["dataset_id"])

if not dataset_tlv:
Expand All @@ -166,7 +211,6 @@ async def websocket_set_network(
if channel := dataset.get(MeshcopTLVType.CHANNEL):
thread_dataset_channel = cast(tlv_parser.Channel, channel).channel

data: OTBRData = hass.data[DOMAIN]
allowed_channel = await get_allowed_channel(hass, data.url)

if allowed_channel and thread_dataset_channel != allowed_channel:
Expand Down Expand Up @@ -205,21 +249,20 @@ async def websocket_set_network(
@websocket_api.websocket_command(
{
"type": "otbr/set_channel",
vol.Required("extended_address"): str,
vol.Required("channel"): int,
}
)
@websocket_api.require_admin
@websocket_api.async_response
@async_get_otbr_data
async def websocket_set_channel(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: dict,
data: OTBRData,
) -> None:
"""Set current channel."""
if DOMAIN not in hass.data:
connection.send_error(msg["id"], "not_loaded", "No OTBR API loaded")
return

data: OTBRData = hass.data[DOMAIN]

if is_multiprotocol_url(data.url):
connection.send_error(
msg["id"],
Expand Down
Loading

0 comments on commit 054242f

Please sign in to comment.