diff --git a/app/controller/controller.py b/app/controller/controller.py index 168fa26..e453fd8 100644 --- a/app/controller/controller.py +++ b/app/controller/controller.py @@ -12,6 +12,7 @@ from model.threading import SearchThread, FilePreviewThread, FileProcessingThread from views.list import DataListItem from PyQt5.QtWidgets import QListWidgetItem + from model.database import TableDBManager import os import re @@ -104,19 +105,22 @@ def _connect_sigs( self.signal_connections.append((signal, slot)) def add_tag(self, tag: 'str'): - self.model.last_selected_table.add_tag(tag) + self.model.last_selected_table.add_tag( + tag, self.model.table_db_manager) - self.curr_elems.tags_display_widget.clear() + self.curr_elems.tags_display_widget.clear() for tag in self.model.last_selected_table.get_tags(): self.curr_elems.tags_display_widget.addTag(tag) def remove_tag(self, tag: 'str'): - self.model.last_selected_table.remove_tag(tag) + self.model.last_selected_table.remove_tag( + tag, self.model.table_db_manager + ) self.curr_elems.tags_display_widget.clear() for tag in self.model.last_selected_table.get_tags(): self.curr_elems.tags_display_widget.addTag(tag) - + def _disconnect_sigs(self): for signal, slot in self.signal_connections: if signal and slot: @@ -267,7 +271,7 @@ def preview_processed_table( self.curr_elems.tags_display_widget.clear() for tag in table.get_tags(): self.curr_elems.tags_display_widget.addTag(tag) - + table_data = { "sheet": self.model.table_db_manager.get_processed_table_data( table.id, context diff --git a/app/model/article_managers.py b/app/model/article_managers.py index a28b63e..b6392fc 100644 --- a/app/model/article_managers.py +++ b/app/model/article_managers.py @@ -1,6 +1,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: from views.list import ListItem, DataListItem, ArticleListItem + from model.database import TableDBManager from utils.constants import PageIdentity from uuid import uuid4, UUID @@ -108,14 +109,15 @@ def set_checked_state(self, checked_state: 'bool', context: 'PageIdentity'): if context in self.observers: self.notify_observers(context) - def add_tag(self, tag: 'str'): + def add_tag(self, tag: 'str', db_manager: 'TableDBManager'): if tag not in self.tags: self.tags.append(tag) - # self.notify_observers(PageIdentity.PRUNED) + db_manager.update_table_tags(self.id, self.tags) - def remove_tag(self, tag: 'str'): + def remove_tag(self, tag: 'str', db_manager: 'TableDBManager'): if tag in self.tags: self.tags.remove(tag) + db_manager.update_table_tags(self.id, self.tags) def get_tags(self): return self.tags diff --git a/app/model/database.py b/app/model/database.py index 3a48672..b20c069 100644 --- a/app/model/database.py +++ b/app/model/database.py @@ -126,6 +126,23 @@ def update_table( ) session.commit() + def update_table_tags(self, table_id: 'str', tags: 'list[str]'): + self._update_table_tags(ProcessedTableDBEntry, table_id, tags) + self._update_table_tags(PostPruningTableDBEntry, table_id, tags) + + def _update_table_tags( + self, + table_class: 'TableDBEntry', + table_id: 'str', + tags: 'list[str]' + ): + _, Session = self._get_engine_and_session(table_class) + with Session() as session: + table_entry: 'TableDBEntry' = session.query(table_class).filter_by(table_id=table_id).first() + if table_entry: + table_entry.tags = tags + session.commit() + def get_processed_table_data( self, table_id: 'str',