Skip to content

Commit

Permalink
Merge pull request #80 from vluzko/location-in-sample
Browse files Browse the repository at this point in the history
Location in sample
  • Loading branch information
vluzko authored Aug 18, 2022
2 parents a028f80 + fe58471 commit ddef6fc
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 18 deletions.
6 changes: 3 additions & 3 deletions aukpy/create_tables.sql
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,13 @@ CREATE TABLE IF NOT EXISTS sampling_event (
trip_comments text,
all_species_reported integer,
number_observers integer,
UNIQUE(sampling_event_identifier)
location_data_id integer,
UNIQUE(sampling_event_identifier),
FOREIGN KEY (location_data_id) REFERENCES location_data(id)
);

CREATE TABLE IF NOT EXISTS observation (
id integer PRIMARY KEY,
location_data_id integer NOT NULL,
species_id integer NOT NULL,
breeding_id integer,
protocol_id integer,
Expand All @@ -91,6 +92,5 @@ CREATE TABLE IF NOT EXISTS observation (
FOREIGN KEY (sampling_event_id) REFERENCES sampling_event(id),
FOREIGN KEY (species_id) REFERENCES species(id),
FOREIGN KEY (breeding_id) REFERENCES breeding(id),
FOREIGN KEY (location_data_id) REFERENCES location_data(id),
FOREIGN KEY (protocol_id) REFERENCES protocol(id)
);
46 changes: 34 additions & 12 deletions aukpy/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
import sqlite3

from pathlib import Path
from typing import Dict, List, Optional, Tuple, Any
from time import time
from typing import Dict, Optional, Tuple, Any

from aukpy import config

Expand Down Expand Up @@ -217,6 +218,7 @@ def insert(
assert len(idx_id_map) == len(df)
as_series = pd.Series(idx_id_map)
df[f"{cls.table_name}_id"] = as_series
df.drop(list(cls.columns), axis=1, inplace=True)
return df, cache


Expand Down Expand Up @@ -312,10 +314,11 @@ class SamplingWrapper(TableWrapper):
"trip_comments",
"all_species_reported",
"number_observers",
"location_data_id",
)
insert_query = """INSERT OR IGNORE INTO sampling_event
(sampling_event_identifier, observation_date, time_observations_started, observer_id, effort_distance_km, effort_area_ha, duration_minutes, trip_comments, all_species_reported, number_observers)
VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
(sampling_event_identifier, observation_date, time_observations_started, observer_id, effort_distance_km, effort_area_ha, duration_minutes, trip_comments, all_species_reported, number_observers, location_data_id)
VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
"""
unique_columns = ("sampling_event_identifier",)

Expand All @@ -341,11 +344,22 @@ def df_processing(cls, df: pd.DataFrame) -> pd.DataFrame:

return df

@classmethod
def insert(
cls,
df: pd.DataFrame,
db: sqlite3.Connection,
cache: Optional[Dict[Any, int]] = None,
) -> Tuple[pd.DataFrame, Dict[Any, int]]:

df, cache = LocationWrapper.insert(df, db, cache=cache)
return super().insert(df, db, cache=cache)


class ObservationWrapper(TableWrapper):
table_name = "observation"
columns = (
"location_data_id",
# "location_data_id",
"species_id",
"breeding_id",
"protocol_id",
Expand Down Expand Up @@ -383,7 +397,6 @@ class ObservationWrapper(TableWrapper):
?,
?,
?,
?,
?
);""".format(
",\n".join(columns)
Expand All @@ -408,9 +421,22 @@ def df_processing(cls, df: pd.DataFrame) -> pd.DataFrame:

return df

@classmethod
def insert(
cls,
df: pd.DataFrame,
db: sqlite3.Connection,
cache: Optional[Dict[Any, int]] = None,
) -> Tuple[pd.DataFrame, Dict[Any, int]]:
# Table specific preprocessing
if cache is None:
cache = {}
sub_frame = cls.df_processing(df.loc[:, list(cls.columns)])
sub_frame.to_sql("observation", con=db, if_exists="append", index=False)
return df, cache


WRAPPERS = (
LocationWrapper,
SpeciesWrapper,
BreedingWrapper,
ProtocolWrapper,
Expand Down Expand Up @@ -467,9 +493,7 @@ def build_db_pandas(
df, _ = wrapper.insert(df, conn)

# Store main observations table
used_columns = [y for x in WRAPPERS for y in x.columns]
just_obs = df.drop(used_columns, axis=1)
ObservationWrapper.insert(just_obs, conn)
ObservationWrapper.insert(df, conn)
conn.commit()
return conn

Expand Down Expand Up @@ -507,8 +531,6 @@ def build_db_incremental(
subtable_cache[wrapper.__name__] = cache

# Store main observations table
used_columns = [y for x in WRAPPERS for y in x.columns]
just_obs = df.drop(used_columns, axis=1)
ObservationWrapper.insert(just_obs, conn)
ObservationWrapper.insert(df, conn)
conn.commit()
return conn
9 changes: 6 additions & 3 deletions tests/perf_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest
from aukpy import db

from tests.db_test import SMALL, MEDIUM
from tests.db_test import SMALL, MEDIUM, SKIP_NON_MOCKED


def check_usage(csv_file: Path):
Expand All @@ -28,7 +28,10 @@ def check_usage(csv_file: Path):
print(f"Size of observation: {res}")


@pytest.mark.skip
@pytest.mark.skipif(**SKIP_NON_MOCKED) # type: ignore
def test_data_usage():
for csv_file in (SMALL, MEDIUM):
for csv_file in (
SMALL,
MEDIUM,
):
check_usage(csv_file)

0 comments on commit ddef6fc

Please sign in to comment.