From a06f7bba6c09d7a108866d2416f0f357d68c8c9b Mon Sep 17 00:00:00 2001 From: Vincent Luczkow Date: Thu, 18 Aug 2022 00:15:45 -0700 Subject: [PATCH] Switched to dropping columns immediately after insertion (frees memory sooner). Closes #78 --- aukpy/create_tables.sql | 6 ++--- aukpy/db.py | 53 +++++++++++++++++++++++++++++++---------- tests/perf_test.py | 9 ++++--- 3 files changed, 50 insertions(+), 18 deletions(-) diff --git a/aukpy/create_tables.sql b/aukpy/create_tables.sql index 19ff685..89fb6d1 100644 --- a/aukpy/create_tables.sql +++ b/aukpy/create_tables.sql @@ -67,12 +67,13 @@ CREATE TABLE IF NOT EXISTS sampling_event ( trip_comments text, all_species_reported integer, number_observers integer, - UNIQUE(sampling_event_identifier) + location_data_id integer, + UNIQUE(sampling_event_identifier), + FOREIGN KEY (location_data_id) REFERENCES location_data(id) ); CREATE TABLE IF NOT EXISTS observation ( id integer PRIMARY KEY, - location_data_id integer NOT NULL, species_id integer NOT NULL, breeding_id integer, protocol_id integer, @@ -91,6 +92,5 @@ CREATE TABLE IF NOT EXISTS observation ( FOREIGN KEY (sampling_event_id) REFERENCES sampling_event(id), FOREIGN KEY (species_id) REFERENCES species(id), FOREIGN KEY (breeding_id) REFERENCES breeding(id), - FOREIGN KEY (location_data_id) REFERENCES location_data(id), FOREIGN KEY (protocol_id) REFERENCES protocol(id) ); diff --git a/aukpy/db.py b/aukpy/db.py index 9bb7859..c73342f 100644 --- a/aukpy/db.py +++ b/aukpy/db.py @@ -9,7 +9,8 @@ import sqlite3 from pathlib import Path -from typing import Dict, List, Optional, Tuple, Any +from time import time +from typing import Dict, Optional, Tuple, Any from aukpy import config @@ -217,6 +218,7 @@ def insert( assert len(idx_id_map) == len(df) as_series = pd.Series(idx_id_map) df[f"{cls.table_name}_id"] = as_series + df.drop(list(cls.columns), axis=1, inplace=True) return df, cache @@ -312,10 +314,11 @@ class SamplingWrapper(TableWrapper): "trip_comments", "all_species_reported", "number_observers", + "location_data_id", ) insert_query = """INSERT OR IGNORE INTO sampling_event - (sampling_event_identifier, observation_date, time_observations_started, observer_id, effort_distance_km, effort_area_ha, duration_minutes, trip_comments, all_species_reported, number_observers) - VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + (sampling_event_identifier, observation_date, time_observations_started, observer_id, effort_distance_km, effort_area_ha, duration_minutes, trip_comments, all_species_reported, number_observers, location_data_id) + VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """ unique_columns = ("sampling_event_identifier",) @@ -341,11 +344,22 @@ def df_processing(cls, df: pd.DataFrame) -> pd.DataFrame: return df + @classmethod + def insert( + cls, + df: pd.DataFrame, + db: sqlite3.Connection, + cache: Optional[Dict[Any, int]] = None, + ) -> Tuple[pd.DataFrame, Dict[Any, int]]: + + df, cache = LocationWrapper.insert(df, db, cache=cache) + return super().insert(df, db, cache=cache) + class ObservationWrapper(TableWrapper): table_name = "observation" columns = ( - "location_data_id", + # "location_data_id", "species_id", "breeding_id", "protocol_id", @@ -383,7 +397,6 @@ class ObservationWrapper(TableWrapper): ?, ?, ?, - ?, ? );""".format( ",\n".join(columns) @@ -408,9 +421,29 @@ def df_processing(cls, df: pd.DataFrame) -> pd.DataFrame: return df + # @classmethod + # def insert( + # cls, + # df: pd.DataFrame, + # db: sqlite3.Connection, + # cache: Optional[Dict[Any, int]] = None, ) -> Tuple[pd.DataFrame, Dict[Any, int]]: + # # Table specific preprocessing + # if cache is None: + # cache = {} + # # sub_frame = cls.df_processing(df.loc[:, list(cls.columns)]) + # # max_id = max_id if max_id is not None else 0 + # # TODO: Optimization: Sort and drop_duplicates is probably faster. + # # groups_to_idx = sub_frame.fillna("").groupby(list(cls.unique_columns)).groups + # # new_idx = {g: idx[0] for g, idx in groups_to_idx.items() if g not in cache} + # # new_values = [sub_frame.loc[idx].tolist() for idx in new_idx.values()] + + # import pdb + # pdb.set_trace() + # # db.executemany(cls.insert_query, new_values) + # return df, cache + WRAPPERS = ( - LocationWrapper, SpeciesWrapper, BreedingWrapper, ProtocolWrapper, @@ -467,9 +500,7 @@ def build_db_pandas( df, _ = wrapper.insert(df, conn) # Store main observations table - used_columns = [y for x in WRAPPERS for y in x.columns] - just_obs = df.drop(used_columns, axis=1) - ObservationWrapper.insert(just_obs, conn) + ObservationWrapper.insert(df, conn) conn.commit() return conn @@ -507,8 +538,6 @@ def build_db_incremental( subtable_cache[wrapper.__name__] = cache # Store main observations table - used_columns = [y for x in WRAPPERS for y in x.columns] - just_obs = df.drop(used_columns, axis=1) - ObservationWrapper.insert(just_obs, conn) + ObservationWrapper.insert(df, conn) conn.commit() return conn diff --git a/tests/perf_test.py b/tests/perf_test.py index e56453e..70c5120 100644 --- a/tests/perf_test.py +++ b/tests/perf_test.py @@ -4,7 +4,7 @@ import pytest from aukpy import db -from tests.db_test import SMALL, MEDIUM +from tests.db_test import SMALL, MEDIUM, SKIP_NON_MOCKED def check_usage(csv_file: Path): @@ -28,7 +28,10 @@ def check_usage(csv_file: Path): print(f"Size of observation: {res}") -@pytest.mark.skip +@pytest.mark.skipif(**SKIP_NON_MOCKED) # type: ignore def test_data_usage(): - for csv_file in (SMALL, MEDIUM): + for csv_file in ( + SMALL, + MEDIUM, + ): check_usage(csv_file)