Skip to content

Commit

Permalink
Switched to dropping columns immediately after insertion (frees memor…
Browse files Browse the repository at this point in the history
…y sooner). Closes #78
  • Loading branch information
vluzko committed Aug 18, 2022
1 parent a028f80 commit a06f7bb
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 18 deletions.
6 changes: 3 additions & 3 deletions aukpy/create_tables.sql
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,13 @@ CREATE TABLE IF NOT EXISTS sampling_event (
trip_comments text,
all_species_reported integer,
number_observers integer,
UNIQUE(sampling_event_identifier)
location_data_id integer,
UNIQUE(sampling_event_identifier),
FOREIGN KEY (location_data_id) REFERENCES location_data(id)
);

CREATE TABLE IF NOT EXISTS observation (
id integer PRIMARY KEY,
location_data_id integer NOT NULL,
species_id integer NOT NULL,
breeding_id integer,
protocol_id integer,
Expand All @@ -91,6 +92,5 @@ CREATE TABLE IF NOT EXISTS observation (
FOREIGN KEY (sampling_event_id) REFERENCES sampling_event(id),
FOREIGN KEY (species_id) REFERENCES species(id),
FOREIGN KEY (breeding_id) REFERENCES breeding(id),
FOREIGN KEY (location_data_id) REFERENCES location_data(id),
FOREIGN KEY (protocol_id) REFERENCES protocol(id)
);
53 changes: 41 additions & 12 deletions aukpy/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
import sqlite3

from pathlib import Path
from typing import Dict, List, Optional, Tuple, Any
from time import time
from typing import Dict, Optional, Tuple, Any

from aukpy import config

Expand Down Expand Up @@ -217,6 +218,7 @@ def insert(
assert len(idx_id_map) == len(df)
as_series = pd.Series(idx_id_map)
df[f"{cls.table_name}_id"] = as_series
df.drop(list(cls.columns), axis=1, inplace=True)
return df, cache


Expand Down Expand Up @@ -312,10 +314,11 @@ class SamplingWrapper(TableWrapper):
"trip_comments",
"all_species_reported",
"number_observers",
"location_data_id",
)
insert_query = """INSERT OR IGNORE INTO sampling_event
(sampling_event_identifier, observation_date, time_observations_started, observer_id, effort_distance_km, effort_area_ha, duration_minutes, trip_comments, all_species_reported, number_observers)
VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
(sampling_event_identifier, observation_date, time_observations_started, observer_id, effort_distance_km, effort_area_ha, duration_minutes, trip_comments, all_species_reported, number_observers, location_data_id)
VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
"""
unique_columns = ("sampling_event_identifier",)

Expand All @@ -341,11 +344,22 @@ def df_processing(cls, df: pd.DataFrame) -> pd.DataFrame:

return df

@classmethod
def insert(
cls,
df: pd.DataFrame,
db: sqlite3.Connection,
cache: Optional[Dict[Any, int]] = None,
) -> Tuple[pd.DataFrame, Dict[Any, int]]:

df, cache = LocationWrapper.insert(df, db, cache=cache)
return super().insert(df, db, cache=cache)


class ObservationWrapper(TableWrapper):
table_name = "observation"
columns = (
"location_data_id",
# "location_data_id",
"species_id",
"breeding_id",
"protocol_id",
Expand Down Expand Up @@ -383,7 +397,6 @@ class ObservationWrapper(TableWrapper):
?,
?,
?,
?,
?
);""".format(
",\n".join(columns)
Expand All @@ -408,9 +421,29 @@ def df_processing(cls, df: pd.DataFrame) -> pd.DataFrame:

return df

# @classmethod
# def insert(
# cls,
# df: pd.DataFrame,
# db: sqlite3.Connection,
# cache: Optional[Dict[Any, int]] = None, ) -> Tuple[pd.DataFrame, Dict[Any, int]]:
# # Table specific preprocessing
# if cache is None:
# cache = {}
# # sub_frame = cls.df_processing(df.loc[:, list(cls.columns)])
# # max_id = max_id if max_id is not None else 0
# # TODO: Optimization: Sort and drop_duplicates is probably faster.
# # groups_to_idx = sub_frame.fillna("").groupby(list(cls.unique_columns)).groups
# # new_idx = {g: idx[0] for g, idx in groups_to_idx.items() if g not in cache}
# # new_values = [sub_frame.loc[idx].tolist() for idx in new_idx.values()]

# import pdb
# pdb.set_trace()
# # db.executemany(cls.insert_query, new_values)
# return df, cache


WRAPPERS = (
LocationWrapper,
SpeciesWrapper,
BreedingWrapper,
ProtocolWrapper,
Expand Down Expand Up @@ -467,9 +500,7 @@ def build_db_pandas(
df, _ = wrapper.insert(df, conn)

# Store main observations table
used_columns = [y for x in WRAPPERS for y in x.columns]
just_obs = df.drop(used_columns, axis=1)
ObservationWrapper.insert(just_obs, conn)
ObservationWrapper.insert(df, conn)
conn.commit()
return conn

Expand Down Expand Up @@ -507,8 +538,6 @@ def build_db_incremental(
subtable_cache[wrapper.__name__] = cache

# Store main observations table
used_columns = [y for x in WRAPPERS for y in x.columns]
just_obs = df.drop(used_columns, axis=1)
ObservationWrapper.insert(just_obs, conn)
ObservationWrapper.insert(df, conn)
conn.commit()
return conn
9 changes: 6 additions & 3 deletions tests/perf_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest
from aukpy import db

from tests.db_test import SMALL, MEDIUM
from tests.db_test import SMALL, MEDIUM, SKIP_NON_MOCKED


def check_usage(csv_file: Path):
Expand All @@ -28,7 +28,10 @@ def check_usage(csv_file: Path):
print(f"Size of observation: {res}")


@pytest.mark.skip
@pytest.mark.skipif(**SKIP_NON_MOCKED) # type: ignore
def test_data_usage():
for csv_file in (SMALL, MEDIUM):
for csv_file in (
SMALL,
MEDIUM,
):
check_usage(csv_file)

0 comments on commit a06f7bb

Please sign in to comment.