Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Premier essai (NE PAS MERGER!!!) #224

Open
wants to merge 24 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 29 additions & 7 deletions pyhilo/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
API_NOTIFICATIONS_ENDPOINT,
API_REGISTRATION_ENDPOINT,
API_REGISTRATION_HEADERS,
AUTOMATION_CHALLENGE_ENDPOINT,
AUTOMATION_DEVICEHUB_ENDPOINT,
DEFAULT_STATE_FILE,
DEFAULT_USER_AGENT,
Expand All @@ -51,7 +52,7 @@
get_state,
set_state,
)
from pyhilo.websocket import WebsocketClient
from pyhilo.websocket import WebsocketClient, WebsocketManager


class API:
Expand Down Expand Up @@ -216,17 +217,24 @@ async def _async_request(
:rtype: dict[str, Any]
"""
kwargs.setdefault("headers", self.headers)
access_token = await self.async_get_access_token()

if endpoint.startswith(API_REGISTRATION_ENDPOINT):
kwargs["headers"] = {**kwargs["headers"], **API_REGISTRATION_HEADERS}
if endpoint.startswith(FB_INSTALL_ENDPOINT):
kwargs["headers"] = {**kwargs["headers"], **FB_INSTALL_HEADERS}
if endpoint.startswith(ANDROID_CLIENT_ENDPOINT):
kwargs["headers"] = {**kwargs["headers"], **ANDROID_CLIENT_HEADERS}
if host == API_HOSTNAME:
access_token = await self.async_get_access_token()
kwargs["headers"]["authorization"] = f"Bearer {access_token}"
kwargs["headers"]["Host"] = host

# ic-dev21 trying Leicas suggestion
if endpoint.startswith(AUTOMATION_CHALLENGE_ENDPOINT):
# remove Ocp-Apim-Subscription-Key header to avoid 401 error
kwargs["headers"].pop("Ocp-Apim-Subscription-Key", None)
kwargs["headers"]["authorization"] = f"Bearer {access_token}"
Comment on lines +235 to +239
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing the 'Ocp-Apim-Subscription-Key' header when the endpoint is 'AUTOMATION_CHALLENGE_ENDPOINT' might introduce issues if this header is required elsewhere in the system. It's crucial to document this change within the function's docstring or as comments explaining the broader context and its necessity.

Suggested change
# ic-dev21 trying Leicas suggestion
if endpoint.startswith(AUTOMATION_CHALLENGE_ENDPOINT):
# remove Ocp-Apim-Subscription-Key header to avoid 401 error
kwargs["headers"].pop("Ocp-Apim-Subscription-Key", None)
kwargs["headers"]["authorization"] = f"Bearer {access_token}"
// Consider adding more comments or documentation here to explain why removing the header is necessary

Comment on lines +236 to +239
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Handling different authorization methods for various endpoints in the same function increases complexity and potential for errors. Consider extracting this logic into a separate function to improve readability and maintainability.

Suggested change
if endpoint.startswith(AUTOMATION_CHALLENGE_ENDPOINT):
# remove Ocp-Apim-Subscription-Key header to avoid 401 error
kwargs["headers"].pop("Ocp-Apim-Subscription-Key", None)
kwargs["headers"]["authorization"] = f"Bearer {access_token}"
# Define a new function to handle request headers
def add_authorization_headers(endpoint, headers):
if endpoint.startswith(API_REGISTRATION_ENDPOINT):
headers.update(API_REGISTRATION_HEADERS)
elif endpoint.startswith(FB_INSTALL_ENDPOINT):
headers.update(FB_INSTALL_HEADERS)
elif endpoint.startswith(ANDROID_CLIENT_ENDPOINT):
headers.update(ANDROID_CLIENT_HEADERS)
if endpoint.startswith(AUTOMATION_CHALLENGE_ENDPOINT):
headers.pop("Ocp-Apim-Subscription-Key", None)

ic-dev21 marked this conversation as resolved.
Show resolved Hide resolved

data: dict[str, Any] = {}
url = parse.urljoin(f"https://{host}", endpoint)
if self.log_traces:
Expand Down Expand Up @@ -303,8 +311,9 @@ async def _async_handle_on_backoff(self, _: dict[str, Any]) -> None:
LOG.info(
"401 detected on websocket, refreshing websocket token. Old url: {self.ws_url} Old Token: {self.ws_token}"
)
LOG.info(f"401 detected on {err.request_info.url}")
async with self._backoff_refresh_lock_ws:
(self.ws_url, self.ws_token) = await self.post_devicehub_negociate()
await self.refresh_ws_token()
await self.get_websocket_params()
return

Expand Down Expand Up @@ -354,13 +363,26 @@ async def _async_post_init(self) -> None:
LOG.debug("Websocket postinit")
await self._get_fid()
await self._get_device_token()
await self.refresh_ws_token()
self.websocket = WebsocketClient(self)

# Initialize WebsocketManager ic-dev21
self.websocket_manager = WebsocketManager(
self.session, self.async_request, self._state_yaml, set_state
)
await self.websocket_manager.initialize_websockets()

# Create both websocket clients
# ic-dev21 need to work on this as it can't lint as is, may need to
# instanciate differently
self.websocket = WebsocketClient(self.websocket_manager.devicehub)
self.websocket2 = WebsocketClient(self.websocket_manager.challengehub)

async def refresh_ws_token(self) -> None:
(self.ws_url, self.ws_token) = await self.post_devicehub_negociate()
await self.get_websocket_params()
"""Refresh the websocket token."""
await self.websocket_manager.refresh_token(self.websocket_manager.devicehub)
await self.websocket_manager.refresh_token(self.websocket_manager.challengehub)


# ic-dev21 not sure this is still needed? See websocket.py _async_negotiate
async def post_devicehub_negociate(self) -> tuple[str, str]:
LOG.debug("Getting websocket url")
url = f"{AUTOMATION_DEVICEHUB_ENDPOINT}/negotiate"
Expand Down
2 changes: 2 additions & 0 deletions pyhilo/const.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@

# Automation server constant
AUTOMATION_DEVICEHUB_ENDPOINT: Final = "/DeviceHub"
AUTOMATION_CHALLENGE_ENDPOINT: Final = "/ChallengeHub"


# Request constants
DEFAULT_USER_AGENT: Final = f"PyHilo/{PYHILO_VERSION} HomeAssistant/{homeassistant.core.__version__} aiohttp/{aiohttp.__version__} Python/{platform.python_version()}"
Expand Down
163 changes: 156 additions & 7 deletions pyhilo/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,23 @@
from enum import IntEnum
import json
from os import environ
from typing import TYPE_CHECKING, Any, Callable, Dict
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple
from urllib import parse
Comment on lines +10 to +11
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is an unused import 'parse' which might have been intended for use in the module 'urllib'. Consider removing it to clean up unused imports and prevent confusion.

Suggested change
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple
from urllib import parse
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple


from aiohttp import ClientWebSocketResponse, WSMsgType
from aiohttp import ClientSession, ClientWebSocketResponse, WSMsgType
from aiohttp.client_exceptions import (
ClientError,
ServerDisconnectedError,
WSServerHandshakeError,
)
from yarl import URL

from pyhilo.const import DEFAULT_USER_AGENT, LOG
from pyhilo.const import (
AUTOMATION_CHALLENGE_ENDPOINT,
AUTOMATION_DEVICEHUB_ENDPOINT,
DEFAULT_USER_AGENT,
LOG,
)
from pyhilo.exceptions import (
CannotConnectError,
ConnectionClosedError,
Expand Down Expand Up @@ -208,7 +214,7 @@ async def _async_send_json(self, payload: dict[str, Any]) -> None:

if self._api.log_traces:
LOG.debug(
f"[TRACE] Sending data to websocket server: {json.dumps(payload)}"
f"[TRACE] Sending data to websocket {self._api.endpoint} : {json.dumps(payload)}"
)
# Hilo added a control character (chr(30)) at the end of each payload they send.
# They also expect this char to be there at the end of every payload we send them.
Expand All @@ -217,7 +223,9 @@ async def _async_send_json(self, payload: dict[str, Any]) -> None:
def _parse_message(self, msg: dict[str, Any]) -> None:
"""Parse an incoming message."""
if self._api.log_traces:
LOG.debug(f"[TRACE] Received message from websocket: {msg}")
LOG.debug(
f"[TRACE] Received message on websocket {self._api.endpoint}: {msg}"
)
if msg.get("type") == SignalRMsgType.PING:
schedule_callback(self._async_pong)
return
Expand Down Expand Up @@ -261,7 +269,7 @@ async def async_connect(self) -> None:
LOG.debug("Websocket: async_connect() called but already connected")
return

LOG.info("Websocket: Connecting to server")
LOG.info("Websocket: Connecting to server %s", self._api.endpoint)
if self._api.log_traces:
LOG.debug(f"[TRACE] Websocket URL: {self._api.full_ws_url}")
headers = {
Expand Down Expand Up @@ -296,7 +304,7 @@ async def async_connect(self) -> None:
LOG.error(f"Unable to connect to WS server {err}")
raise CannotConnectError(err) from err

LOG.info("Connected to websocket server")
LOG.info(f"Connected to websocket server {self._api.endpoint}")
self._watchdog.trigger()
for callback in self._connect_callbacks:
schedule_callback(callback)
Expand Down Expand Up @@ -376,3 +384,144 @@ async def async_invoke(
"type": inv_type,
}
)


@dataclass
class WebsocketConfig:
"""Configuration for a websocket connection"""

endpoint: str
url: Optional[str] = None
token: Optional[str] = None
connection_id: Optional[str] = None
full_ws_url: Optional[str] = None
log_traces: bool = True
session: ClientSession | None = None


class WebsocketManager:
"""Manages multiple websocket connections for the Hilo API"""

def __init__(
self, session: ClientSession, async_request, state_yaml: str, set_state_callback
) -> None:
"""Initialize the websocket manager.

Args:
session: The aiohttp client session
async_request: The async request method from the API class
state_yaml: Path to the state file
set_state_callback: Callback to save state
"""
self.session = session
self.async_request = async_request
self._state_yaml = state_yaml
self._set_state = set_state_callback
self._shared_token = None

# Initialize websocket configurations, more can be added here
self.devicehub = WebsocketConfig(
endpoint=AUTOMATION_DEVICEHUB_ENDPOINT, session=session
)
self.challengehub = WebsocketConfig(
endpoint=AUTOMATION_CHALLENGE_ENDPOINT, session=session
)

async def initialize_websockets(self) -> None:
"""Initialize both websocket connections"""
# ic-dev21 get token from device hub
await self.refresh_token(self.devicehub, get_new_token=True)
# ic-dev21 get token from challenge hub
await self.refresh_token(self.challengehub, get_new_token=True)

async def refresh_token(
self, config: WebsocketConfig, get_new_token: bool = True
) -> None:
"""Refresh token for a specific websocket configuration.

Args:
config: The websocket configuration to refresh
"""
if get_new_token:
config.url, self._shared_token = await self._negotiate(config)
config.token = self._shared_token
else:
config.url, _ = await self._negotiate(config)
config.token = self._shared_token

await self._get_websocket_params(config)

async def _negotiate(self, config: WebsocketConfig) -> Tuple[str, str]:
"""Negotiate websocket connection and get URL and token.

Args:
config: The websocket configuration to negotiate

Returns:
Tuple containing the websocket URL and access token
"""
LOG.debug(f"Getting websocket url for {config.endpoint}")
url = f"{config.endpoint}/negotiate"
LOG.debug(f"Negotiate URL is {url}")
Comment on lines +463 to +465
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

String interpolation with f-strings for logging is generally fine for lower volume messages, but be mindful if this was a high-volume statement as it could have a performance impact. Consider lazy logging if performance becomes a concern.


resp = await self.async_request("post", url)
ws_url = resp.get("url")
ws_token = resp.get("accessToken")

# Save state
state_key = (
"websocket"
if config.endpoint == AUTOMATION_DEVICEHUB_ENDPOINT
else "websocket2"
)
await self._set_state(
self._state_yaml,
state_key,
{
"url": ws_url,
"token": ws_token,
},
)

return ws_url, ws_token

async def _get_websocket_params(self, config: WebsocketConfig) -> None:
"""Get websocket parameters including connection ID.

Args:
config: The websocket configuration to get parameters for
"""
uri = parse.urlparse(config.url)
LOG.debug(f"Getting websocket params for {config.endpoint}")
LOG.debug(f"Getting uri {uri}")

resp = await self.async_request(
"post",
f"{uri.path}negotiate?{uri.query}",
host=uri.netloc,
headers={
"authorization": f"Bearer {config.token}",
},
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The use of 'resp.get("connectionId")' without a default could lead to potential KeyError. Consider using a default value to safeguard against missing dictionary keys.

Suggested change
resp = await self.async_request(
"post",
f"{uri.path}negotiate?{uri.query}",
host=uri.netloc,
headers={
"authorization": f"Bearer {config.token}",
},
)
config.connection_id = resp.get("connectionId", "")


config.connection_id = resp.get("connectionId", "")
config.full_ws_url = (
f"{config.url}&id={config.connection_id}&access_token={config.token}"
)
LOG.debug(f"Getting full ws URL {config.full_ws_url}")

transport_dict = resp.get("availableTransports", [])
websocket_dict = {
"connection_id": config.connection_id,
"available_transports": transport_dict,
"full_url": config.full_ws_url,
}

# Save state
state_key = (
"websocket"
if config.endpoint == AUTOMATION_DEVICEHUB_ENDPOINT
else "websocket2"
)
LOG.debug(f"Calling set_state {state_key}_params")
await self._set_state(self._state_yaml, state_key, websocket_dict)
Loading