From 054242ff0fe419ba2ccf7ef2b58481297bc25b1a Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Wed, 17 Jul 2024 09:04:54 +0200 Subject: [PATCH] Require specifying extended address in otbr WS API calls (#108282) Co-authored-by: Stefan Agner --- .../components/otbr/websocket_api.py | 95 +++-- tests/components/otbr/test_websocket_api.py | 369 ++++++++++++++++-- 2 files changed, 395 insertions(+), 69 deletions(-) diff --git a/homeassistant/components/otbr/websocket_api.py b/homeassistant/components/otbr/websocket_api.py index 163152a4bffdc9..9b7e46bc362b54 100644 --- a/homeassistant/components/otbr/websocket_api.py +++ b/homeassistant/components/otbr/websocket_api.py @@ -1,6 +1,8 @@ """Websocket API for OTBR.""" -from typing import cast +from collections.abc import Callable, Coroutine +from functools import wraps +from typing import Any, cast import python_otbr_api from python_otbr_api import PENDING_DATASET_DELAY_TIMER, tlv_parser @@ -55,7 +57,7 @@ async def websocket_info( border_agent_id = await data.get_border_agent_id() dataset = await data.get_active_dataset() dataset_tlvs = await data.get_active_dataset_tlvs() - extended_address = await data.get_extended_address() + extended_address = (await data.get_extended_address()).hex() except HomeAssistantError as exc: connection.send_error(msg["id"], "otbr_info_failed", str(exc)) return @@ -63,34 +65,76 @@ async def websocket_info( # The border agent ID is checked when the OTBR config entry is setup, # we can assert it's not None assert border_agent_id is not None + + extended_pan_id = ( + dataset.extended_pan_id.lower() if dataset and dataset.extended_pan_id else None + ) connection.send_result( msg["id"], { - "active_dataset_tlvs": dataset_tlvs.hex() if dataset_tlvs else None, - "border_agent_id": border_agent_id.hex(), - "channel": dataset.channel if dataset else None, - "extended_address": extended_address.hex(), - "url": data.url, + extended_address: { + "active_dataset_tlvs": dataset_tlvs.hex() if dataset_tlvs else None, + "border_agent_id": border_agent_id.hex(), + "channel": dataset.channel if dataset else None, + "extended_address": extended_address, + "extended_pan_id": extended_pan_id, + "url": data.url, + } }, ) +def async_get_otbr_data( + orig_func: Callable[ + [HomeAssistant, websocket_api.ActiveConnection, dict, OTBRData], + Coroutine[Any, Any, None], + ], +) -> Callable[ + [HomeAssistant, websocket_api.ActiveConnection, dict], Coroutine[Any, Any, None] +]: + """Decorate function to get OTBR data.""" + + @wraps(orig_func) + async def async_check_extended_address_func( + hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict + ) -> None: + """Fetch OTBR data and pass to orig_func.""" + if DOMAIN not in hass.data: + connection.send_error(msg["id"], "not_loaded", "No OTBR API loaded") + return + + data: OTBRData = hass.data[DOMAIN] + + try: + extended_address = await data.get_extended_address() + except HomeAssistantError as exc: + connection.send_error(msg["id"], "get_extended_address_failed", str(exc)) + return + if extended_address.hex() != msg["extended_address"]: + connection.send_error(msg["id"], "unknown_router", "") + return + + await orig_func(hass, connection, msg, data) + + return async_check_extended_address_func + + @websocket_api.websocket_command( { "type": "otbr/create_network", + vol.Required("extended_address"): str, } ) @websocket_api.require_admin @websocket_api.async_response +@async_get_otbr_data async def websocket_create_network( - hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict + hass: HomeAssistant, + connection: websocket_api.ActiveConnection, + msg: dict, + data: OTBRData, ) -> None: """Create a new Thread network.""" - if DOMAIN not in hass.data: - connection.send_error(msg["id"], "not_loaded", "No OTBR API loaded") - return - - data: OTBRData = hass.data[DOMAIN] channel = await get_allowed_channel(hass, data.url) or DEFAULT_CHANNEL try: @@ -144,19 +188,20 @@ async def websocket_create_network( @websocket_api.websocket_command( { "type": "otbr/set_network", + vol.Required("extended_address"): str, vol.Required("dataset_id"): str, } ) @websocket_api.require_admin @websocket_api.async_response +@async_get_otbr_data async def websocket_set_network( - hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict + hass: HomeAssistant, + connection: websocket_api.ActiveConnection, + msg: dict, + data: OTBRData, ) -> None: """Set the Thread network to be used by the OTBR.""" - if DOMAIN not in hass.data: - connection.send_error(msg["id"], "not_loaded", "No OTBR API loaded") - return - dataset_tlv = await async_get_dataset(hass, msg["dataset_id"]) if not dataset_tlv: @@ -166,7 +211,6 @@ async def websocket_set_network( if channel := dataset.get(MeshcopTLVType.CHANNEL): thread_dataset_channel = cast(tlv_parser.Channel, channel).channel - data: OTBRData = hass.data[DOMAIN] allowed_channel = await get_allowed_channel(hass, data.url) if allowed_channel and thread_dataset_channel != allowed_channel: @@ -205,21 +249,20 @@ async def websocket_set_network( @websocket_api.websocket_command( { "type": "otbr/set_channel", + vol.Required("extended_address"): str, vol.Required("channel"): int, } ) @websocket_api.require_admin @websocket_api.async_response +@async_get_otbr_data async def websocket_set_channel( - hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict + hass: HomeAssistant, + connection: websocket_api.ActiveConnection, + msg: dict, + data: OTBRData, ) -> None: """Set current channel.""" - if DOMAIN not in hass.data: - connection.send_error(msg["id"], "not_loaded", "No OTBR API loaded") - return - - data: OTBRData = hass.data[DOMAIN] - if is_multiprotocol_url(data.url): connection.send_error( msg["id"], diff --git a/tests/components/otbr/test_websocket_api.py b/tests/components/otbr/test_websocket_api.py index df55d38d3b73ff..5361b56c688103 100644 --- a/tests/components/otbr/test_websocket_api.py +++ b/tests/components/otbr/test_websocket_api.py @@ -36,11 +36,14 @@ async def test_get_info( websocket_client, ) -> None: """Test async_get_info.""" + extended_pan_id = "ABCD1234" with ( patch( "python_otbr_api.OTBR.get_active_dataset", - return_value=python_otbr_api.ActiveDataSet(channel=16), + return_value=python_otbr_api.ActiveDataSet( + channel=16, extended_pan_id=extended_pan_id + ), ), patch( "python_otbr_api.OTBR.get_active_dataset_tlvs", return_value=DATASET_CH16 @@ -58,12 +61,16 @@ async def test_get_info( msg = await websocket_client.receive_json() assert msg["success"] + extended_address = TEST_BORDER_AGENT_EXTENDED_ADDRESS.hex() assert msg["result"] == { - "url": BASE_URL, - "active_dataset_tlvs": DATASET_CH16.hex().lower(), - "channel": 16, - "border_agent_id": TEST_BORDER_AGENT_ID.hex(), - "extended_address": TEST_BORDER_AGENT_EXTENDED_ADDRESS.hex(), + extended_address: { + "url": BASE_URL, + "active_dataset_tlvs": DATASET_CH16.hex().lower(), + "channel": 16, + "border_agent_id": TEST_BORDER_AGENT_ID.hex(), + "extended_address": extended_address, + "extended_pan_id": extended_pan_id.lower(), + } } @@ -121,6 +128,10 @@ async def test_create_network( patch( "python_otbr_api.OTBR.get_active_dataset_tlvs", return_value=DATASET_CH16 ) as get_active_dataset_tlvs_mock, + patch( + "python_otbr_api.OTBR.get_extended_address", + return_value=TEST_BORDER_AGENT_EXTENDED_ADDRESS, + ), patch( "homeassistant.components.thread.dataset_store.DatasetStore.async_add" ) as mock_add, @@ -129,7 +140,12 @@ async def test_create_network( return_value=0x1234, ), ): - await websocket_client.send_json_auto_id({"type": "otbr/create_network"}) + await websocket_client.send_json_auto_id( + { + "type": "otbr/create_network", + "extended_address": TEST_BORDER_AGENT_EXTENDED_ADDRESS.hex(), + } + ) msg = await websocket_client.receive_json() assert msg["success"] @@ -156,7 +172,9 @@ async def test_create_network_no_entry( """Test create network.""" await async_setup_component(hass, "otbr", {}) websocket_client = await hass_ws_client(hass) - await websocket_client.send_json_auto_id({"type": "otbr/create_network"}) + await websocket_client.send_json_auto_id( + {"type": "otbr/create_network", "extended_address": "blah"} + ) msg = await websocket_client.receive_json() assert not msg["success"] @@ -170,11 +188,22 @@ async def test_create_network_fails_1( websocket_client, ) -> None: """Test create network.""" - with patch( - "python_otbr_api.OTBR.set_enabled", - side_effect=python_otbr_api.OTBRError, + with ( + patch( + "python_otbr_api.OTBR.set_enabled", + side_effect=python_otbr_api.OTBRError, + ), + patch( + "python_otbr_api.OTBR.get_extended_address", + return_value=TEST_BORDER_AGENT_EXTENDED_ADDRESS, + ), ): - await websocket_client.send_json_auto_id({"type": "otbr/create_network"}) + await websocket_client.send_json_auto_id( + { + "type": "otbr/create_network", + "extended_address": TEST_BORDER_AGENT_EXTENDED_ADDRESS.hex(), + } + ) msg = await websocket_client.receive_json() assert not msg["success"] @@ -197,8 +226,17 @@ async def test_create_network_fails_2( side_effect=python_otbr_api.OTBRError, ), patch("python_otbr_api.OTBR.factory_reset"), + patch( + "python_otbr_api.OTBR.get_extended_address", + return_value=TEST_BORDER_AGENT_EXTENDED_ADDRESS, + ), ): - await websocket_client.send_json_auto_id({"type": "otbr/create_network"}) + await websocket_client.send_json_auto_id( + { + "type": "otbr/create_network", + "extended_address": TEST_BORDER_AGENT_EXTENDED_ADDRESS.hex(), + } + ) msg = await websocket_client.receive_json() assert not msg["success"] @@ -223,8 +261,17 @@ async def test_create_network_fails_3( patch( "python_otbr_api.OTBR.factory_reset", ), + patch( + "python_otbr_api.OTBR.get_extended_address", + return_value=TEST_BORDER_AGENT_EXTENDED_ADDRESS, + ), ): - await websocket_client.send_json_auto_id({"type": "otbr/create_network"}) + await websocket_client.send_json_auto_id( + { + "type": "otbr/create_network", + "extended_address": TEST_BORDER_AGENT_EXTENDED_ADDRESS.hex(), + } + ) msg = await websocket_client.receive_json() assert not msg["success"] @@ -248,8 +295,17 @@ async def test_create_network_fails_4( patch( "python_otbr_api.OTBR.factory_reset", ), + patch( + "python_otbr_api.OTBR.get_extended_address", + return_value=TEST_BORDER_AGENT_EXTENDED_ADDRESS, + ), ): - await websocket_client.send_json_auto_id({"type": "otbr/create_network"}) + await websocket_client.send_json_auto_id( + { + "type": "otbr/create_network", + "extended_address": TEST_BORDER_AGENT_EXTENDED_ADDRESS.hex(), + } + ) msg = await websocket_client.receive_json() assert not msg["success"] @@ -268,8 +324,17 @@ async def test_create_network_fails_5( patch("python_otbr_api.OTBR.create_active_dataset"), patch("python_otbr_api.OTBR.get_active_dataset_tlvs", return_value=None), patch("python_otbr_api.OTBR.factory_reset"), + patch( + "python_otbr_api.OTBR.get_extended_address", + return_value=TEST_BORDER_AGENT_EXTENDED_ADDRESS, + ), ): - await websocket_client.send_json_auto_id({"type": "otbr/create_network"}) + await websocket_client.send_json_auto_id( + { + "type": "otbr/create_network", + "extended_address": TEST_BORDER_AGENT_EXTENDED_ADDRESS.hex(), + } + ) msg = await websocket_client.receive_json() assert not msg["success"] @@ -291,14 +356,69 @@ async def test_create_network_fails_6( "python_otbr_api.OTBR.factory_reset", side_effect=python_otbr_api.OTBRError, ), + patch( + "python_otbr_api.OTBR.get_extended_address", + return_value=TEST_BORDER_AGENT_EXTENDED_ADDRESS, + ), ): - await websocket_client.send_json_auto_id({"type": "otbr/create_network"}) + await websocket_client.send_json_auto_id( + { + "type": "otbr/create_network", + "extended_address": TEST_BORDER_AGENT_EXTENDED_ADDRESS.hex(), + } + ) msg = await websocket_client.receive_json() assert not msg["success"] assert msg["error"]["code"] == "factory_reset_failed" +async def test_create_network_fails_7( + hass: HomeAssistant, + aioclient_mock: AiohttpClientMocker, + otbr_config_entry_multipan, + websocket_client, +) -> None: + """Test create network.""" + with patch( + "python_otbr_api.OTBR.get_extended_address", + side_effect=python_otbr_api.OTBRError, + ): + await websocket_client.send_json_auto_id( + { + "type": "otbr/create_network", + "extended_address": TEST_BORDER_AGENT_EXTENDED_ADDRESS.hex(), + } + ) + msg = await websocket_client.receive_json() + + assert not msg["success"] + assert msg["error"]["code"] == "get_extended_address_failed" + + +async def test_create_network_fails_8( + hass: HomeAssistant, + aioclient_mock: AiohttpClientMocker, + otbr_config_entry_multipan, + websocket_client, +) -> None: + """Test create network.""" + with patch( + "python_otbr_api.OTBR.get_extended_address", + return_value=TEST_BORDER_AGENT_EXTENDED_ADDRESS, + ): + await websocket_client.send_json_auto_id( + { + "type": "otbr/create_network", + "extended_address": "blah", + } + ) + msg = await websocket_client.receive_json() + + assert not msg["success"] + assert msg["error"]["code"] == "unknown_router" + + async def test_set_network( hass: HomeAssistant, aioclient_mock: AiohttpClientMocker, @@ -312,6 +432,10 @@ async def test_set_network( dataset_id = list(dataset_store.datasets)[1] with ( + patch( + "python_otbr_api.OTBR.get_extended_address", + return_value=TEST_BORDER_AGENT_EXTENDED_ADDRESS, + ), patch( "python_otbr_api.OTBR.set_active_dataset_tlvs" ) as set_active_dataset_tlvs_mock, @@ -320,6 +444,7 @@ async def test_set_network( await websocket_client.send_json_auto_id( { "type": "otbr/set_network", + "extended_address": TEST_BORDER_AGENT_EXTENDED_ADDRESS.hex(), "dataset_id": dataset_id, } ) @@ -345,6 +470,7 @@ async def test_set_network_no_entry( await websocket_client.send_json_auto_id( { "type": "otbr/set_network", + "extended_address": TEST_BORDER_AGENT_EXTENDED_ADDRESS.hex(), "dataset_id": "abc", } ) @@ -368,14 +494,19 @@ async def test_set_network_channel_conflict( multiprotocol_addon_manager_mock.async_get_channel.return_value = 15 - await websocket_client.send_json_auto_id( - { - "type": "otbr/set_network", - "dataset_id": dataset_id, - } - ) + with patch( + "python_otbr_api.OTBR.get_extended_address", + return_value=TEST_BORDER_AGENT_EXTENDED_ADDRESS, + ): + await websocket_client.send_json_auto_id( + { + "type": "otbr/set_network", + "extended_address": TEST_BORDER_AGENT_EXTENDED_ADDRESS.hex(), + "dataset_id": dataset_id, + } + ) - msg = await websocket_client.receive_json() + msg = await websocket_client.receive_json() assert not msg["success"] assert msg["error"]["code"] == "channel_conflict" @@ -389,14 +520,19 @@ async def test_set_network_unknown_dataset( ) -> None: """Test set network.""" - await websocket_client.send_json_auto_id( - { - "type": "otbr/set_network", - "dataset_id": "abc", - } - ) + with patch( + "python_otbr_api.OTBR.get_extended_address", + return_value=TEST_BORDER_AGENT_EXTENDED_ADDRESS, + ): + await websocket_client.send_json_auto_id( + { + "type": "otbr/set_network", + "extended_address": TEST_BORDER_AGENT_EXTENDED_ADDRESS.hex(), + "dataset_id": "abc", + } + ) - msg = await websocket_client.receive_json() + msg = await websocket_client.receive_json() assert not msg["success"] assert msg["error"]["code"] == "unknown_dataset" @@ -413,13 +549,20 @@ async def test_set_network_fails_1( dataset_store = await thread.dataset_store.async_get_store(hass) dataset_id = list(dataset_store.datasets)[1] - with patch( - "python_otbr_api.OTBR.set_enabled", - side_effect=python_otbr_api.OTBRError, + with ( + patch( + "python_otbr_api.OTBR.get_extended_address", + return_value=TEST_BORDER_AGENT_EXTENDED_ADDRESS, + ), + patch( + "python_otbr_api.OTBR.set_enabled", + side_effect=python_otbr_api.OTBRError, + ), ): await websocket_client.send_json_auto_id( { "type": "otbr/set_network", + "extended_address": TEST_BORDER_AGENT_EXTENDED_ADDRESS.hex(), "dataset_id": dataset_id, } ) @@ -441,6 +584,10 @@ async def test_set_network_fails_2( dataset_id = list(dataset_store.datasets)[1] with ( + patch( + "python_otbr_api.OTBR.get_extended_address", + return_value=TEST_BORDER_AGENT_EXTENDED_ADDRESS, + ), patch( "python_otbr_api.OTBR.set_enabled", ), @@ -452,6 +599,7 @@ async def test_set_network_fails_2( await websocket_client.send_json_auto_id( { "type": "otbr/set_network", + "extended_address": TEST_BORDER_AGENT_EXTENDED_ADDRESS.hex(), "dataset_id": dataset_id, } ) @@ -473,6 +621,10 @@ async def test_set_network_fails_3( dataset_id = list(dataset_store.datasets)[1] with ( + patch( + "python_otbr_api.OTBR.get_extended_address", + return_value=TEST_BORDER_AGENT_EXTENDED_ADDRESS, + ), patch( "python_otbr_api.OTBR.set_enabled", side_effect=[None, python_otbr_api.OTBRError], @@ -484,6 +636,7 @@ async def test_set_network_fails_3( await websocket_client.send_json_auto_id( { "type": "otbr/set_network", + "extended_address": TEST_BORDER_AGENT_EXTENDED_ADDRESS.hex(), "dataset_id": dataset_id, } ) @@ -493,6 +646,54 @@ async def test_set_network_fails_3( assert msg["error"]["code"] == "set_enabled_failed" +async def test_set_network_fails_4( + hass: HomeAssistant, + aioclient_mock: AiohttpClientMocker, + otbr_config_entry_multipan, + websocket_client, +) -> None: + """Test set network.""" + with patch( + "python_otbr_api.OTBR.get_extended_address", + side_effect=python_otbr_api.OTBRError, + ): + await websocket_client.send_json_auto_id( + { + "type": "otbr/set_network", + "extended_address": TEST_BORDER_AGENT_EXTENDED_ADDRESS.hex(), + "dataset_id": "abc", + } + ) + msg = await websocket_client.receive_json() + + assert not msg["success"] + assert msg["error"]["code"] == "get_extended_address_failed" + + +async def test_set_network_fails_5( + hass: HomeAssistant, + aioclient_mock: AiohttpClientMocker, + otbr_config_entry_multipan, + websocket_client, +) -> None: + """Test set network.""" + with patch( + "python_otbr_api.OTBR.get_extended_address", + return_value=TEST_BORDER_AGENT_EXTENDED_ADDRESS, + ): + await websocket_client.send_json_auto_id( + { + "type": "otbr/set_network", + "extended_address": "blah", + "dataset_id": "abc", + } + ) + msg = await websocket_client.receive_json() + + assert not msg["success"] + assert msg["error"]["code"] == "unknown_router" + + async def test_set_channel( hass: HomeAssistant, aioclient_mock: AiohttpClientMocker, @@ -501,9 +702,19 @@ async def test_set_channel( ) -> None: """Test set channel.""" - with patch("python_otbr_api.OTBR.set_channel"): + with ( + patch( + "python_otbr_api.OTBR.get_extended_address", + return_value=TEST_BORDER_AGENT_EXTENDED_ADDRESS, + ), + patch("python_otbr_api.OTBR.set_channel"), + ): await websocket_client.send_json_auto_id( - {"type": "otbr/set_channel", "channel": 12} + { + "type": "otbr/set_channel", + "extended_address": TEST_BORDER_AGENT_EXTENDED_ADDRESS.hex(), + "channel": 12, + } ) msg = await websocket_client.receive_json() @@ -519,9 +730,19 @@ async def test_set_channel_multiprotocol( ) -> None: """Test set channel.""" - with patch("python_otbr_api.OTBR.set_channel"): + with ( + patch( + "python_otbr_api.OTBR.get_extended_address", + return_value=TEST_BORDER_AGENT_EXTENDED_ADDRESS, + ), + patch("python_otbr_api.OTBR.set_channel"), + ): await websocket_client.send_json_auto_id( - {"type": "otbr/set_channel", "channel": 12} + { + "type": "otbr/set_channel", + "extended_address": TEST_BORDER_AGENT_EXTENDED_ADDRESS.hex(), + "channel": 12, + } ) msg = await websocket_client.receive_json() @@ -538,7 +759,11 @@ async def test_set_channel_no_entry( await async_setup_component(hass, "otbr", {}) websocket_client = await hass_ws_client(hass) await websocket_client.send_json_auto_id( - {"type": "otbr/set_channel", "channel": 12} + { + "type": "otbr/set_channel", + "extended_address": TEST_BORDER_AGENT_EXTENDED_ADDRESS.hex(), + "channel": 12, + } ) msg = await websocket_client.receive_json() @@ -546,21 +771,79 @@ async def test_set_channel_no_entry( assert msg["error"]["code"] == "not_loaded" -async def test_set_channel_fails( +async def test_set_channel_fails_1( hass: HomeAssistant, aioclient_mock: AiohttpClientMocker, otbr_config_entry_thread, websocket_client, +) -> None: + """Test set channel.""" + with ( + patch( + "python_otbr_api.OTBR.get_extended_address", + return_value=TEST_BORDER_AGENT_EXTENDED_ADDRESS, + ), + patch( + "python_otbr_api.OTBR.set_channel", + side_effect=python_otbr_api.OTBRError, + ), + ): + await websocket_client.send_json_auto_id( + { + "type": "otbr/set_channel", + "extended_address": TEST_BORDER_AGENT_EXTENDED_ADDRESS.hex(), + "channel": 12, + } + ) + msg = await websocket_client.receive_json() + + assert not msg["success"] + assert msg["error"]["code"] == "set_channel_failed" + + +async def test_set_channel_fails_2( + hass: HomeAssistant, + aioclient_mock: AiohttpClientMocker, + otbr_config_entry_multipan, + websocket_client, ) -> None: """Test set channel.""" with patch( - "python_otbr_api.OTBR.set_channel", + "python_otbr_api.OTBR.get_extended_address", side_effect=python_otbr_api.OTBRError, ): await websocket_client.send_json_auto_id( - {"type": "otbr/set_channel", "channel": 12} + { + "type": "otbr/set_channel", + "extended_address": TEST_BORDER_AGENT_EXTENDED_ADDRESS.hex(), + "channel": 12, + } ) msg = await websocket_client.receive_json() assert not msg["success"] - assert msg["error"]["code"] == "set_channel_failed" + assert msg["error"]["code"] == "get_extended_address_failed" + + +async def test_set_channel_fails_3( + hass: HomeAssistant, + aioclient_mock: AiohttpClientMocker, + otbr_config_entry_multipan, + websocket_client, +) -> None: + """Test set channel.""" + with patch( + "python_otbr_api.OTBR.get_extended_address", + return_value=TEST_BORDER_AGENT_EXTENDED_ADDRESS, + ): + await websocket_client.send_json_auto_id( + { + "type": "otbr/set_channel", + "extended_address": "blah", + "channel": 12, + } + ) + msg = await websocket_client.receive_json() + + assert not msg["success"] + assert msg["error"]["code"] == "unknown_router"