From d83751596b6f774b50d601266d3eb0d5d590467b Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Tue, 16 Jul 2024 10:12:08 -0500 Subject: [PATCH 1/3] fix(Pose): remove `config_file` from class property --- aeon/io/reader.py | 36 +++++++++++++++++++----------------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/aeon/io/reader.py b/aeon/io/reader.py index 25db82a8..e5c86d5d 100644 --- a/aeon/io/reader.py +++ b/aeon/io/reader.py @@ -286,7 +286,6 @@ def __init__(self, pattern: str, model_root: str = "/ceph/aeon/aeon/data/process # `pattern` for this reader should typically be '_*' super().__init__(pattern, columns=None) self._model_root = model_root - self.config_file = None # requires reading the data file to be set def read(self, file: Path) -> pd.DataFrame: """Reads data from the Harp-binarized tracking file.""" @@ -295,9 +294,9 @@ def read(self, file: Path) -> pd.DataFrame: config_file_dir = Path(self._model_root) / model_dir if not config_file_dir.exists(): raise FileNotFoundError(f"Cannot find model dir {config_file_dir}") - self.config_file = self.get_config_file(config_file_dir) - identities = self.get_class_names() - parts = self.get_bodyparts() + config_file = self.get_config_file(config_file_dir) + identities = self.get_class_names(config_file) + parts = self.get_bodyparts(config_file) # Using bodyparts, assign column names to Harp register values, and read data in default format. try: # Bonsai.Sleap0.2 @@ -327,7 +326,7 @@ def read(self, file: Path) -> pd.DataFrame: parts = unique_parts # Set new columns, and reformat `data`. - data = self.class_int2str(data) + data = self.class_int2str(data, config_file) n_parts = len(parts) part_data_list = [pd.DataFrame()] * n_parts new_columns = ["identity", "identity_likelihood", "part", "x", "y", "part_likelihood"] @@ -350,45 +349,48 @@ def read(self, file: Path) -> pd.DataFrame: new_data = pd.concat(part_data_list) return new_data.sort_index() - def get_class_names(self) -> list[str]: + @staticmethod + def get_class_names(config_file: Path) -> list[str]: """Returns a list of classes from a model's config file.""" classes = None - with open(self.config_file) as f: + with open(config_file) as f: config = json.load(f) - if self.config_file.stem == "confmap_config": # SLEAP + if config_file.stem == "confmap_config": # SLEAP try: heads = config["model"]["heads"] classes = util.find_nested_key(heads, "class_vectors")["classes"] except KeyError as err: if not classes: - raise KeyError(f"Cannot find class_vectors in {self.config_file}.") from err + raise KeyError(f"Cannot find class_vectors in {config_file}.") from err return classes - def get_bodyparts(self) -> list[str]: + @staticmethod + def get_bodyparts(config_file: Path) -> list[str]: """Returns a list of bodyparts from a model's config file.""" parts = [] - with open(self.config_file) as f: + with open(config_file) as f: config = json.load(f) - if self.config_file.stem == "confmap_config": # SLEAP + if config_file.stem == "confmap_config": # SLEAP try: heads = config["model"]["heads"] parts = [util.find_nested_key(heads, "anchor_part")] parts += util.find_nested_key(heads, "part_names") except KeyError as err: if not parts: - raise KeyError(f"Cannot find bodyparts in {self.config_file}.") from err + raise KeyError(f"Cannot find bodyparts in {config_file}.") from err return parts - def class_int2str(self, data: pd.DataFrame) -> pd.DataFrame: + @staticmethod + def class_int2str(data: pd.DataFrame, config_file: Path) -> pd.DataFrame: """Converts a class integer in a tracking data dataframe to its associated string (subject id).""" - if self.config_file.stem == "confmap_config": # SLEAP - with open(self.config_file) as f: + if config_file.stem == "confmap_config": # SLEAP + with open(config_file) as f: config = json.load(f) try: heads = config["model"]["heads"] classes = util.find_nested_key(heads, "classes") except KeyError as err: - raise KeyError(f"Cannot find classes in {self.config_file}.") from err + raise KeyError(f"Cannot find classes in {config_file}.") from err for i, subj in enumerate(classes): data.loc[data["identity"] == i, "identity"] = subj return data From 96e9e469a77a02306d153f9a89c976e313b17894 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Tue, 16 Jul 2024 10:22:36 -0500 Subject: [PATCH 2/3] feat: update SLEAP ingestion - no longer dependent on `config_file` --- aeon/dj_pipeline/tracking.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/aeon/dj_pipeline/tracking.py b/aeon/dj_pipeline/tracking.py index 22ddf978..9ef747e2 100644 --- a/aeon/dj_pipeline/tracking.py +++ b/aeon/dj_pipeline/tracking.py @@ -174,10 +174,8 @@ def make(self, key): if not len(pose_data): raise ValueError(f"No SLEAP data found for {key['experiment_name']} - {device_name}") - # get bodyparts and classes - bodyparts = stream_reader.get_bodyparts() - anchor_part = bodyparts[0] # anchor_part is always the first one - class_names = stream_reader.get_class_names() + # get identity names + class_names = np.unique(pose_data.identity) identity_mapping = {n: i for i, n in enumerate(class_names)} # ingest parts and classes @@ -186,6 +184,10 @@ def make(self, key): identity_position = pose_data[pose_data["identity"] == identity] if identity_position.empty: continue + + # get anchor part - always the first one of all the body parts + anchor_part = np.unique(identity_position.part)[0] + for part in set(identity_position.part.values): part_position = identity_position[identity_position.part == part] part_entries.append( From 3411fe86dfe9c7528951aee440204881496bf557 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Thu, 18 Jul 2024 12:41:14 -0500 Subject: [PATCH 3/3] fix: update logic to associate true pellet times with each threshold update time --- aeon/dj_pipeline/__init__.py | 1 + aeon/dj_pipeline/analysis/block_analysis.py | 48 +++++++++++++++------ 2 files changed, 37 insertions(+), 12 deletions(-) diff --git a/aeon/dj_pipeline/__init__.py b/aeon/dj_pipeline/__init__.py index 72e57718..d8f201c0 100644 --- a/aeon/dj_pipeline/__init__.py +++ b/aeon/dj_pipeline/__init__.py @@ -45,6 +45,7 @@ def fetch_stream(query, drop_pk=True): df.rename(columns={"timestamps": "time"}, inplace=True) df.set_index("time", inplace=True) df.sort_index(inplace=True) + df = df.convert_dtypes(convert_string=False, convert_integer=False, convert_boolean=False, convert_floating=False) return df diff --git a/aeon/dj_pipeline/analysis/block_analysis.py b/aeon/dj_pipeline/analysis/block_analysis.py index 4d5b0943..27fe4812 100644 --- a/aeon/dj_pipeline/analysis/block_analysis.py +++ b/aeon/dj_pipeline/analysis/block_analysis.py @@ -189,18 +189,42 @@ def make(self, key): patch_keys, patch_names = patch_query.fetch("KEY", "underground_feeder_name") for patch_key, patch_name in zip(patch_keys, patch_names): - delivered_pellet_df = fetch_stream( + # pellet delivery and patch threshold data + beam_break_df = fetch_stream( streams.UndergroundFeederBeamBreak & patch_key & chunk_restriction )[block_start:block_end] depletion_state_df = fetch_stream( streams.UndergroundFeederDepletionState & patch_key & chunk_restriction )[block_start:block_end] + # remove NaNs from threshold column + depletion_state_df = depletion_state_df.dropna(subset=["threshold"]) + # identify & remove invalid indices where the time difference is less than 1 second + invalid_indices = np.where(depletion_state_df.index.to_series().diff().dt.total_seconds() < 1)[0] + depletion_state_df = depletion_state_df.drop(depletion_state_df.index[invalid_indices]) + + # find pellet times associated with each threshold update + # for each threshold, find the time of the next threshold update, + # find the closest beam break after this update time, + # and use this beam break time as the delivery time for the initial threshold + pellet_ts_threshold_df = depletion_state_df.copy() + pellet_ts_threshold_df["pellet_timestamp"] = pd.NaT + for threshold_idx in range(len(pellet_ts_threshold_df) - 1): + if np.isnan(pellet_ts_threshold_df.threshold.iloc[threshold_idx]): + continue + next_threshold_time = pellet_ts_threshold_df.index[threshold_idx + 1] + post_thresh_pellet_ts = beam_break_df.index[beam_break_df.index > next_threshold_time] + next_beam_break = post_thresh_pellet_ts[np.searchsorted(post_thresh_pellet_ts, next_threshold_time)] + pellet_ts_threshold_df.pellet_timestamp.iloc[threshold_idx] = next_beam_break + # remove NaNs from pellet_timestamp column (last row) + pellet_ts_threshold_df = pellet_ts_threshold_df.dropna(subset=["pellet_timestamp"]) + + # wheel encoder data encoder_df = fetch_stream(streams.UndergroundFeederEncoder & patch_key & chunk_restriction)[ block_start:block_end ] # filter out maintenance period based on logs - pellet_df = filter_out_maintenance_periods( - delivered_pellet_df, + pellet_ts_threshold_df = filter_out_maintenance_periods( + pellet_ts_threshold_df, maintenance_period, block_end, dropna=True, @@ -229,7 +253,6 @@ def make(self, key): patch_rate = depletion_state_df.rate.iloc[0] patch_offset = depletion_state_df.offset.iloc[0] - # handles patch rate value being INF patch_rate = 999999999 if np.isinf(patch_rate) else patch_rate @@ -237,14 +260,14 @@ def make(self, key): { **key, "patch_name": patch_name, - "pellet_count": len(pellet_df), - "pellet_timestamps": pellet_df.index.values, + "pellet_count": len(pellet_ts_threshold_df), + "pellet_timestamps": pellet_ts_threshold_df.pellet_timestamp.values, "wheel_cumsum_distance_travelled": encoder_df.distance_travelled.values[ ::wheel_downsampling_factor ], "wheel_timestamps": encoder_df.index.values[::wheel_downsampling_factor], - "patch_threshold": depletion_state_df.threshold.values, - "patch_threshold_timestamps": depletion_state_df.index.values, + "patch_threshold": pellet_ts_threshold_df.threshold.values, + "patch_threshold_timestamps": pellet_ts_threshold_df.index.values, "patch_rate": patch_rate, "patch_offset": patch_offset, } @@ -267,7 +290,7 @@ def make(self, key): subject_names = [] for subject_name in set(subject_visits_df.id): _df = subject_visits_df[subject_visits_df.id == subject_name] - if _df.type[-1] != "Exit": + if _df.type.iloc[-1] != "Exit": subject_names.append(subject_name) for subject_name in subject_names: @@ -454,7 +477,7 @@ def make(self, key): "dist_to_patch" ].values - # Get closest subject to patch at each pel del timestep + # Get closest subject to patch at each pellet timestep closest_subjects_pellet_ts = dist_to_patch_pel_ts_id_df.idxmin(axis=1) # Get closest subject to patch at each wheel timestep cum_wheel_dist_subj_df = pd.DataFrame( @@ -481,9 +504,10 @@ def make(self, key): all_subj_patch_pref_dict[patch["patch_name"]][subject_name][ "cum_time" ] = subject_in_patch_cum_time - subj_pellets = closest_subjects_pellet_ts[closest_subjects_pellet_ts == subject_name] - subj_patch_thresh = patch["patch_threshold"][np.searchsorted(patch["patch_threshold_timestamps"], subj_pellets.index.values) - 1] + closest_subj_mask = closest_subjects_pellet_ts == subject_name + subj_pellets = closest_subjects_pellet_ts[closest_subj_mask] + subj_patch_thresh = patch["patch_threshold"][closest_subj_mask] self.Patch.insert1( key