From 8c74d51333d478691119ffc50003be8e43b8ad22 Mon Sep 17 00:00:00 2001 From: Vadim Aleksandrov Date: Fri, 8 Sep 2023 00:24:56 +0300 Subject: [PATCH] chore: add close method to db class --- transmission_telegram_bot/bot.py | 109 +++++++++++++++++-------------- transmission_telegram_bot/db.py | 6 ++ 2 files changed, 67 insertions(+), 48 deletions(-) diff --git a/transmission_telegram_bot/bot.py b/transmission_telegram_bot/bot.py index 9817651..ceb397f 100755 --- a/transmission_telegram_bot/bot.py +++ b/transmission_telegram_bot/bot.py @@ -8,6 +8,7 @@ from functools import wraps from pathlib import Path from textwrap import dedent +from typing import Coroutine from emoji import emojize from telegram import ( @@ -406,66 +407,67 @@ async def error_action(update, context): async def check_torrent_download_status(context): # noqa: C901 global transmission + global db + + if isinstance(db, Coroutine): + db = await db + try: - db = await DB.create(cfg["db"]["path"]) + torrents = await db.list_uncomplete_torrents() except Exception as exc: logger.error(f"{type(exc).__name__}({exc})") else: - try: - torrents = await db.list_uncomplete_torrents() - except Exception as exc: - logger.error(f"{type(exc).__name__}({exc})") - else: - if torrents: - for torrent in torrents: + if torrents: + for torrent in torrents: + try: + task = transmission.get_torrent(torrent[1]) + except Exception: + await db.remove_torrent_by_id(torrent[1]) + else: try: - task = transmission.get_torrent(torrent[1]) - except Exception: - await db.remove_torrent_by_id(torrent[1]) + if task.doneDate: + await db.complete_torrent(torrent[1]) + except Exception as exc: + logger.error(f"{type(exc).__name__}({exc})") else: - try: - if task.doneDate: - await db.complete_torrent(torrent[1]) - except Exception as exc: - logger.error(f"{type(exc).__name__}({exc})") - else: - if task.doneDate: - response = f'Torrent "*{task.name}*" was successfully downloaded' + if task.doneDate: + response = f'Torrent "*{task.name}*" was successfully downloaded' + try: + notify_flag = False try: + chat = cfg["telegram"]["allow_chat"].get(torrent[0]) + except Exception: notify_flag = False - try: - chat = cfg["telegram"]["allow_chat"].get(torrent[0]) - except Exception: - notify_flag = False - else: - if chat["notify"] == "personal": - notify_flag = True - if notify_flag: - context.bot.send_message( - chat_id=torrent[0], - text=response, - parse_mode="Markdown", - ) - - notify_about_all = [ - chat["telegram_id"] - for chat in cfg["telegram"]["allow_chat"] - if chat["notify"] == "all" - ] - if notify_about_all: - for telegram_id in notify_about_all: - await context.bot.send_message( - chat_id=telegram_id, - text=response, - parse_mode="Markdown", - ) - except Exception as exc: - logger.error(f"{type(exc).__name__}({exc})") + else: + if chat["notify"] == "personal": + notify_flag = True + if notify_flag: + context.bot.send_message( + chat_id=torrent[0], + text=response, + parse_mode="Markdown", + ) + + notify_about_all = [ + chat["telegram_id"] + for chat in cfg["telegram"]["allow_chat"] + if chat["notify"] == "all" + ] + if notify_about_all: + for telegram_id in notify_about_all: + await context.bot.send_message( + chat_id=telegram_id, + text=response, + parse_mode="Markdown", + ) + except Exception as exc: + logger.error(f"{type(exc).__name__}({exc})") def main(): global cfg global transmission + global db global logger parser = argparse.ArgumentParser() @@ -496,6 +498,13 @@ def main(): ) except Exception as exc: logger.error(f"Transmission connection error: {exc}") + sys.exit(1) + + try: + db = DB.create(cfg["db"]["path"]) + except Exception as exc: + logger.error(f"{type(exc).__name__}({exc})") + sys.exit(1) application = ApplicationBuilder().token(cfg["telegram"]["token"]) @@ -542,9 +551,13 @@ def main(): if job_queue: job_queue.run_repeating( - check_torrent_download_status, interval=check_period, first=10, job_kwargs={"max_instances": max_instances} + check_torrent_download_status, + interval=check_period, + first=10, + job_kwargs={"max_instances": max_instances}, ) application.run_polling() + db.close() if __name__ == "__main__": diff --git a/transmission_telegram_bot/db.py b/transmission_telegram_bot/db.py index 614e676..469a887 100644 --- a/transmission_telegram_bot/db.py +++ b/transmission_telegram_bot/db.py @@ -90,3 +90,9 @@ async def vacuum_db(self) -> None: await self.conn.execute("VACUUM") except Exception as exc: raise DBExceptionError(exc) from exc + + async def close(self) -> None: + try: + await self.conn.close() + except Exception as exc: + raise DBExceptionError(exc) from exc