Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ Support firmware update from HA #249

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 96 additions & 4 deletions custom_components/aquarea/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import logging
import json
import aiohttp
from typing import Optional
import asyncio
from typing import Optional, Any
from io import BufferedReader, BytesIO

from homeassistant.components import mqtt
from homeassistant.components.mqtt.client import async_publish
Expand Down Expand Up @@ -87,14 +89,25 @@ def __init__(
self.stats_firmware_contain_version: Optional[bool] = None

self._attr_supported_features = (
UpdateEntityFeature.RELEASE_NOTES | UpdateEntityFeature.INSTALL
UpdateEntityFeature.RELEASE_NOTES | UpdateEntityFeature.INSTALL | UpdateEntityFeature.PROGRESS | UpdateEntityFeature.SPECIFIC_VERSION
)
self._attr_release_url = f"https://github.com/{HEISHAMON_REPOSITORY}/releases"
# FIXME: for now we assume board is using the "model-type-small"
self._model_type = "model-type-small"
self._release_notes = None
self._attr_progress = False

self._ip_topic = f"{self.discovery_prefix}ip"
self._heishamon_ip = None

async def async_added_to_hass(self) -> None:
"""Subscribe to MQTT events."""

@callback
def ip_received(message):
self._heishamon_ip = message.payload
await mqtt.async_subscribe(self.hass, self._ip_topic, ip_received, 1)

@callback
def message_received(message):
"""Handle new MQTT messages."""
Expand Down Expand Up @@ -166,5 +179,84 @@ async def _update_latest_release(self):
self.async_write_ha_state()

def release_notes(self) -> str | None:
header = f"⚠ Update is not supported via HA. Update is done via heishamon webui\n\n\n"
return header + str(self._release_notes)
return f"⚠️ Automated upgrades only supports {self._model_type}.\n\nBeware!\n\n" + str(self._release_notes)

async def async_install(self, version: str | None, backup: bool, **kwargs: Any) -> None:
if version is None:
version = self._attr_latest_version
_LOGGER.info(f"Will install latest version ({version}) of the firmware")
else:
_LOGGER.info(f"Will install version {version} of the firmware")
self._attr_progress = 0
async with aiohttp.ClientSession() as session:
resp = await session.get(
f"https://github.com/Egyras/HeishaMon/raw/master/binaries/{self._model_type}/HeishaMon.ino.d1-v{version}.bin"
)

if resp.status != 200:
_LOGGER.warn(
f"Impossible to download version {version} from heishamon repository {HEISHAMON_REPOSITORY}"
)
return

firmware_binary = await resp.read()
_LOGGER.info(f"Firmware is {len(firmware_binary)} bytes long")
self._attr_progress = 10
resp = await session.get(
f"https://github.com/Egyras/HeishaMon/raw/master/binaries/{self._model_type}/HeishaMon.ino.d1-v{version}.md5"
)

if resp.status != 200:
_LOGGER.warn(
f"Impossible to fetch checksum of version #{version} from heishamon repository {HEISHAMON_REPOSITORY}"
)
return
checksum = await resp.text()
self._attr_progress = 20
_LOGGER.info(f"Downloaded binary and checksum {checksum} of version {version}")

while self._heishamon_ip is None:
_LOGGER.warn("Waiting for an mqtt message to get the ip address of heishamon")
await asyncio.sleep(1)

def track_progress(current, total):
self._attr_progress = int(current / total * 100)
_LOGGER.info(f"Currently read {current} out of {total}: {self._attr_progress}%")


async with aiohttp.ClientSession() as session:
_LOGGER.info(f"Starting upgrade of firmware to version {version} on {self._heishamon_ip}")
to = aiohttp.ClientTimeout(total=300, connect=10)
try:
with ProgressReader(firmware_binary, track_progress) as reader:
resp = await session.post(
f"http://{self._heishamon_ip}/firmware",
data={
'md5': checksum,
# 'firmware': ('firmware.bin', firmware_binary, 'application/octet-stream')
'firmware': reader

},
timeout=to
)
except TimeoutError as e:
_LOGGER.error(f"Timeout while uploading new firmware")
raise e
if resp.status != 200:
_LOGGER.warn(f"Impossible to perform firmware update to version {version}")
return
_LOGGER.info(f"Finished uploading firmware. Heishamon should now be rebooting")

class ProgressReader(BufferedReader):
def __init__(self, binary_data, read_callback=None):
self._read_callback = read_callback
super().__init__(raw=BytesIO(binary_data))
self.length = len(binary_data)

def read(self, size=None):
computed_size = size
if not computed_size:
computed_size = self.length - self.tell()
if self._read_callback:
self._read_callback(self.tell(), self.length)
return super(ProgressReader, self).read(size)
Loading