diff --git a/app/model/database.py b/app/model/database.py index 12ba95b..ef22a02 100644 --- a/app/model/database.py +++ b/app/model/database.py @@ -1,8 +1,8 @@ -from sqlalchemy import create_engine, Column, String, BLOB +from sqlalchemy import create_engine, Column, String from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker from uuid import uuid4 -import pickle +import pandas as pd Base = declarative_base() @@ -12,7 +12,7 @@ class ProcessedTableDBEntry(Base): table_id = Column(String, primary_key=True) original_file_id = Column(String) - table_data = Column(BLOB) + sql_table_name = Column(String) class PostPruningTableDBEntry(Base): @@ -20,7 +20,7 @@ class PostPruningTableDBEntry(Base): table_id = Column(String, primary_key=True) original_file_id = Column(String) - table_data = Column(BLOB) + sql_table_name = Column(String) class TableDBManager: @@ -29,22 +29,25 @@ def __init__(self, db_url=f"sqlite:///tables-{str(uuid4())}.db"): Base.metadata.create_all(self.engine) self.Session = sessionmaker(bind=self.engine) - def save_table(self, table_class, table_id, original_file_id, table_data): + def save_table(self, table_class, table_id, original_file_id, df): with self.Session() as session: + sql_table_name = f"{table_class.__tablename__}_{table_id}" + df.to_sql(sql_table_name, self.engine, index=False) new_table = table_class( table_id=table_id, original_file_id=str(original_file_id), - table_data=table_data + sql_table_name=sql_table_name ) session.add(new_table) session.commit() - def update_table(self, table_class, table_id, table_data): + def update_table(self, table_class, table_id, df): with self.Session() as session: existing_table = session.query( table_class).filter_by(table_id=table_id).first() if existing_table: - existing_table.table_data = table_data + df.to_sql(existing_table.sql_table_name, + self.engine, if_exists='replace', index=False) session.commit() def get_processed_table_data(self, table_id): @@ -55,12 +58,12 @@ def get_post_pruning_table_data(self, table_id): def get_table_data(self, table_class, table_id): with self.Session() as session: - table = session.query(table_class).filter_by( - table_id=table_id).first() - if table: - table_data = pickle.loads(table.table_data) - table_data.reset_index(drop=True, inplace=True) - return table_data + table_entry = session.query( + table_class).filter_by(table_id=table_id).first() + if table_entry: + df = pd.read_sql_table( + table_entry.sql_table_name, self.engine) + return df.reset_index(drop=True) return None def get_table_object(self, table_class, table_id): @@ -71,10 +74,11 @@ def get_table_object(self, table_class, table_id): def delete_table(self, table_class, table_id): with self.Session() as session: - table = session.query(table_class).filter_by( - table_id=table_id).first() - if table: - session.delete(table) + table_entry = session.query( + table_class).filter_by(table_id=table_id).first() + if table_entry: + self.engine.execute(f"DROP TABLE IF EXISTS {table_entry.sql_table_name}") + session.delete(table_entry) session.commit() def reset(self): @@ -85,6 +89,4 @@ def reset(self): def processed_df_to_db(db_manager, table_id, original_file_id, df): - serialized_df = pickle.dumps(df) - db_manager.save_table(ProcessedTableDBEntry, table_id, - original_file_id, serialized_df) + db_manager.save_table(ProcessedTableDBEntry, table_id, original_file_id, df) diff --git a/app/model/model.py b/app/model/model.py index a0953ec..841f403 100644 --- a/app/model/model.py +++ b/app/model/model.py @@ -1,5 +1,4 @@ from uuid import uuid4 -import pickle from model.article_managers import Bibliography, Article, SuppFile, SuppFileManager, ProcessedTable, ProcessedTableManager from model.database import TableDBManager, PostPruningTableDBEntry @@ -78,7 +77,7 @@ def prune_tables_and_columns(self, context): in article.processed_tables if table.checked] for table in tables_to_prune: - serialized_df = None + pruned_df = None if context == 'parsed': columns_vector = table.checked_columns @@ -90,10 +89,9 @@ def prune_tables_and_columns(self, context): table.id) if data is not None and columns_vector is not None: - data = data.iloc[:, columns_vector] - serialized_df = pickle.dumps(data) + pruned_df = data.iloc[:, columns_vector] - if serialized_df is not None: + if pruned_df is not None: existing_table = self.table_db_manager.get_table_object( PostPruningTableDBEntry, table.id) @@ -102,13 +100,13 @@ def prune_tables_and_columns(self, context): self.table_db_manager.update_table( PostPruningTableDBEntry, table.id, - serialized_df) + pruned_df) else: self.table_db_manager.save_table( PostPruningTableDBEntry, table.id, table.file_id, - serialized_df) + pruned_df) article.pruned_tables = tables_to_prune