Skip to content

Commit

Permalink
Move from Pickling to all SQL
Browse files Browse the repository at this point in the history
  • Loading branch information
jordantgh committed Oct 10, 2023
1 parent 39faa82 commit 4b63954
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 28 deletions.
44 changes: 23 additions & 21 deletions app/model/database.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from sqlalchemy import create_engine, Column, String, BLOB
from sqlalchemy import create_engine, Column, String
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from uuid import uuid4
import pickle
import pandas as pd

Base = declarative_base()

Expand All @@ -12,15 +12,15 @@ class ProcessedTableDBEntry(Base):

table_id = Column(String, primary_key=True)
original_file_id = Column(String)
table_data = Column(BLOB)
sql_table_name = Column(String)


class PostPruningTableDBEntry(Base):
__tablename__ = 'post_pruning_tables'

table_id = Column(String, primary_key=True)
original_file_id = Column(String)
table_data = Column(BLOB)
sql_table_name = Column(String)


class TableDBManager:
Expand All @@ -29,22 +29,25 @@ def __init__(self, db_url=f"sqlite:///tables-{str(uuid4())}.db"):
Base.metadata.create_all(self.engine)
self.Session = sessionmaker(bind=self.engine)

def save_table(self, table_class, table_id, original_file_id, table_data):
def save_table(self, table_class, table_id, original_file_id, df):
with self.Session() as session:
sql_table_name = f"{table_class.__tablename__}_{table_id}"
df.to_sql(sql_table_name, self.engine, index=False)
new_table = table_class(
table_id=table_id,
original_file_id=str(original_file_id),
table_data=table_data
sql_table_name=sql_table_name
)
session.add(new_table)
session.commit()

def update_table(self, table_class, table_id, table_data):
def update_table(self, table_class, table_id, df):
with self.Session() as session:
existing_table = session.query(
table_class).filter_by(table_id=table_id).first()
if existing_table:
existing_table.table_data = table_data
df.to_sql(existing_table.sql_table_name,
self.engine, if_exists='replace', index=False)
session.commit()

def get_processed_table_data(self, table_id):
Expand All @@ -55,12 +58,12 @@ def get_post_pruning_table_data(self, table_id):

def get_table_data(self, table_class, table_id):
with self.Session() as session:
table = session.query(table_class).filter_by(
table_id=table_id).first()
if table:
table_data = pickle.loads(table.table_data)
table_data.reset_index(drop=True, inplace=True)
return table_data
table_entry = session.query(
table_class).filter_by(table_id=table_id).first()
if table_entry:
df = pd.read_sql_table(
table_entry.sql_table_name, self.engine)
return df.reset_index(drop=True)
return None

def get_table_object(self, table_class, table_id):
Expand All @@ -71,10 +74,11 @@ def get_table_object(self, table_class, table_id):

def delete_table(self, table_class, table_id):
with self.Session() as session:
table = session.query(table_class).filter_by(
table_id=table_id).first()
if table:
session.delete(table)
table_entry = session.query(
table_class).filter_by(table_id=table_id).first()
if table_entry:
self.engine.execute(f"DROP TABLE IF EXISTS {table_entry.sql_table_name}")
session.delete(table_entry)
session.commit()

def reset(self):
Expand All @@ -85,6 +89,4 @@ def reset(self):


def processed_df_to_db(db_manager, table_id, original_file_id, df):
serialized_df = pickle.dumps(df)
db_manager.save_table(ProcessedTableDBEntry, table_id,
original_file_id, serialized_df)
db_manager.save_table(ProcessedTableDBEntry, table_id, original_file_id, df)
12 changes: 5 additions & 7 deletions app/model/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from uuid import uuid4
import pickle

from model.article_managers import Bibliography, Article, SuppFile, SuppFileManager, ProcessedTable, ProcessedTableManager
from model.database import TableDBManager, PostPruningTableDBEntry
Expand Down Expand Up @@ -78,7 +77,7 @@ def prune_tables_and_columns(self, context):
in article.processed_tables if table.checked]

for table in tables_to_prune:
serialized_df = None
pruned_df = None

if context == 'parsed':
columns_vector = table.checked_columns
Expand All @@ -90,10 +89,9 @@ def prune_tables_and_columns(self, context):
table.id)

if data is not None and columns_vector is not None:
data = data.iloc[:, columns_vector]
serialized_df = pickle.dumps(data)
pruned_df = data.iloc[:, columns_vector]

if serialized_df is not None:
if pruned_df is not None:
existing_table = self.table_db_manager.get_table_object(
PostPruningTableDBEntry,
table.id)
Expand All @@ -102,13 +100,13 @@ def prune_tables_and_columns(self, context):
self.table_db_manager.update_table(
PostPruningTableDBEntry,
table.id,
serialized_df)
pruned_df)
else:
self.table_db_manager.save_table(
PostPruningTableDBEntry,
table.id,
table.file_id,
serialized_df)
pruned_df)

article.pruned_tables = tables_to_prune

Expand Down

0 comments on commit 4b63954

Please sign in to comment.