Skip to content

Commit

Permalink
Merge pull request #184 from dvd-dev/feature/new-auth
Browse files Browse the repository at this point in the history
Feature/new auth
  • Loading branch information
ic-dev21 authored Mar 2, 2024
2 parents 74a1bc7 + 713ca02 commit 714da6a
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 211 deletions.
238 changes: 44 additions & 194 deletions pyhilo/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
import random
import string
import sys
from typing import TYPE_CHECKING, Any, Callable, Union, cast
from typing import Any, Callable, Union, cast
from urllib import parse

from aiohttp import ClientSession
from aiohttp.client_exceptions import ClientResponseError
import backoff
from homeassistant.helpers import config_entry_oauth2_flow

from pyhilo.const import (
ANDROID_CLIENT_ENDPOINT,
Expand All @@ -26,16 +27,7 @@
API_NOTIFICATIONS_ENDPOINT,
API_REGISTRATION_ENDPOINT,
API_REGISTRATION_HEADERS,
AUTH_CLIENT_ID,
AUTH_ENDPOINT,
AUTH_HOSTNAME,
AUTH_RESPONSE_TYPE,
AUTH_SCOPE,
AUTH_TYPE_PASSWORD,
AUTH_TYPE_REFRESH,
AUTOMATION_DEVICEHUB_ENDPOINT,
AUTOMATION_HOSTNAME,
CONTENT_TYPE_FORM,
DEFAULT_STATE_FILE,
DEFAULT_USER_AGENT,
FB_APP_ID,
Expand All @@ -52,9 +44,7 @@
)
from pyhilo.device import DeviceAttribute, HiloDevice, get_device_attributes
from pyhilo.exceptions import InvalidCredentialsError, RequestError
from pyhilo.util import schedule_callback
from pyhilo.util.state import (
TokenDict,
WebsocketDict,
WebsocketTransportsDict,
get_state,
Expand All @@ -76,104 +66,73 @@ def __init__(
self,
*,
session: ClientSession,
oauth_session: config_entry_oauth2_flow.OAuth2Session,
request_retries: int = REQUEST_RETRY,
log_traces: bool = False,
) -> None:
"""Initialize"""
self._access_token: str | None = None
self._backoff_refresh_lock_api = asyncio.Lock()
self._backoff_refresh_lock_ws = asyncio.Lock()
self._reg_id: str | None = None
self._request_retries = request_retries
self._state_yaml: str = DEFAULT_STATE_FILE
self._token_expiration: datetime | None = None
self.state = get_state(self._state_yaml)
self.async_request = self._wrap_request_method(self._request_retries)
self.device_attributes = get_device_attributes()
self.session: ClientSession = session
self._oauth_session = oauth_session
self.websocket: WebsocketClient
self._username: str
self._refresh_token_callbacks: list[Callable[..., Any]] = []
self.log_traces: bool = False
self.log_traces = log_traces
self._get_device_callbacks: list[Callable[..., Any]] = []

@property
def headers(self) -> dict[str, Any]:
headers = {
"User-Agent": DEFAULT_USER_AGENT,
}
if not self._access_token:
return headers
return {
**headers,
**{
"Content-Type": "application/json; charset=utf-8",
"Ocp-Apim-Subscription-Key": SUBSCRIPTION_KEY,
"authorization": f"Bearer {self._access_token}",
},
}

@classmethod
async def async_auth_refresh_token(
cls,
*,
session: ClientSession,
provided_refresh_token: Union[str, None] = None,
request_retries: int = REQUEST_RETRY,
state_yaml: str = DEFAULT_STATE_FILE,
log_traces: bool = False,
) -> API:
api = cls(session=session, request_retries=request_retries)
api.log_traces = log_traces
api._state_yaml = state_yaml
api.state = get_state(state_yaml)
if provided_refresh_token:
api._refresh_token = provided_refresh_token
else:
token_state = api.state.get("token", {})
api._refresh_token = token_state.get("refresh")
if not api._refresh_token:
raise InvalidCredentialsError

await api._async_refresh_access_token()
await api._async_post_init()
return api

@classmethod
async def async_auth_password(
async def async_create(
cls,
username: str,
password: str,
*,
session: ClientSession,
oauth_session: config_entry_oauth2_flow.OAuth2Session,
request_retries: int = REQUEST_RETRY,
state_yaml: str = DEFAULT_STATE_FILE,
log_traces: bool = False,
) -> API:
"""Get an authenticated API object from a username and password.
:param username: the username
:type username: ``str``
:param password: the password
:type the password: ``str``
"""Get an authenticated API object.
:param session: The ``aiohttp`` ``ClientSession`` session used for all HTTP requests
:type session: ``aiohttp.client.ClientSession``
:param oauth_session: The session to make requests authenticated with OAuth2.
:type oauth_session: ``config_entry_oauth2_flow.OAuth2Session``
:param request_retries: The default number of request retries to use
:type request_retries: ``int``
:param state_yaml: File where we store registration ID
:type state_yaml: ``str``
:rtype: :meth:`pyhilo.api.API`
"""
api = cls(session=session, request_retries=request_retries)
api.log_traces = log_traces
api._username = username
api._state_yaml = state_yaml
api.state = get_state(state_yaml)
password = parse.quote(password, safe="!@#$%^?&*()_+")
auth_body = api.auth_body(
AUTH_TYPE_PASSWORD, username=username, password=password
api = cls(
session=session,
oauth_session=oauth_session,
request_retries=request_retries,
log_traces=log_traces,
)
await api.async_auth_post(auth_body)
# Test token before post init
await api.async_get_access_token()
await api._async_post_init()
return api

@property
def headers(self) -> dict[str, Any]:
headers = {
"User-Agent": DEFAULT_USER_AGENT,
}
return {
**headers,
**{
"Content-Type": "application/json; charset=utf-8",
"Ocp-Apim-Subscription-Key": SUBSCRIPTION_KEY,
},
}

async def async_get_access_token(self) -> str:
"""Return a valid access token."""
if not self._oauth_session.valid_token:
await self._oauth_session.async_ensure_token_valid()

return str(self._oauth_session.token["access_token"])

def dev_atts(
self, attribute: str, value_type: Union[str, None] = None
) -> Union[DeviceAttribute, str]:
Expand Down Expand Up @@ -241,96 +200,6 @@ async def _get_fid(self) -> None:
await self.fb_install(self._fb_id)
self._get_fid_state()

async def _async_refresh_access_token(self) -> None:
"""Update access/refresh tokens from a refresh token
and schedule a callback for later to refresh it.
"""
auth_body = self.auth_body(
AUTH_TYPE_REFRESH,
refresh_token=self._refresh_token,
)
await self.async_auth_post(auth_body)
for callback in self._refresh_token_callbacks:
schedule_callback(callback, self._refresh_token)

async def async_auth_post(self, body: dict) -> None:
"""Prepares an authentication request for the Web API.
:param body: Contains the parameters passed to get tokens
:type body: dict
:raises InvalidCredentialsError: Invalid username/password
:raises RequestError: Other error
"""
try:
LOG.debug("Authentication intiated")
resp = await self._async_request(
"post",
AUTH_ENDPOINT,
host=AUTH_HOSTNAME,
headers={
"Content-Type": CONTENT_TYPE_FORM,
},
data=body,
)
except ClientResponseError as err:
LOG.error(f"ClientResponseError: {err}")
if err.status in (400, 401, 403):
LOG.error(f"Raising InvalidCredentialsError from {err}")
raise InvalidCredentialsError("Invalid credentials") from err
raise RequestError(err) from err
self._access_token = resp.get("access_token")
self._access_token_expire_dt = datetime.now() + timedelta(
seconds=int(str(resp.get("expires_in")))
)
self._refresh_token = resp.get(AUTH_TYPE_REFRESH, "")
token_dict: TokenDict = {
"access": self._access_token,
"refresh": self._refresh_token,
"expires_at": self._access_token_expire_dt,
}
set_state(self._state_yaml, "token", token_dict)

def auth_body(
self,
grant_type: str,
*,
username: str = "",
password: str = "",
refresh_token: str = "",
) -> dict[Any, Any]:
"""Generates a dict to pass to the authentication endpoint for
the Web API.
:param grant_type: either password or refresh_token
:type grant_type: str
:param username: defaults to ""
:type username: str, optional
:param password: defaults to ""
:type password: str, optional
:param refresh_token: Refresh token received from a previous password auth, defaults to ""
:type refresh_token: str, optional
:return: Dict structured for authentication
:rtype: dict[Any, Any]
"""
LOG.debug(f"Auth body for grant {grant_type}")
body = {
"grant_type": grant_type,
"client_id": AUTH_CLIENT_ID,
"scope": AUTH_SCOPE,
}
if grant_type == AUTH_TYPE_PASSWORD:
body = {
**body,
**{
"response_type": AUTH_RESPONSE_TYPE,
"username": username,
"password": password,
},
}
elif grant_type == AUTH_TYPE_REFRESH:
body[AUTH_TYPE_REFRESH] = refresh_token
return body

async def _async_request(
self, method: str, endpoint: str, host: str = API_HOSTNAME, **kwargs: Any
) -> dict[str, Any]:
Expand All @@ -352,7 +221,11 @@ async def _async_request(
kwargs["headers"] = {**kwargs["headers"], **FB_INSTALL_HEADERS}
if endpoint.startswith(ANDROID_CLIENT_ENDPOINT):
kwargs["headers"] = {**kwargs["headers"], **ANDROID_CLIENT_HEADERS}
if host == API_HOSTNAME:
access_token = await self.async_get_access_token()
kwargs["headers"]["authorization"] = f"Bearer {access_token}"
kwargs["headers"]["Host"] = host

data: dict[str, Any] = {}
url = parse.urljoin(f"https://{host}", endpoint)
if self.log_traces:
Expand Down Expand Up @@ -433,13 +306,6 @@ async def _async_handle_on_backoff(self, _: dict[str, Any]) -> None:
(self.ws_url, self.ws_token) = await self.post_devicehub_negociate()
await self.get_websocket_params()
return
if TYPE_CHECKING:
assert self._access_token_expire_dt
async with self._backoff_refresh_lock_api:
if datetime.now() <= self._access_token_expire_dt:
return
LOG.info("401 detected on api; refreshing api token")
await self._async_refresh_access_token()

@staticmethod
def _handle_on_giveup(_: dict[str, Any]) -> None:
Expand Down Expand Up @@ -482,22 +348,6 @@ def enable_request_retries(self) -> None:
"""Enable the request retry mechanism."""
self.async_request = self._wrap_request_method(self._request_retries)

def add_refresh_token_callback(
self, callback: Callable[..., None]
) -> Callable[..., None]:
"""Add a callback that should be triggered when tokens are refreshed.
Note that callbacks should expect to receive a refresh token as a parameter.
:param callback: The method to call after receiving an event.
:type callback: ``Callable[..., None]``
"""
self._refresh_token_callbacks.append(callback)

def remove() -> None:
"""Remove the callback."""
self._refresh_token_callbacks.remove(callback)

return remove

async def _async_post_init(self) -> None:
"""Perform some post-init actions."""
LOG.debug("Websocket postinit")
Expand All @@ -513,7 +363,7 @@ async def refresh_ws_token(self) -> None:
async def post_devicehub_negociate(self) -> tuple[str, str]:
LOG.debug("Getting websocket url")
url = f"{AUTOMATION_DEVICEHUB_ENDPOINT}/negotiate"
resp = await self.async_request("post", url, host=AUTOMATION_HOSTNAME)
resp = await self.async_request("post", url)
ws_url = resp.get("url")
ws_token = resp.get("accessToken")
set_state(
Expand Down
25 changes: 9 additions & 16 deletions pyhilo/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,20 @@
LOG: Final = logging.getLogger(__package__)
DEFAULT_STATE_FILE: Final = "hilo_state.yaml"
REQUEST_RETRY: Final = 9
TIMEOUT: Final = 10
TOKEN_EXPIRATION_PADDING: Final = 300
VERIFY: Final = True
DEVICE_REFRESH_TIME: Final = 1800
PYHILO_VERSION: Final = "2023.12.01"
PYHILO_VERSION: Final = "2023.12.02"
# TODO: Find a way to keep previous line in sync with pyproject.toml automatically

CONTENT_TYPE_FORM: Final = "application/x-www-form-urlencoded"
ANDROID_PKG_NAME: Final = "com.hiloenergie.hilo"
DOMAIN: Final = "hilo"
# Auth constants
AUTH_HOSTNAME: Final = "hilodirectoryb2c.b2clogin.com"
AUTH_ENDPOINT: Final = (
"/hilodirectoryb2c.onmicrosoft.com/oauth2/v2.0/token?p=B2C_1A_B2C_1_PasswordFlow"
)
AUTH_CLIENT_ID: Final = "9870f087-25f8-43b6-9cad-d4b74ce512e1"
AUTH_TYPE_PASSWORD: Final = "password"
AUTH_TYPE_REFRESH: Final = "refresh_token"
AUTH_RESPONSE_TYPE: Final = "token id_token"
AUTH_SCOPE: Final = "openid 9870f087-25f8-43b6-9cad-d4b74ce512e1 offline_access"
AUTH_HOSTNAME: Final = "connexion.hiloenergie.com"
AUTH_ENDPOINT: Final = "/HiloDirectoryB2C.onmicrosoft.com/B2C_1A_SIGN_IN/oauth2/v2.0/"
AUTH_AUTHORIZE: Final = f"https://{AUTH_HOSTNAME}{AUTH_ENDPOINT}authorize"
AUTH_TOKEN: Final = f"https://{AUTH_HOSTNAME}{AUTH_ENDPOINT}token"
AUTH_CHALLENGE_METHOD: Final = "S256"
AUTH_CLIENT_ID: Final = "1ca9f585-4a55-4085-8e30-9746a65fa561"
AUTH_SCOPE: Final = "openid https://HiloDirectoryB2C.onmicrosoft.com/hiloapis/user_impersonation offline_access"
SUBSCRIPTION_KEY: Final = "20eeaedcb86945afa3fe792cea89b8bf"

# API constants
Expand All @@ -46,8 +40,7 @@
"Hilo-Tenant": DOMAIN,
}

# Automation server constants
AUTOMATION_HOSTNAME: Final = "automation.hiloenergie.com"
# Automation server constant
AUTOMATION_DEVICEHUB_ENDPOINT: Final = "/DeviceHub"

# Request constants
Expand Down
Loading

0 comments on commit 714da6a

Please sign in to comment.