diff --git a/aukpy/create_tables.sql b/aukpy/create_tables.sql index 19ff685..89fb6d1 100644 --- a/aukpy/create_tables.sql +++ b/aukpy/create_tables.sql @@ -67,12 +67,13 @@ CREATE TABLE IF NOT EXISTS sampling_event ( trip_comments text, all_species_reported integer, number_observers integer, - UNIQUE(sampling_event_identifier) + location_data_id integer, + UNIQUE(sampling_event_identifier), + FOREIGN KEY (location_data_id) REFERENCES location_data(id) ); CREATE TABLE IF NOT EXISTS observation ( id integer PRIMARY KEY, - location_data_id integer NOT NULL, species_id integer NOT NULL, breeding_id integer, protocol_id integer, @@ -91,6 +92,5 @@ CREATE TABLE IF NOT EXISTS observation ( FOREIGN KEY (sampling_event_id) REFERENCES sampling_event(id), FOREIGN KEY (species_id) REFERENCES species(id), FOREIGN KEY (breeding_id) REFERENCES breeding(id), - FOREIGN KEY (location_data_id) REFERENCES location_data(id), FOREIGN KEY (protocol_id) REFERENCES protocol(id) ); diff --git a/aukpy/db.py b/aukpy/db.py index 9bb7859..b9f2c64 100644 --- a/aukpy/db.py +++ b/aukpy/db.py @@ -9,7 +9,8 @@ import sqlite3 from pathlib import Path -from typing import Dict, List, Optional, Tuple, Any +from time import time +from typing import Dict, Optional, Tuple, Any from aukpy import config @@ -217,6 +218,7 @@ def insert( assert len(idx_id_map) == len(df) as_series = pd.Series(idx_id_map) df[f"{cls.table_name}_id"] = as_series + df.drop(list(cls.columns), axis=1, inplace=True) return df, cache @@ -312,10 +314,11 @@ class SamplingWrapper(TableWrapper): "trip_comments", "all_species_reported", "number_observers", + "location_data_id", ) insert_query = """INSERT OR IGNORE INTO sampling_event - (sampling_event_identifier, observation_date, time_observations_started, observer_id, effort_distance_km, effort_area_ha, duration_minutes, trip_comments, all_species_reported, number_observers) - VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + (sampling_event_identifier, observation_date, time_observations_started, observer_id, effort_distance_km, effort_area_ha, duration_minutes, trip_comments, all_species_reported, number_observers, location_data_id) + VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """ unique_columns = ("sampling_event_identifier",) @@ -341,11 +344,22 @@ def df_processing(cls, df: pd.DataFrame) -> pd.DataFrame: return df + @classmethod + def insert( + cls, + df: pd.DataFrame, + db: sqlite3.Connection, + cache: Optional[Dict[Any, int]] = None, + ) -> Tuple[pd.DataFrame, Dict[Any, int]]: + + df, cache = LocationWrapper.insert(df, db, cache=cache) + return super().insert(df, db, cache=cache) + class ObservationWrapper(TableWrapper): table_name = "observation" columns = ( - "location_data_id", + # "location_data_id", "species_id", "breeding_id", "protocol_id", @@ -383,7 +397,6 @@ class ObservationWrapper(TableWrapper): ?, ?, ?, - ?, ? );""".format( ",\n".join(columns) @@ -408,9 +421,22 @@ def df_processing(cls, df: pd.DataFrame) -> pd.DataFrame: return df + @classmethod + def insert( + cls, + df: pd.DataFrame, + db: sqlite3.Connection, + cache: Optional[Dict[Any, int]] = None, + ) -> Tuple[pd.DataFrame, Dict[Any, int]]: + # Table specific preprocessing + if cache is None: + cache = {} + sub_frame = cls.df_processing(df.loc[:, list(cls.columns)]) + sub_frame.to_sql("observation", con=db, if_exists="append", index=False) + return df, cache + WRAPPERS = ( - LocationWrapper, SpeciesWrapper, BreedingWrapper, ProtocolWrapper, @@ -467,9 +493,7 @@ def build_db_pandas( df, _ = wrapper.insert(df, conn) # Store main observations table - used_columns = [y for x in WRAPPERS for y in x.columns] - just_obs = df.drop(used_columns, axis=1) - ObservationWrapper.insert(just_obs, conn) + ObservationWrapper.insert(df, conn) conn.commit() return conn @@ -507,8 +531,6 @@ def build_db_incremental( subtable_cache[wrapper.__name__] = cache # Store main observations table - used_columns = [y for x in WRAPPERS for y in x.columns] - just_obs = df.drop(used_columns, axis=1) - ObservationWrapper.insert(just_obs, conn) + ObservationWrapper.insert(df, conn) conn.commit() return conn diff --git a/tests/perf_test.py b/tests/perf_test.py index e56453e..70c5120 100644 --- a/tests/perf_test.py +++ b/tests/perf_test.py @@ -4,7 +4,7 @@ import pytest from aukpy import db -from tests.db_test import SMALL, MEDIUM +from tests.db_test import SMALL, MEDIUM, SKIP_NON_MOCKED def check_usage(csv_file: Path): @@ -28,7 +28,10 @@ def check_usage(csv_file: Path): print(f"Size of observation: {res}") -@pytest.mark.skip +@pytest.mark.skipif(**SKIP_NON_MOCKED) # type: ignore def test_data_usage(): - for csv_file in (SMALL, MEDIUM): + for csv_file in ( + SMALL, + MEDIUM, + ): check_usage(csv_file)