Skip to content

Commit

Permalink
✨ Improve update progress tracking
Browse files Browse the repository at this point in the history
  • Loading branch information
kamaradclimber committed Sep 13, 2024
1 parent 0868476 commit 4b9b416
Showing 1 changed file with 31 additions and 2 deletions.
33 changes: 31 additions & 2 deletions custom_components/aquarea/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import aiohttp
import asyncio
from typing import Optional, Any
from io import BufferedReader, BytesIO

from homeassistant.components import mqtt
from homeassistant.components.mqtt.client import async_publish
Expand Down Expand Up @@ -220,16 +221,44 @@ async def async_install(self, version: str | None, backup: bool, **kwargs: Any)
_LOGGER.warn("Waiting for an mqtt message to get the ip address of heishamon")
await asyncio.sleep(1)

def track_progress(current, total):
self._attr_progress = int(current / total * 100)
_LOGGER.info(f"Currently read {current} out of {total}: {self._attr_progress}%")


async with aiohttp.ClientSession() as session:
_LOGGER.info(f"Starting upgrade of firmware to version {version} on {self._heishamon_ip}")
to = aiohttp.ClientTimeout(total=300, connect=10)
try:
resp = await session.post(
f"http://{self._heishamon_ip}/firmware", data={'md5': checksum, 'firmware': ('firmware.bin', firmware_binary, 'application/octet-stream')}, timeout=to)
with ProgressReader(firmware_binary, track_progress) as reader:
resp = await session.post(
f"http://{self._heishamon_ip}/firmware",
data={
'md5': checksum,
# 'firmware': ('firmware.bin', firmware_binary, 'application/octet-stream')
'firmware': reader

},
timeout=to
)
except TimeoutError as e:
_LOGGER.error(f"Timeout while uploading new firmware")
raise e
if resp.status != 200:
_LOGGER.warn(f"Impossible to perform firmware update to version {version}")
return
_LOGGER.info(f"Finished uploading firmware. Heishamon should now be rebooting")

class ProgressReader(BufferedReader):
def __init__(self, binary_data, read_callback=None):
self._read_callback = read_callback
super().__init__(raw=BytesIO(binary_data))
self.length = len(binary_data)

def read(self, size=None):
computed_size = size
if not computed_size:
computed_size = self.length - self.tell()
if self._read_callback:
self._read_callback(self.tell(), self.length)
return super(ProgressReader, self).read(size)

0 comments on commit 4b9b416

Please sign in to comment.