Skip to content

Commit

Permalink
add supervisor support
Browse files Browse the repository at this point in the history
  • Loading branch information
marcelveldt committed Sep 5, 2020
1 parent 43f06e2 commit 61c3fa9
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 60 deletions.
134 changes: 75 additions & 59 deletions hass_client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,60 +7,61 @@
import asyncio
import functools
import logging
import os
from enum import Enum
from typing import Any, Awaitable, Callable, List, Optional, Union
from typing import Any, Awaitable, Callable, List, Optional, Tuple, Union

import aiohttp

LOGGER = logging.getLogger("hass-client")


EVENT_CONNECTED = "connected"
EVENT_STATE_CHANGED = "state_changed"

IS_SUPERVISOR = os.environ.get("HASSIO_TOKEN") is not None


class HomeAssistant:
"""Connection to HomeAssistant (over websockets)."""

def __init__(self, url: str = None, token: str = None):
def __init__(self, url: str = None, token: str = None, loop=None):
"""
Initialize the connection to HomeAssistant.
:param url: full url to the HomeAssistant instance.
:param token: a long lived token.
Initialize the connection to HomeAssistant.
:param url: full url to the HomeAssistant instance.
:param token: a long lived token.
If url and token are omitted, assume supervisor install.
"""
self._loop = asyncio.get_event_loop()
self._loop = loop
self._states = {}
self.__async_send_ws = None
self.__last_id = 10
self._ws_callbacks = {}
self._device_registry = {}
self._entity_registry = {}
self._area_registry = {}
if url.startswith("https://"):
if not url and not token and IS_SUPERVISOR:
self._host = "hassio/homeassistant"
self._use_ssl = False
elif url.startswith("https://"):
self._use_ssl = True
self._host = url.replace("https://", "")
else:
self._token = token
elif url and token:
self._use_ssl = False
self._host = url.replace("http://", "")
self._token = token
self._token = token
else:
raise RuntimeError("Please provide a valid url and token!")
self._http_session = None
self._initial_state_received = False
self._connected_callback = None
self._event_listeners = []
self._ws_task = None

def connect(self):
"""Start the connection."""
self._loop.create_task(self.async_connect())

def close(self):
"""Close the connection."""
self._loop.create_task(self.async_close())

async def async_connect(self):
"""Start the connection."""
if not self._loop.is_running:
raise RuntimeError("A running eventloop is required!")
"""Connect to HomeAssistant."""
if not self._loop:
self._loop = asyncio.get_running_loop()
if not self._token or not self._host:
raise RuntimeError("A valid url and token is required")
self._http_session = aiohttp.ClientSession(
Expand All @@ -83,11 +84,11 @@ def register_event_callback(
entity_filter: Union[None, str, List[str]] = None,
) -> Callable:
"""
Add callback for events.
Returns function to remove the listener.
:param cb_func: callback function or coroutine
:param event_filter: Optionally only listen for these events
:param event_filter: In case of state_changed event, only forward these entities
Add callback for events.
Returns function to remove the listener.
:param cb_func: callback function or coroutine
:param event_filter: Optionally only listen for these events
:param event_filter: In case of state_changed event, only forward these entities
"""
listener = (cb_func, event_filter, entity_filter)
self._event_listeners.append(listener)
Expand All @@ -100,11 +101,15 @@ def remove_listener():
@property
def device_registry(self) -> dict:
"""Return device registry."""
if not self._device_registry:
LOGGER.warning("Connection is not yet ready.")
return self._device_registry

@property
def entity_registry(self) -> dict:
"""Return device registry."""
if not self._entity_registry:
LOGGER.warning("Connection is not yet ready.")
return self._entity_registry

@property
Expand Down Expand Up @@ -154,19 +159,9 @@ def items_by_domain(self, domain: str) -> List[dict]:

def get_state(self, entity_id: str, attribute: str = "state") -> dict:
"""
Get state(obj) of a Home Assistant entity.
:param entity_id: The entity id for which the state must be returned.
:param attribute: The attribute to return from the state object.
"""
return asyncio.run_coroutine_threadsafe(
self.async_get_state(entity_id, attribute), self._loop
).result()

async def async_get_state(self, entity_id: str, attribute: str = "state") -> dict:
"""
Get state(obj) of a Home Assistant entity.
:param entity_id: The entity id for which the state must be returned.
:param attribute: The attribute to return from the state object.
Get state(obj) of a Home Assistant entity.
:param entity_id: The entity id for which the state must be returned.
:param attribute: The attribute to return from the state object.
"""
if not self._initial_state_received:
LOGGER.warning("Connection is not yet ready.")
Expand All @@ -179,12 +174,20 @@ async def async_get_state(self, entity_id: str, attribute: str = "state") -> dic
return state_obj
return None

async def async_get_state(self, entity_id: str, attribute: str = "state") -> dict:
"""
Get state(obj) of a Home Assistant entity.
:param entity_id: The entity id for which the state must be returned.
:param attribute: The attribute to return from the state object.
"""
return self.get_state(entity_id, attribute) # safe to call in loop

async def async_call_service(self, domain: str, service: str, service_data: dict = None):
"""
Call service on Home Assistant.
:param url: Domain of the service to call (e.g. light, switch).
:param service: The service to call (e.g. turn_on).
:param service_data: Optional dict with parameters (e.g. { brightness: 20 }).
Call service on Home Assistant.
:param url: Domain of the service to call (e.g. light, switch).
:param service: The service to call (e.g. turn_on).
:param service_data: Optional dict with parameters (e.g. { brightness: 20 }).
"""
if not self._initial_state_received:
LOGGER.warning("Connection is not yet ready.")
Expand All @@ -193,14 +196,12 @@ async def async_call_service(self, domain: str, service: str, service_data: dict
msg["service_data"] = service_data
return await self.__async_send_ws(msg)

async def async_set_state(
self, entity_id: str, new_state: str, state_attributes: dict = None
):
async def async_set_state(self, entity_id: str, new_state: str, state_attributes: dict = None):
"""
Set state on a homeassistant entity.
:param entity_id: Entity id to set state for.
:param new_state: The new state.
:param state_attributes: Optional dict with parameters (e.g. { name: 'Cool entity' }).
Set state on a homeassistant entity.
:param entity_id: Entity id to set state for.
:param new_state: The new state.
:param state_attributes: Optional dict with parameters (e.g. { name: 'Cool entity' }).
"""
if state_attributes is None:
state_attributes = {}
Expand Down Expand Up @@ -240,7 +241,10 @@ async def _send_msg(msg, callback=None):
async for msg in conn:
await self.__process_ws_message(conn, msg)

except (aiohttp.client_exceptions.ClientConnectorError, ConnectionRefusedError,) as exc:
except (
aiohttp.client_exceptions.ClientConnectorError,
ConnectionRefusedError,
) as exc:
LOGGER.error(exc)
await asyncio.sleep(10)

Expand All @@ -252,7 +256,7 @@ async def __process_ws_message(self, conn, msg):
# send auth token
auth_msg = {
"type": "auth",
"access_token": self._token,
"access_token": self.__get_token(),
}
await conn.send_json(auth_msg)
elif data["type"] == "auth_invalid":
Expand All @@ -270,23 +274,28 @@ async def __async_subscribe_events(self):
"""Subscribe to common events when the ws was (re)connected."""
# request all current states
await self.__async_send_ws(
{"type": "get_states"}, callback=self.__async_receive_all_states,
{"type": "get_states"},
callback=self.__async_receive_all_states,
)
# subscribe to all events
await self.__async_send_ws(
{"type": "subscribe_events"}, callback=self.__async_state_changed,
{"type": "subscribe_events"},
callback=self.__async_state_changed,
)
# request all area, device and entity registry
await self.__async_send_ws(
{"type": "config/area_registry/list"}, callback=self.__async_receive_area_registry,
{"type": "config/area_registry/list"},
callback=self.__async_receive_area_registry,
)
# request device registry
await self.__async_send_ws(
{"type": "config/device_registry/list"}, callback=self.__async_receive_device_registry,
{"type": "config/device_registry/list"},
callback=self.__async_receive_device_registry,
)
# request entity registry
await self.__async_send_ws(
{"type": "config/entity_registry/list"}, callback=self.__async_receive_entity_registry,
{"type": "config/entity_registry/list"},
callback=self.__async_receive_entity_registry,
)

async def __async_state_changed(self, msg: dict):
Expand Down Expand Up @@ -336,7 +345,7 @@ async def __async_get_data(self, endpoint: str):
if self._use_ssl:
url = f"https://{self._host}/api/{endpoint}"
headers = {
"Authorization": f"Bearer {self._token}",
"Authorization": f"Bearer {self.__get_token()}",
"Content-Type": "application/json",
}
async with self._http_session.get(url, headers=headers, verify_ssl=False) as response:
Expand All @@ -348,7 +357,7 @@ async def __async_post_data(self, endpoint: str, data: dict):
if self._use_ssl:
url = f"https://{self._host}/api/{endpoint}"
headers = {
"Authorization": "Bearer %s" % self._token,
"Authorization": "Bearer %s" % self.__get_token(),
"Content-Type": "application/json",
}
async with self._http_session.post(
Expand Down Expand Up @@ -376,3 +385,10 @@ async def __async_signal_event(self, event: str, event_details: Any = None):
self._loop.create_task(check_target(event, event_details))
else:
self._loop.run_in_executor(None, cb_func, event, event_details)

def __get_token(self) -> str:
""""Get auth token for Home Assistant."""
if IS_SUPERVISOR:
# On supervisor installs the token is provided by a environment variable
return os.environ["HASSIO_TOKEN"]
return self._token
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

PROJECT_DIR = Path(__file__).parent.resolve()
README_FILE = PROJECT_DIR / "README.md"
VERSION = "0.0.3"
VERSION = "0.0.4"

with open("requirements.txt") as f:
INSTALL_REQUIRES = f.read().splitlines()
Expand Down

0 comments on commit 61c3fa9

Please sign in to comment.