diff --git a/bin/recategorize-messages.py b/bin/recategorize-messages.py index 911c13062..6e7f9aa2f 100755 --- a/bin/recategorize-messages.py +++ b/bin/recategorize-messages.py @@ -1,6 +1,7 @@ #!/usr/bin/env python import datetime +from collections.abc import Iterable import click from sqlalchemy.orm import Query @@ -13,64 +14,84 @@ from inbox.models.session import global_session_scope, session_scope -def fetch_message_ids( +def yield_account_id_and_message_ids( *, - account_id: int | None, + only_account_id: int | None, date_start: datetime.date | None, date_end: datetime.date | None, only_inbox: bool, -) -> list[int]: - query = Query([Message.id]) - if account_id: - query = query.filter(Message.namespace.has(Namespace.account_id == account_id)) - if only_inbox: - inbox_folder = ImapUid.folder.has(Folder._canonical_name == "INBOX") - query = query.filter(Message.imapuids.any(inbox_folder)) - if date_start: - query = query.filter(Message.created_at >= date_start) - if date_end: - query = query.filter(Message.created_at < date_end) +) -> Iterable[int, list[int]]: + namespace_query = Query([Namespace.account_id, Namespace.id]) + if only_account_id: + namespace_query = namespace_query.filter( + Namespace.account_id == only_account_id + ) with global_session_scope() as session: - message_ids = [message_id for message_id, in query.with_session(session)] + account_id_to_namespace_id = { + account_id: namespace_id + for account_id, namespace_id in namespace_query.with_session(session) + } - return message_ids + for account_id, namespace_id in account_id_to_namespace_id.items(): + query = Query([Message.id]).filter(Message.namespace_id == namespace_id) + + if only_inbox: + inbox_folder = ImapUid.folder.has(Folder._canonical_name == "INBOX") + query = query.filter(Message.imapuids.any(inbox_folder)) + if date_start: + query = query.filter(Message.created_at >= date_start) + if date_end: + query = query.filter(Message.created_at < date_end) + + with global_session_scope() as session: + message_ids = [message_id for message_id, in query.with_session(session)] + + yield account_id, message_ids @click.command() @click.option("--date-start", type=click.DateTime(formats=["%Y-%m-%d"]), default=None) @click.option("--date-end", type=click.DateTime(formats=["%Y-%m-%d"]), default=None) -@click.option("--account-id", type=int, default=None) +@click.option("--only-account-id", type=int, default=None) @click.option("--only-inbox", is_flag=True, default=False) +@click.option("--dry-run/--no-dry-run", default=True) def main( - account_id: int | None, + only_account_id: int | None, only_inbox: bool, date_start: datetime.date | None, date_end: datetime.date | None, + dry_run: bool, ) -> None: - message_ids = fetch_message_ids( - account_id=account_id, + print( + f"Settings: {only_account_id=},{only_inbox=},{date_start=},{date_end=},{dry_run=}" + ) + + def session_factory(): + return global_session_scope() if dry_run else session_scope(None) + + for account_id, message_ids in yield_account_id_and_message_ids( + only_account_id=only_account_id, date_start=date_start, date_end=date_end, only_inbox=only_inbox, - ) + ): + print(f"{account_id=},{len(message_ids)=}") - print(f"Found {len(message_ids)}") - - for message_id in message_ids: - with session_scope(None) as session: - message = session.query(Message).get(message_id) - old_categories = set( - category.display_name for category in message.categories - ) - update_message_metadata(session, message.account, message, message.is_draft) - new_categories = set( - category.display_name for category in message.categories - ) - if old_categories != new_categories: - print( - f"Message {message_id} categories changed from {old_categories} to {new_categories}" + for message_id in message_ids: + with session_factory() as session: + message = session.query(Message).get(message_id) + old_categories = set( + category.display_name for category in message.categories + ) + update_message_metadata( + session, message.account, message, message.is_draft + ) + new_categories = set( + category.display_name for category in message.categories ) + if old_categories != new_categories: + print(f"\t{message_id=},{old_categories=} to {new_categories}=") if __name__ == "__main__":