From 26fe961f1b07da9b9fa6e3547e6f0e8144612df6 Mon Sep 17 00:00:00 2001 From: arenekosreal <17194552+arenekosreal@users.noreply.github.com> Date: Fri, 20 Sep 2024 21:36:57 +0800 Subject: [PATCH] Make ruff happy --- src/crx_repo/client.py | 26 +++--- src/crx_repo/server.py | 177 ++++++++++++++++++++++++++--------------- 2 files changed, 127 insertions(+), 76 deletions(-) diff --git a/src/crx_repo/client.py b/src/crx_repo/client.py index ca34702..932a5e3 100644 --- a/src/crx_repo/client.py +++ b/src/crx_repo/client.py @@ -46,18 +46,18 @@ async def download_forever(self): try: await self._do_download() await asyncio.sleep(self.interval) - except asyncio.CancelledError: - _logger.debug("Cleaning old extensions...") - for p in sorted( - self.cache_path.rglob("*.crx"), - key=lambda p: p.stat().st_mtime, - )[:-1]: - p.unlink() - _logger.debug( - "Stopping downloader for extension %s", - self.extension_id, - ) + except (asyncio.CancelledError, KeyboardInterrupt): break + _logger.debug("Cleaning old extensions...") + for p in sorted( + self.cache_path.rglob("*.crx"), + key=lambda p: p.stat().st_mtime, + )[:-1]: + p.unlink() + _logger.debug( + "Stopping downloader for extension %s", + self.extension_id, + ) async def _do_download(self): async with ClientSession() as session: @@ -72,7 +72,9 @@ async def _do_download(self): if response.status != HTTPStatus.OK: _logger.debug("Failed to download extension.") return - if response.content_length != int(size): + if response.content_length is None: + _logger.warning("No Content-Length header found.") + elif response.content_length != int(size): _logger.warning("Content-Length is not equals to size returned by API.") hash_calculator = hashlib.sha256() extension_path = self.cache_path / (version + ".crx.part") diff --git a/src/crx_repo/server.py b/src/crx_repo/server.py index f44b9d9..b456746 100644 --- a/src/crx_repo/server.py +++ b/src/crx_repo/server.py @@ -11,6 +11,8 @@ from ssl import SSLContext from ssl import create_default_context from typing import Any +from typing import Callable +from typing import NamedTuple from aiohttp import web from asyncio import Task from asyncio import CancelledError @@ -19,7 +21,9 @@ from watchfiles import Change from watchfiles import awatch from urllib.parse import unquote -from collections.abc import Generator +from collections.abc import Iterator +from collections.abc import Coroutine +from collections.abc import AsyncIterator from crx_repo.client import ExtensionDownloader from xml.etree.ElementTree import Element from xml.etree.ElementTree import indent @@ -28,8 +32,35 @@ from crx_repo.config.config import TlsHttpListenConfig +class ExtensionInfo(NamedTuple): + """A named tuple stores metainfo for an .crx file.""" + extension_id: str + version: str + size: int + hash_sha256: str + + _logger = logging.getLogger(__name__) -_cache: dict[str, set[str]] = {} +_cache: set[ExtensionInfo] = set() + + +def _iter_extension_info( + target_extension_id: str | None = None, + target_extension_version: str | None = None, +) -> Iterator[ExtensionInfo]: + for info in _cache: + if target_extension_id is not None: + extension_id_match = info.extension_id == target_extension_id + else: + extension_id_match = True + + if target_extension_version is not None: + extension_version_match = info.version == target_extension_version + else: + extension_version_match = True + + if extension_id_match and extension_version_match: + yield info def _get_ssl_context(tls: TlsHttpListenConfig | None) -> SSLContext | None: @@ -56,7 +87,7 @@ def _parse_params(params: str) -> dict[str, str | None]: def _get_filters(xs: list[str]) -> list[tuple[str, str]]: filters: list[tuple[str, str]] = [] - _logger.debug("Handling query param.") + _logger.debug("Handling query param %s.", xs) for x in xs: x_unquoted = unquote(x) params = _parse_params(x_unquoted) @@ -81,13 +112,16 @@ async def _watch_cache(cache: Path): extension_id = p.parent.stem match change: case Change.added: - if extension_id in _cache: - _cache[extension_id].add(extension_version) - else: - _cache[extension_id] = {extension_version} + info = ExtensionInfo( + extension_id, + extension_version, + p.stat().st_size, + hashlib.sha256(p.read_bytes()).hexdigest() + ) + _cache.add(info) case Change.deleted: - if extension_id in _cache and extension_version in _cache[extension_id]: - _cache[extension_id].remove(extension_version) + for info in _iter_extension_info(extension_id, extension_version): + _cache.remove(info) case _: pass except CancelledError: @@ -104,53 +138,29 @@ async def _block(): return -def _get_crx_info( - cache_path: Path, - filters: list[tuple[str, str]], -) -> Generator[tuple[str, tuple[str, int, str]], Any, None]: - if len(filters) == 0: - for crx, versions in _cache.items(): - for version in versions: - filters.append((crx, version)) - - for crx, version in filters: - path = cache_path / crx / (version + ".crx") - if path.is_file(): - content = path.read_bytes() - info = (version, len(content), hashlib.sha256(content).hexdigest()) - yield crx, info - - def _gen_cache(cache: Path): for path in cache.glob("./*/*.crx"): extension_version = path.stem extension_id = path.parent.stem - if extension_id in _cache: - _cache[extension_id].add(extension_version) - else: - _cache[extension_id] = {extension_version} + extension_size = path.stat().st_size + extension_hash_sha256 = hashlib.sha256(path.read_bytes()).hexdigest() + _cache.add( + ExtensionInfo( + extension_id, + extension_version, + extension_size, + extension_hash_sha256, + ) + ) -def setup_server( +def _get_cleanup_ctx_callback( config: Config, - debug: bool = False, -) -> web.Application: - """Get WebApplication instance from config.""" - cache_path = Path(config.cache_dir) - - app = web.Application( - logger=_logger, - debug=debug, - ) - - extension_keys: list[web.AppKey[Task[None]]] = [] - for extension in config.extensions: - extension_key = web.AppKey(extension, Task[None]) - extension_keys.append(extension_key) - - watcher_key = web.AppKey("cache-watcher", Task[None]) - - async def register_services(app: web.Application): + cache_path: Path, + watcher_key: web.AppKey[Task[None]], + extension_keys: list[web.AppKey[Task[None]]] +) -> Callable[[web.Application], AsyncIterator[None]]: + async def callback(app: web.Application): _gen_cache(cache_path) app[watcher_key] = create_task(_watch_cache(cache_path)) @@ -176,33 +186,41 @@ async def register_services(app: web.Application): await app[watcher_key] _cache.clear() + return callback - app.cleanup_ctx.append(register_services) - - prefix = config.prefix if config.prefix.startswith("/") else "/" + config.prefix - manifest_path = config.manifest_path if config.manifest_path.startswith("/") else \ - "/" + config.manifest_path - async def _handle_manifest(request: web.Request) -> web.Response: +def _get_handler( + config: Config, + prefix: str +) -> Callable[[web.Request], Coroutine[Any, Any, web.Response]]: + async def handler(request: web.Request) -> web.Response: absolute_base = config.base + prefix + "/" root = Element("gupdate") root.attrib["xmlns"] = "http://www.google.com/update2/response" root.attrib["protocol"] = "2.0" xs = request.query.getall("x") if "x" in request.query else [] filters = _get_filters(xs) + infos: list[ExtensionInfo] = [] + if len(filters) > 0: + for extension_id, extension_version in filters: + for info in _iter_extension_info(extension_id, extension_version): + infos.append(info) + else: + for info in _iter_extension_info(): + infos.append(info) - for crx, info in _get_crx_info(cache_path, filters): - app = root.find("./app[@appid='{}']".format(crx)) + for info in infos: + extension_path = info.extension_id + "/" + info.version + ".crx" + app = root.find("./app[@appid='{}']".format(info.extension_id)) if app is None: app = Element("app") - app.attrib["appid"] = crx + app.attrib["appid"] = info.extension_id root.append(app) - version, size, sha256 = info update_check = Element("updatecheck") - update_check.attrib["codebase"] = absolute_base + crx + "/" + version + ".crx" - update_check.attrib["version"] = version - update_check.attrib["size"] = str(size) - update_check.attrib["hash_sha256"] = sha256 + update_check.attrib["codebase"] = absolute_base + extension_path + update_check.attrib["version"] = info.version + update_check.attrib["size"] = str(info.size) + update_check.attrib["hash_sha256"] = info.hash_sha256 app.append(update_check) indent(root) xml: bytes = tostring(root, encoding="utf-8", xml_declaration=True) @@ -211,6 +229,37 @@ async def _handle_manifest(request: web.Request) -> web.Response: content_type="application/xml", charset="utf-8" ) + return handler + + +def setup_server( + config: Config, + debug: bool = False, +) -> web.Application: + """Get WebApplication instance from config.""" + cache_path = Path(config.cache_dir) + + app = web.Application( + logger=_logger, + debug=debug, + ) + + extension_keys: list[web.AppKey[Task[None]]] = [] + for extension in config.extensions: + extension_key = web.AppKey(extension, Task[None]) + extension_keys.append(extension_key) + + watcher_key = web.AppKey("cache-watcher", Task[None]) + + callback = _get_cleanup_ctx_callback(config, cache_path, watcher_key, extension_keys) + + app.cleanup_ctx.append(callback) + + prefix = config.prefix if config.prefix.startswith("/") else "/" + config.prefix + manifest_path = config.manifest_path if config.manifest_path.startswith("/") else \ + "/" + config.manifest_path + + handler = _get_handler(config, prefix) if not cache_path.is_dir(): if cache_path.exists(): @@ -218,7 +267,7 @@ async def _handle_manifest(request: web.Request) -> web.Response: cache_path.unlink() cache_path.mkdir(parents=True) _ = app.router.add_static(prefix, cache_path, name="crx-handler") - _ = app.router.add_get(manifest_path, _handle_manifest, name="manifest-handler") + _ = app.router.add_get(manifest_path, handler, name="manifest-handler") return app