diff --git a/aeon/dj_pipeline/acquisition.py b/aeon/dj_pipeline/acquisition.py index b20c1a0c..eac5b0e2 100644 --- a/aeon/dj_pipeline/acquisition.py +++ b/aeon/dj_pipeline/acquisition.py @@ -563,6 +563,47 @@ def make(self, key): ) +@schema +class EnvironmentActiveConfiguration(dj.Imported): + definition = """ # Environment Active Configuration + -> Chunk + """ + + class Name(dj.Part): + definition = """ + -> master + time: datetime(6) # time when the configuration is applied to the environment + --- + name: varchar(32) # name of the environment configuration + value: longblob # dictionary of the configuration + """ + + def make(self, key): + chunk_start, chunk_end = (Chunk & key).fetch1("chunk_start", "chunk_end") + data_dirs = Experiment.get_data_directories(key) + devices_schema = getattr( + aeon_schemas, + (Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]}).fetch1( + "devices_schema_name" + ), + ) + device = devices_schema.Environment + stream_reader = device.EnvironmentActiveConfiguration # expecting columns: time, name, value + stream_data = io_api.load( + root=data_dirs, + reader=stream_reader, + start=pd.Timestamp(chunk_start), + end=pd.Timestamp(chunk_end), + ) + + stream_data.reset_index(inplace=True) + for k, v in key.items(): + stream_data[k] = v + + self.insert1(key) + self.Name.insert(stream_data) + + # ---- HELPERS ---- diff --git a/aeon/dj_pipeline/analysis/block_analysis.py b/aeon/dj_pipeline/analysis/block_analysis.py index e05c70d8..30e8b258 100644 --- a/aeon/dj_pipeline/analysis/block_analysis.py +++ b/aeon/dj_pipeline/analysis/block_analysis.py @@ -162,7 +162,7 @@ def make(self, key): chunk_keys = (acquisition.Chunk & key & chunk_restriction).fetch("KEY") streams_tables = ( streams.UndergroundFeederDepletionState, - streams.UndergroundFeederBeamBreak, + streams.UndergroundFeederDeliverPellet, streams.UndergroundFeederEncoder, tracking.SLEAPTracking, ) @@ -190,35 +190,11 @@ def make(self, key): for patch_key, patch_name in zip(patch_keys, patch_names): # pellet delivery and patch threshold data - beam_break_df = fetch_stream( - streams.UndergroundFeederBeamBreak & patch_key & chunk_restriction - )[block_start:block_end] depletion_state_df = fetch_stream( streams.UndergroundFeederDepletionState & patch_key & chunk_restriction )[block_start:block_end] - # remove NaNs from threshold column - depletion_state_df = depletion_state_df.dropna(subset=["threshold"]) - # identify & remove invalid indices where the time difference is less than 1 second - invalid_indices = np.where(depletion_state_df.index.to_series().diff().dt.total_seconds() < 1)[0] - depletion_state_df = depletion_state_df.drop(depletion_state_df.index[invalid_indices]) - - # find pellet times associated with each threshold update - # for each threshold, find the time of the next threshold update, - # find the closest beam break after this update time, - # and use this beam break time as the delivery time for the initial threshold - pellet_ts_threshold_df = depletion_state_df.copy() - pellet_ts_threshold_df["pellet_timestamp"] = pd.NaT - for threshold_idx in range(len(pellet_ts_threshold_df) - 1): - if np.isnan(pellet_ts_threshold_df.threshold.iloc[threshold_idx]): - continue - next_threshold_time = pellet_ts_threshold_df.index[threshold_idx + 1] - post_thresh_pellet_ts = beam_break_df.index[beam_break_df.index > next_threshold_time] - if post_thresh_pellet_ts.empty: - break - next_beam_break = post_thresh_pellet_ts[np.searchsorted(post_thresh_pellet_ts, next_threshold_time)] - pellet_ts_threshold_df.pellet_timestamp.iloc[threshold_idx] = next_beam_break - # remove NaNs from pellet_timestamp column (last row) - pellet_ts_threshold_df = pellet_ts_threshold_df.dropna(subset=["pellet_timestamp"]) + + pellet_ts_threshold_df = get_threshold_associated_pellets(patch_key, block_start, block_end) # wheel encoder data encoder_df = fetch_stream(streams.UndergroundFeederEncoder & patch_key & chunk_restriction)[ @@ -798,3 +774,53 @@ class AnalysisNote(dj.Manual): note_type='': varchar(64) note: varchar(3000) """ + +# ---- Helper Functions ---- + + +def get_threshold_associated_pellets(patch_key, start, end): + """ + Retrieve the pellet delivery timestamps associated with each patch threshold update within the specified start-end time. + 1. Get all patch state update timestamps: let's call these events "A" + 2. Remove all "A" events near manual pellet delivery events (so we don't include manual pellet delivery events in downstream analysis) + 3. For the remaining "A" events, find the nearest delivery event within 1s: for this delivery event, check if there are any repeat delivery events within 0.5 seconds - take the last of these as the pellet delivery timestamp (discard all "A" events that don't have such a corresponding delivery event) + 4. Now for these 'clean' "A" events, go back in time to the SECOND preceding pellet threshold value: this is the threshold value for this pellet delivery (as seen in this image we discussed before) + """ + chunk_restriction = acquisition.create_chunk_restriction( + patch_key["experiment_name"], start, end + ) + # pellet delivery and patch threshold data + delivered_pellet_df = fetch_stream( + streams.UndergroundFeederDeliverPellet & patch_key & chunk_restriction + )[start:end] + depletion_state_df = fetch_stream( + streams.UndergroundFeederDepletionState & patch_key & chunk_restriction + )[start:end] + # remove NaNs from threshold column + depletion_state_df = depletion_state_df.dropna(subset=["threshold"]) + # identify & remove invalid indices where the time difference is less than 1 second + invalid_indices = np.where(depletion_state_df.index.to_series().diff().dt.total_seconds() < 1)[0] + depletion_state_df = depletion_state_df.drop(depletion_state_df.index[invalid_indices]) + + # find pellet times approximately coincide with each threshold update + # i.e. nearest pellet delivery within 100ms before or after threshold update + delivered_pellet_ts = delivered_pellet_df.index + pellet_ts_threshold_df = depletion_state_df.copy() + pellet_ts_threshold_df["pellet_timestamp"] = pd.NaT + for threshold_idx in range(len(pellet_ts_threshold_df)): + threshold_time = pellet_ts_threshold_df.index[threshold_idx] + within_range_pellet_ts = np.logical_and(delivered_pellet_ts >= threshold_time - pd.Timedelta(milliseconds=100), + delivered_pellet_ts <= threshold_time + pd.Timedelta(milliseconds=100)) + if not within_range_pellet_ts.any(): + continue + pellet_time = delivered_pellet_ts[within_range_pellet_ts][-1] + pellet_ts_threshold_df.pellet_timestamp.iloc[threshold_idx] = pellet_time + + # remove rows of threshold updates without corresponding pellet times from i.e. pellet_timestamp is NaN + pellet_ts_threshold_df = pellet_ts_threshold_df.dropna(subset=["pellet_timestamp"]) + # shift back the pellet_timestamp values by 1 to match the pellet_timestamp with the previous threshold update + pellet_ts_threshold_df.pellet_timestamp = pellet_ts_threshold_df.pellet_timestamp.shift(-1) + # remove NaNs from pellet_timestamp column (last row) + pellet_ts_threshold_df = pellet_ts_threshold_df.dropna(subset=["pellet_timestamp"]) + + return pellet_ts_threshold_df diff --git a/aeon/io/reader.py b/aeon/io/reader.py index e5c86d5d..4d455432 100644 --- a/aeon/io/reader.py +++ b/aeon/io/reader.py @@ -135,6 +135,26 @@ def read(self, file): ) +class JsonList(Reader): + """Extracts data from json list (.jsonl) files, where the key "seconds" + stores the Aeon timestamp, in seconds. + """ + + def __init__(self, pattern, columns=(), root_key="value", extension="jsonl"): + super().__init__(pattern, columns, extension) + self.columns = columns + self.root_key = root_key + + def read(self, file): + """Reads data from the specified jsonl file.""" + with open(file, "r") as f: + df = pd.read_json(f, lines=True) + df.set_index("seconds", inplace=True) + for column in self.columns: + df[column] = df[self.root_key].apply(lambda x: x[column]) + return df + + class Subject(Csv): """Extracts metadata for subjects entering and exiting the environment. diff --git a/aeon/schema/schemas.py b/aeon/schema/schemas.py index 2738d522..a843b097 100644 --- a/aeon/schema/schemas.py +++ b/aeon/schema/schemas.py @@ -116,7 +116,7 @@ social03 = DotMap( [ Device("Metadata", stream.Metadata), - Device("Environment", social_02.Environment, social_02.SubjectData), + Device("Environment", social_02.Environment, social_02.SubjectData, social_03.EnvironmentActiveConfiguration), Device("CameraTop", stream.Video, social_03.Pose), Device("CameraNorth", stream.Video), Device("CameraSouth", stream.Video), @@ -147,7 +147,7 @@ social04 = DotMap( [ Device("Metadata", stream.Metadata), - Device("Environment", social_02.Environment, social_02.SubjectData), + Device("Environment", social_02.Environment, social_02.SubjectData, social_03.EnvironmentActiveConfiguration), Device("CameraTop", stream.Video, social_03.Pose), Device("CameraNorth", stream.Video), Device("CameraSouth", stream.Video), diff --git a/aeon/schema/social_03.py b/aeon/schema/social_03.py index 558b39c9..a3bc2cdf 100644 --- a/aeon/schema/social_03.py +++ b/aeon/schema/social_03.py @@ -1,3 +1,5 @@ +import json +import pandas as pd import aeon.io.reader as _reader from aeon.schema.streams import Stream @@ -6,3 +8,9 @@ class Pose(Stream): def __init__(self, path): super().__init__(_reader.Pose(f"{path}_202_*")) + + +class EnvironmentActiveConfiguration(Stream): + + def __init__(self, path): + super().__init__(_reader.JsonList(f"{path}_ActiveConfiguration_*", columns=["name"]))