Skip to content

Commit

Permalink
Make ruff happy
Browse files Browse the repository at this point in the history
  • Loading branch information
arenekosreal committed Sep 20, 2024
1 parent a5202c1 commit 26fe961
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 76 deletions.
26 changes: 14 additions & 12 deletions src/crx_repo/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,18 @@ async def download_forever(self):
try:
await self._do_download()
await asyncio.sleep(self.interval)
except asyncio.CancelledError:
_logger.debug("Cleaning old extensions...")
for p in sorted(
self.cache_path.rglob("*.crx"),
key=lambda p: p.stat().st_mtime,
)[:-1]:
p.unlink()
_logger.debug(
"Stopping downloader for extension %s",
self.extension_id,
)
except (asyncio.CancelledError, KeyboardInterrupt):
break
_logger.debug("Cleaning old extensions...")
for p in sorted(
self.cache_path.rglob("*.crx"),
key=lambda p: p.stat().st_mtime,
)[:-1]:
p.unlink()
_logger.debug(
"Stopping downloader for extension %s",
self.extension_id,
)

async def _do_download(self):
async with ClientSession() as session:
Expand All @@ -72,7 +72,9 @@ async def _do_download(self):
if response.status != HTTPStatus.OK:
_logger.debug("Failed to download extension.")
return
if response.content_length != int(size):
if response.content_length is None:
_logger.warning("No Content-Length header found.")
elif response.content_length != int(size):
_logger.warning("Content-Length is not equals to size returned by API.")
hash_calculator = hashlib.sha256()
extension_path = self.cache_path / (version + ".crx.part")
Expand Down
177 changes: 113 additions & 64 deletions src/crx_repo/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from ssl import SSLContext
from ssl import create_default_context
from typing import Any
from typing import Callable
from typing import NamedTuple
from aiohttp import web
from asyncio import Task
from asyncio import CancelledError
Expand All @@ -19,7 +21,9 @@
from watchfiles import Change
from watchfiles import awatch
from urllib.parse import unquote
from collections.abc import Generator
from collections.abc import Iterator
from collections.abc import Coroutine
from collections.abc import AsyncIterator
from crx_repo.client import ExtensionDownloader
from xml.etree.ElementTree import Element
from xml.etree.ElementTree import indent
Expand All @@ -28,8 +32,35 @@
from crx_repo.config.config import TlsHttpListenConfig


class ExtensionInfo(NamedTuple):
"""A named tuple stores metainfo for an .crx file."""
extension_id: str
version: str
size: int
hash_sha256: str


_logger = logging.getLogger(__name__)
_cache: dict[str, set[str]] = {}
_cache: set[ExtensionInfo] = set()


def _iter_extension_info(
target_extension_id: str | None = None,
target_extension_version: str | None = None,
) -> Iterator[ExtensionInfo]:
for info in _cache:
if target_extension_id is not None:
extension_id_match = info.extension_id == target_extension_id
else:
extension_id_match = True

if target_extension_version is not None:
extension_version_match = info.version == target_extension_version
else:
extension_version_match = True

if extension_id_match and extension_version_match:
yield info


def _get_ssl_context(tls: TlsHttpListenConfig | None) -> SSLContext | None:
Expand All @@ -56,7 +87,7 @@ def _parse_params(params: str) -> dict[str, str | None]:

def _get_filters(xs: list[str]) -> list[tuple[str, str]]:
filters: list[tuple[str, str]] = []
_logger.debug("Handling query param.")
_logger.debug("Handling query param %s.", xs)
for x in xs:
x_unquoted = unquote(x)
params = _parse_params(x_unquoted)
Expand All @@ -81,13 +112,16 @@ async def _watch_cache(cache: Path):
extension_id = p.parent.stem
match change:
case Change.added:
if extension_id in _cache:
_cache[extension_id].add(extension_version)
else:
_cache[extension_id] = {extension_version}
info = ExtensionInfo(
extension_id,
extension_version,
p.stat().st_size,
hashlib.sha256(p.read_bytes()).hexdigest()
)
_cache.add(info)
case Change.deleted:
if extension_id in _cache and extension_version in _cache[extension_id]:
_cache[extension_id].remove(extension_version)
for info in _iter_extension_info(extension_id, extension_version):
_cache.remove(info)
case _:
pass
except CancelledError:
Expand All @@ -104,53 +138,29 @@ async def _block():
return


def _get_crx_info(
cache_path: Path,
filters: list[tuple[str, str]],
) -> Generator[tuple[str, tuple[str, int, str]], Any, None]:
if len(filters) == 0:
for crx, versions in _cache.items():
for version in versions:
filters.append((crx, version))

for crx, version in filters:
path = cache_path / crx / (version + ".crx")
if path.is_file():
content = path.read_bytes()
info = (version, len(content), hashlib.sha256(content).hexdigest())
yield crx, info


def _gen_cache(cache: Path):
for path in cache.glob("./*/*.crx"):
extension_version = path.stem
extension_id = path.parent.stem
if extension_id in _cache:
_cache[extension_id].add(extension_version)
else:
_cache[extension_id] = {extension_version}
extension_size = path.stat().st_size
extension_hash_sha256 = hashlib.sha256(path.read_bytes()).hexdigest()
_cache.add(
ExtensionInfo(
extension_id,
extension_version,
extension_size,
extension_hash_sha256,
)
)


def setup_server(
def _get_cleanup_ctx_callback(
config: Config,
debug: bool = False,
) -> web.Application:
"""Get WebApplication instance from config."""
cache_path = Path(config.cache_dir)

app = web.Application(
logger=_logger,
debug=debug,
)

extension_keys: list[web.AppKey[Task[None]]] = []
for extension in config.extensions:
extension_key = web.AppKey(extension, Task[None])
extension_keys.append(extension_key)

watcher_key = web.AppKey("cache-watcher", Task[None])

async def register_services(app: web.Application):
cache_path: Path,
watcher_key: web.AppKey[Task[None]],
extension_keys: list[web.AppKey[Task[None]]]
) -> Callable[[web.Application], AsyncIterator[None]]:
async def callback(app: web.Application):
_gen_cache(cache_path)

app[watcher_key] = create_task(_watch_cache(cache_path))
Expand All @@ -176,33 +186,41 @@ async def register_services(app: web.Application):
await app[watcher_key]

_cache.clear()
return callback

app.cleanup_ctx.append(register_services)

prefix = config.prefix if config.prefix.startswith("/") else "/" + config.prefix
manifest_path = config.manifest_path if config.manifest_path.startswith("/") else \
"/" + config.manifest_path

async def _handle_manifest(request: web.Request) -> web.Response:
def _get_handler(
config: Config,
prefix: str
) -> Callable[[web.Request], Coroutine[Any, Any, web.Response]]:
async def handler(request: web.Request) -> web.Response:
absolute_base = config.base + prefix + "/"
root = Element("gupdate")
root.attrib["xmlns"] = "http://www.google.com/update2/response"
root.attrib["protocol"] = "2.0"
xs = request.query.getall("x") if "x" in request.query else []
filters = _get_filters(xs)
infos: list[ExtensionInfo] = []
if len(filters) > 0:
for extension_id, extension_version in filters:
for info in _iter_extension_info(extension_id, extension_version):
infos.append(info)
else:
for info in _iter_extension_info():
infos.append(info)

for crx, info in _get_crx_info(cache_path, filters):
app = root.find("./app[@appid='{}']".format(crx))
for info in infos:
extension_path = info.extension_id + "/" + info.version + ".crx"
app = root.find("./app[@appid='{}']".format(info.extension_id))
if app is None:
app = Element("app")
app.attrib["appid"] = crx
app.attrib["appid"] = info.extension_id
root.append(app)
version, size, sha256 = info
update_check = Element("updatecheck")
update_check.attrib["codebase"] = absolute_base + crx + "/" + version + ".crx"
update_check.attrib["version"] = version
update_check.attrib["size"] = str(size)
update_check.attrib["hash_sha256"] = sha256
update_check.attrib["codebase"] = absolute_base + extension_path
update_check.attrib["version"] = info.version
update_check.attrib["size"] = str(info.size)
update_check.attrib["hash_sha256"] = info.hash_sha256
app.append(update_check)
indent(root)
xml: bytes = tostring(root, encoding="utf-8", xml_declaration=True)
Expand All @@ -211,14 +229,45 @@ async def _handle_manifest(request: web.Request) -> web.Response:
content_type="application/xml",
charset="utf-8"
)
return handler


def setup_server(
config: Config,
debug: bool = False,
) -> web.Application:
"""Get WebApplication instance from config."""
cache_path = Path(config.cache_dir)

app = web.Application(
logger=_logger,
debug=debug,
)

extension_keys: list[web.AppKey[Task[None]]] = []
for extension in config.extensions:
extension_key = web.AppKey(extension, Task[None])
extension_keys.append(extension_key)

watcher_key = web.AppKey("cache-watcher", Task[None])

callback = _get_cleanup_ctx_callback(config, cache_path, watcher_key, extension_keys)

app.cleanup_ctx.append(callback)

prefix = config.prefix if config.prefix.startswith("/") else "/" + config.prefix
manifest_path = config.manifest_path if config.manifest_path.startswith("/") else \
"/" + config.manifest_path

handler = _get_handler(config, prefix)

if not cache_path.is_dir():
if cache_path.exists():
_logger.debug("Removing %s to create directory.", cache_path)
cache_path.unlink()
cache_path.mkdir(parents=True)
_ = app.router.add_static(prefix, cache_path, name="crx-handler")
_ = app.router.add_get(manifest_path, _handle_manifest, name="manifest-handler")
_ = app.router.add_get(manifest_path, handler, name="manifest-handler")
return app


Expand Down

0 comments on commit 26fe961

Please sign in to comment.