Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Require specifying extended address in otbr WS API calls #108282

Merged
merged 9 commits into from
Jul 17, 2024
95 changes: 69 additions & 26 deletions homeassistant/components/otbr/websocket_api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Websocket API for OTBR."""

from typing import cast
from collections.abc import Callable, Coroutine
from functools import wraps
from typing import Any, cast

import python_otbr_api
from python_otbr_api import PENDING_DATASET_DELAY_TIMER, tlv_parser
Expand Down Expand Up @@ -55,42 +57,84 @@ async def websocket_info(
border_agent_id = await data.get_border_agent_id()
dataset = await data.get_active_dataset()
dataset_tlvs = await data.get_active_dataset_tlvs()
extended_address = await data.get_extended_address()
extended_address = (await data.get_extended_address()).hex()
except HomeAssistantError as exc:
connection.send_error(msg["id"], "otbr_info_failed", str(exc))
return

# The border agent ID is checked when the OTBR config entry is setup,
# we can assert it's not None
assert border_agent_id is not None

extended_pan_id = (
dataset.extended_pan_id.lower() if dataset and dataset.extended_pan_id else None
)
connection.send_result(
msg["id"],
{
"active_dataset_tlvs": dataset_tlvs.hex() if dataset_tlvs else None,
"border_agent_id": border_agent_id.hex(),
"channel": dataset.channel if dataset else None,
"extended_address": extended_address.hex(),
"url": data.url,
extended_address: {
"active_dataset_tlvs": dataset_tlvs.hex() if dataset_tlvs else None,
"border_agent_id": border_agent_id.hex(),
"channel": dataset.channel if dataset else None,
"extended_address": extended_address,
"extended_pan_id": extended_pan_id,
"url": data.url,
}
},
)


def async_get_otbr_data(
orig_func: Callable[
[HomeAssistant, websocket_api.ActiveConnection, dict, OTBRData],
Coroutine[Any, Any, None],
],
) -> Callable[
[HomeAssistant, websocket_api.ActiveConnection, dict], Coroutine[Any, Any, None]
]:
"""Decorate function to get OTBR data."""

@wraps(orig_func)
async def async_check_extended_address_func(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
) -> None:
"""Fetch OTBR data and pass to orig_func."""
if DOMAIN not in hass.data:
connection.send_error(msg["id"], "not_loaded", "No OTBR API loaded")
return

data: OTBRData = hass.data[DOMAIN]

try:
extended_address = await data.get_extended_address()
except HomeAssistantError as exc:
connection.send_error(msg["id"], "get_extended_address_failed", str(exc))
return
if extended_address.hex() != msg["extended_address"]:
connection.send_error(msg["id"], "unknown_router", "")
return

await orig_func(hass, connection, msg, data)

return async_check_extended_address_func


@websocket_api.websocket_command(
{
"type": "otbr/create_network",
vol.Required("extended_address"): str,
}
)
@websocket_api.require_admin
@websocket_api.async_response
@async_get_otbr_data
async def websocket_create_network(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: dict,
data: OTBRData,
) -> None:
"""Create a new Thread network."""
if DOMAIN not in hass.data:
connection.send_error(msg["id"], "not_loaded", "No OTBR API loaded")
return

data: OTBRData = hass.data[DOMAIN]
channel = await get_allowed_channel(hass, data.url) or DEFAULT_CHANNEL

try:
Expand Down Expand Up @@ -144,19 +188,20 @@ async def websocket_create_network(
@websocket_api.websocket_command(
{
"type": "otbr/set_network",
vol.Required("extended_address"): str,
vol.Required("dataset_id"): str,
}
)
@websocket_api.require_admin
@websocket_api.async_response
@async_get_otbr_data
async def websocket_set_network(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: dict,
data: OTBRData,
) -> None:
"""Set the Thread network to be used by the OTBR."""
if DOMAIN not in hass.data:
connection.send_error(msg["id"], "not_loaded", "No OTBR API loaded")
return

dataset_tlv = await async_get_dataset(hass, msg["dataset_id"])

if not dataset_tlv:
Expand All @@ -166,7 +211,6 @@ async def websocket_set_network(
if channel := dataset.get(MeshcopTLVType.CHANNEL):
thread_dataset_channel = cast(tlv_parser.Channel, channel).channel

data: OTBRData = hass.data[DOMAIN]
allowed_channel = await get_allowed_channel(hass, data.url)

if allowed_channel and thread_dataset_channel != allowed_channel:
Expand Down Expand Up @@ -205,21 +249,20 @@ async def websocket_set_network(
@websocket_api.websocket_command(
{
"type": "otbr/set_channel",
vol.Required("extended_address"): str,
vol.Required("channel"): int,
}
)
@websocket_api.require_admin
@websocket_api.async_response
@async_get_otbr_data
async def websocket_set_channel(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: dict,
data: OTBRData,
) -> None:
"""Set current channel."""
if DOMAIN not in hass.data:
connection.send_error(msg["id"], "not_loaded", "No OTBR API loaded")
return

data: OTBRData = hass.data[DOMAIN]

if is_multiprotocol_url(data.url):
connection.send_error(
msg["id"],
Expand Down
Loading