diff --git a/pyhilo/api.py b/pyhilo/api.py index 761e23f..167352f 100755 --- a/pyhilo/api.py +++ b/pyhilo/api.py @@ -27,7 +27,7 @@ API_NOTIFICATIONS_ENDPOINT, API_REGISTRATION_ENDPOINT, API_REGISTRATION_HEADERS, - AUTOMATION_DEVICEHUB_ENDPOINT, + AUTOMATION_CHALLENGE_ENDPOINT, DEFAULT_STATE_FILE, DEFAULT_USER_AGENT, FB_APP_ID, @@ -51,7 +51,7 @@ get_state, set_state, ) -from pyhilo.websocket import WebsocketClient +from pyhilo.websocket import WebsocketClient, WebsocketManager class API: @@ -81,9 +81,13 @@ def __init__( self.device_attributes = get_device_attributes() self.session: ClientSession = session self._oauth_session = oauth_session - self.websocket: WebsocketClient + self.websocket_devices: WebsocketClient + self.websocket_challenges: WebsocketClient self.log_traces = log_traces self._get_device_callbacks: list[Callable[..., Any]] = [] + self.ws_url: str = "" + self.ws_token: str = "" + self.endpoint: str = "" @classmethod async def async_create( @@ -216,6 +220,8 @@ async def _async_request( :rtype: dict[str, Any] """ kwargs.setdefault("headers", self.headers) + access_token = await self.async_get_access_token() + if endpoint.startswith(API_REGISTRATION_ENDPOINT): kwargs["headers"] = {**kwargs["headers"], **API_REGISTRATION_HEADERS} if endpoint.startswith(FB_INSTALL_ENDPOINT): @@ -223,10 +229,15 @@ async def _async_request( if endpoint.startswith(ANDROID_CLIENT_ENDPOINT): kwargs["headers"] = {**kwargs["headers"], **ANDROID_CLIENT_HEADERS} if host == API_HOSTNAME: - access_token = await self.async_get_access_token() kwargs["headers"]["authorization"] = f"Bearer {access_token}" kwargs["headers"]["Host"] = host + # ic-dev21 trying Leicas suggestion + if endpoint.startswith(AUTOMATION_CHALLENGE_ENDPOINT): + # remove Ocp-Apim-Subscription-Key header to avoid 401 error + kwargs["headers"].pop("Ocp-Apim-Subscription-Key", None) + kwargs["headers"]["authorization"] = f"Bearer {access_token}" + data: dict[str, Any] = {} url = parse.urljoin(f"https://{host}", endpoint) if self.log_traces: @@ -245,6 +256,7 @@ async def _async_request( if self.log_traces: LOG.debug("[TRACE] Data received from /%s: %s", endpoint, data) resp.raise_for_status() + LOG.debug(f"ic-dev21 Data is {data}") return data def _get_url( @@ -303,8 +315,9 @@ async def _async_handle_on_backoff(self, _: dict[str, Any]) -> None: LOG.info( "401 detected on websocket, refreshing websocket token. Old url: {self.ws_url} Old Token: {self.ws_token}" ) + LOG.info(f"401 detected on {err.request_info.url}") async with self._backoff_refresh_lock_ws: - (self.ws_url, self.ws_token) = await self.post_devicehub_negociate() + await self.refresh_ws_token() await self.get_websocket_params() return @@ -354,30 +367,23 @@ async def _async_post_init(self) -> None: LOG.debug("Websocket postinit") await self._get_fid() await self._get_device_token() - await self.refresh_ws_token() - self.websocket = WebsocketClient(self) - async def refresh_ws_token(self) -> None: - (self.ws_url, self.ws_token) = await self.post_devicehub_negociate() - await self.get_websocket_params() - - async def post_devicehub_negociate(self) -> tuple[str, str]: - LOG.debug("Getting websocket url") - url = f"{AUTOMATION_DEVICEHUB_ENDPOINT}/negotiate" - LOG.debug(f"devicehub URL is {url}") - resp = await self.async_request("post", url) - ws_url = resp.get("url") - ws_token = resp.get("accessToken") - LOG.debug("Calling set_state devicehub_negotiate") - await set_state( - self._state_yaml, - "websocket", - { - "url": ws_url, - "token": ws_token, - }, + # Initialize WebsocketManager ic-dev21 + self.websocket_manager = WebsocketManager( + self.session, self.async_request, self._state_yaml, set_state ) - return (ws_url, ws_token) + await self.websocket_manager.initialize_websockets() + + # Create both websocket clients + # ic-dev21 need to work on this as it can't lint as is, may need to + # instantiate differently + self.websocket_devices = WebsocketClient(self.websocket_manager.devicehub) + self.websocket_challenges = WebsocketClient(self.websocket_manager.challengehub) + + async def refresh_ws_token(self) -> None: + """Refresh the websocket token.""" + await self.websocket_manager.refresh_token(self.websocket_manager.devicehub) + await self.websocket_manager.refresh_token(self.websocket_manager.challengehub) async def get_websocket_params(self) -> None: uri = parse.urlparse(self.ws_url) diff --git a/pyhilo/const.py b/pyhilo/const.py old mode 100644 new mode 100755 index 6441f6d..a53edb1 --- a/pyhilo/const.py +++ b/pyhilo/const.py @@ -42,6 +42,8 @@ # Automation server constant AUTOMATION_DEVICEHUB_ENDPOINT: Final = "/DeviceHub" +AUTOMATION_CHALLENGE_ENDPOINT: Final = "/ChallengeHub" + # Request constants DEFAULT_USER_AGENT: Final = f"PyHilo/{PYHILO_VERSION} HomeAssistant/{homeassistant.core.__version__} aiohttp/{aiohttp.__version__} Python/{platform.python_version()}" diff --git a/pyhilo/websocket.py b/pyhilo/websocket.py index ee2b6fc..dc494e4 100755 --- a/pyhilo/websocket.py +++ b/pyhilo/websocket.py @@ -7,9 +7,10 @@ from enum import IntEnum import json from os import environ -from typing import TYPE_CHECKING, Any, Callable, Dict +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple +from urllib import parse -from aiohttp import ClientWebSocketResponse, WSMsgType +from aiohttp import ClientSession, ClientWebSocketResponse, WSMsgType from aiohttp.client_exceptions import ( ClientError, ServerDisconnectedError, @@ -17,7 +18,12 @@ ) from yarl import URL -from pyhilo.const import DEFAULT_USER_AGENT, LOG +from pyhilo.const import ( + AUTOMATION_CHALLENGE_ENDPOINT, + AUTOMATION_DEVICEHUB_ENDPOINT, + DEFAULT_USER_AGENT, + LOG, +) from pyhilo.exceptions import ( CannotConnectError, ConnectionClosedError, @@ -208,7 +214,7 @@ async def _async_send_json(self, payload: dict[str, Any]) -> None: if self._api.log_traces: LOG.debug( - f"[TRACE] Sending data to websocket server: {json.dumps(payload)}" + f"[TRACE] Sending data to websocket {self._api.endpoint} : {json.dumps(payload)}" ) # Hilo added a control character (chr(30)) at the end of each payload they send. # They also expect this char to be there at the end of every payload we send them. @@ -217,7 +223,9 @@ async def _async_send_json(self, payload: dict[str, Any]) -> None: def _parse_message(self, msg: dict[str, Any]) -> None: """Parse an incoming message.""" if self._api.log_traces: - LOG.debug(f"[TRACE] Received message from websocket: {msg}") + LOG.debug( + f"[TRACE] Received message on websocket {self._api.endpoint}: {msg}" + ) if msg.get("type") == SignalRMsgType.PING: schedule_callback(self._async_pong) return @@ -247,7 +255,7 @@ def add_disconnect_callback( return self._add_callback(self._disconnect_callbacks, callback) def add_event_callback(self, callback: Callable[..., Any]) -> Callable[..., None]: - """Add a callback callback to be called upon receiving an event. + """Add a callback to be called upon receiving an event. Note that callbacks should expect to receive a WebsocketEvent object as a parameter. :param callback: The method to call after receiving an event. @@ -261,7 +269,7 @@ async def async_connect(self) -> None: LOG.debug("Websocket: async_connect() called but already connected") return - LOG.info("Websocket: Connecting to server") + LOG.info("Websocket: Connecting to server %s", self._api.endpoint) if self._api.log_traces: LOG.debug(f"[TRACE] Websocket URL: {self._api.full_ws_url}") headers = { @@ -296,7 +304,7 @@ async def async_connect(self) -> None: LOG.error(f"Unable to connect to WS server {err}") raise CannotConnectError(err) from err - LOG.info("Connected to websocket server") + LOG.info(f"Connected to websocket server {self._api.endpoint}") self._watchdog.trigger() for callback in self._connect_callbacks: schedule_callback(callback) @@ -376,3 +384,144 @@ async def async_invoke( "type": inv_type, } ) + + +@dataclass +class WebsocketConfig: + """Configuration for a websocket connection""" + + endpoint: str + url: Optional[str] = None + token: Optional[str] = None + connection_id: Optional[str] = None + full_ws_url: Optional[str] = None + log_traces: bool = True + session: ClientSession | None = None + + +class WebsocketManager: + """Manages multiple websocket connections for the Hilo API""" + + def __init__( + self, + session: ClientSession, + async_request: Callable[..., Any], + state_yaml: str, + set_state_callback: Callable[..., Any], + ) -> None: + """Initialize the websocket manager. + + Args: + session: The aiohttp client session + async_request: The async request method from the API class + state_yaml: Path to the state file + set_state_callback: Callback to save state + """ + self.session = session + self.async_request = async_request + self._state_yaml = state_yaml + self._set_state = set_state_callback + self._shared_token: Optional[str] = None + # Initialize websocket configurations, more can be added here + self.devicehub = WebsocketConfig( + endpoint=AUTOMATION_DEVICEHUB_ENDPOINT, session=session + ) + self.challengehub = WebsocketConfig( + endpoint=AUTOMATION_CHALLENGE_ENDPOINT, session=session + ) + + async def initialize_websockets(self) -> None: + """Initialize both websocket connections""" + # ic-dev21 get token from device hub + await self.refresh_token(self.devicehub, get_new_token=True) + # ic-dev21 get token from challenge hub + await self.refresh_token(self.challengehub, get_new_token=True) + + async def refresh_token( + self, config: WebsocketConfig, get_new_token: bool = True + ) -> None: + """Refresh token for a specific websocket configuration. + Args: + config: The websocket configuration to refresh + """ + if get_new_token: + config.url, self._shared_token = await self._negotiate(config) + config.token = self._shared_token + else: + config.url, _ = await self._negotiate(config) + config.token = self._shared_token + + await self._get_websocket_params(config) + + async def _negotiate(self, config: WebsocketConfig) -> Tuple[str, str]: + """Negotiate websocket connection and get URL and token. + Args: + config: The websocket configuration to negotiate + Returns: + Tuple containing the websocket URL and access token + """ + LOG.debug(f"Getting websocket url for {config.endpoint}") + url = f"{config.endpoint}/negotiate" + LOG.debug(f"Negotiate URL is {url}") + + resp = await self.async_request("post", url) + ws_url = resp.get("url") + ws_token = resp.get("accessToken") + + # Save state + state_key = ( + "websocketDevices" + if config.endpoint == AUTOMATION_DEVICEHUB_ENDPOINT + else "websocketChallenges" + ) + await self._set_state( + self._state_yaml, + state_key, + { + "url": ws_url, + "token": ws_token, + }, + ) + + return ws_url, ws_token + + async def _get_websocket_params(self, config: WebsocketConfig) -> None: + """Get websocket parameters including connection ID. + + Args: + config: The websocket configuration to get parameters for + """ + uri = parse.urlparse(config.url) + LOG.debug(f"Getting websocket params for {config.endpoint}") + LOG.debug(f"Getting uri {uri}") + + resp = await self.async_request( + "post", + f"{uri.path}negotiate?{uri.query}", # type: ignore + host=uri.netloc, + headers={ + "authorization": f"Bearer {config.token}", + }, + ) + + config.connection_id = resp.get("connectionId", "") + config.full_ws_url = ( + f"{config.url}&id={config.connection_id}&access_token={config.token}" + ) + LOG.debug(f"Getting full ws URL {config.full_ws_url}") + + transport_dict = resp.get("availableTransports", []) + websocket_dict = { + "connection_id": config.connection_id, + "available_transports": transport_dict, + "full_url": config.full_ws_url, + } + + # Save state + state_key = ( + "websocketDevices" + if config.endpoint == AUTOMATION_DEVICEHUB_ENDPOINT + else "websocketChallenges" + ) + LOG.debug(f"Calling set_state {state_key}_params") + await self._set_state(self._state_yaml, state_key, websocket_dict)