From bb5a676bc00ab36cab1a0712f436f9b6a48f6b86 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Tue, 30 Jan 2024 18:03:46 -0600 Subject: [PATCH 1/3] fix: :bug: fix type annotation error with __future__ --- aeon/util.py | 1 + 1 file changed, 1 insertion(+) diff --git a/aeon/util.py b/aeon/util.py index aed34803..a9dc88bc 100644 --- a/aeon/util.py +++ b/aeon/util.py @@ -1,4 +1,5 @@ """Utility functions.""" +from __future__ import annotations from typing import Any From 5766c0bcf9bdb6f7f40d420bcb4d5878a2882839 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Wed, 31 Jan 2024 15:12:56 +0000 Subject: [PATCH 2/3] fix: :bug: fix naming error & streamline BlockPlots make function --- aeon/dj_pipeline/analysis/block_analysis.py | 32 +++++++++++---------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/aeon/dj_pipeline/analysis/block_analysis.py b/aeon/dj_pipeline/analysis/block_analysis.py index 342528f6..ee0e7020 100644 --- a/aeon/dj_pipeline/analysis/block_analysis.py +++ b/aeon/dj_pipeline/analysis/block_analysis.py @@ -1,17 +1,15 @@ import datetime -import datajoint as dj -import pandas as pd import json + +import datajoint as dj import numpy as np +import pandas as pd from aeon.analysis import utils as analysis_utils - -from aeon.dj_pipeline import get_schema_name, fetch_stream -from aeon.dj_pipeline import acquisition, tracking, streams -from aeon.dj_pipeline.analysis.visit import ( - get_maintenance_periods, - filter_out_maintenance_periods, -) +from aeon.dj_pipeline import (acquisition, fetch_stream, get_schema_name, + streams, tracking) +from aeon.dj_pipeline.analysis.visit import (filter_out_maintenance_periods, + get_maintenance_periods) schema = dj.schema(get_schema_name("analysis")) @@ -47,7 +45,7 @@ class Patch(dj.Part): wheel_timestamps: longblob patch_threshold: longblob patch_threshold_timestamps: longblob - patch_rate: float + patch_rate: float """ class Subject(dj.Part): @@ -207,7 +205,7 @@ class Patch(dj.Part): pellet_timestamps: longblob wheel_distance_travelled: longblob # wheel's cumulative distance travelled wheel_timestamps: longblob - cumulative_sum_preference: longblob + cumulative_sum_preference: longblob windowed_sum_preference: longblob """ @@ -235,7 +233,10 @@ def make(self, key): # Make plotly plots weight_fig = go.Figure() pos_fig = go.Figure() - for subject_data in (BlockAnalysis.Subject & key).fetch(as_dict=True): + wheel_fig = go.Figure() + + for subject_data in (BlockAnalysis.Subject & key): + # Subject weight over time weight_fig.add_trace( go.Scatter( x=subject_data["weight_timestamps"], @@ -244,6 +245,7 @@ def make(self, key): name=subject_data["subject_name"], ) ) + # Subject position over time mask = subject_data["position_likelihood"] > conf_thresh pos_fig.add_trace( go.Scatter3d( @@ -255,12 +257,12 @@ def make(self, key): ) ) - wheel_fig = go.Figure() - for patch_data in (BlockAnalysis.Patch & key).fetch(as_dict=True): + # Cumulative wheel distance travelled over time + for patch_data in (BlockAnalysis.Patch & key): wheel_fig.add_trace( go.Scatter( x=patch_data["wheel_timestamps"][::2], - y=patch_data["cumulative_distance_travelled"][::2], + y=patch_data["wheel_cumsum_distance_travelled"][::2], mode="lines", name=patch_data["patch_name"], ) From 1271fa49d749a752baa1f175fcfa26a4dd84bdfd Mon Sep 17 00:00:00 2001 From: JaerongA Date: Wed, 31 Jan 2024 16:24:50 +0000 Subject: [PATCH 3/3] feat: :sparkles: add patch_rate_plot and cumulative_pellet_plot to BlockPlots --- aeon/dj_pipeline/analysis/block_analysis.py | 74 +++++++++++++++------ 1 file changed, 53 insertions(+), 21 deletions(-) diff --git a/aeon/dj_pipeline/analysis/block_analysis.py b/aeon/dj_pipeline/analysis/block_analysis.py index ee0e7020..6f80ec6e 100644 --- a/aeon/dj_pipeline/analysis/block_analysis.py +++ b/aeon/dj_pipeline/analysis/block_analysis.py @@ -1,15 +1,14 @@ -import datetime import json import datajoint as dj import numpy as np import pandas as pd +import plotly.express as px +import plotly.graph_objs as go from aeon.analysis import utils as analysis_utils -from aeon.dj_pipeline import (acquisition, fetch_stream, get_schema_name, - streams, tracking) -from aeon.dj_pipeline.analysis.visit import (filter_out_maintenance_periods, - get_maintenance_periods) +from aeon.dj_pipeline import acquisition, fetch_stream, get_schema_name, streams, tracking +from aeon.dj_pipeline.analysis.visit import filter_out_maintenance_periods, get_maintenance_periods schema = dj.schema(get_schema_name("analysis")) @@ -63,13 +62,11 @@ class Subject(dj.Part): """ def make(self, key): - """ - Restrict, fetch and aggregate data from different streams to produce intermediate data products - at a per-block level (for different patches and different subjects) - 1. Query data for all chunks within the block - 2. Fetch streams, filter by maintenance period - 3. Fetch subject position data (SLEAP) - 4. Aggregate and insert into the table + """Restrict, fetch and aggregate data from different streams to produce intermediate data products at a per-block level (for different patches and different subjects). + 1. Query data for all chunks within the block. + 2. Fetch streams, filter by maintenance period. + 3. Fetch subject position data (SLEAP). + 4. Aggregate and insert into the table. """ block_start, block_end = (Block & key).fetch1("block_start", "block_end") @@ -163,7 +160,7 @@ def make(self, key): pos_df = filter_out_maintenance_periods(pos_df, maintenance_period, block_end) position_diff = np.sqrt( - (np.square(np.diff(pos_df.x.astype(float))) + np.square(np.diff(pos_df.y.astype(float)))) + np.square(np.diff(pos_df.x.astype(float))) + np.square(np.diff(pos_df.y.astype(float))) ) cumsum_distance_travelled = np.concatenate([[0], np.cumsum(position_diff)]) @@ -221,11 +218,11 @@ class BlockPlots(dj.Computed): subject_positions_plot: longblob subject_weights_plot: longblob patch_distance_travelled_plot: longblob + patch_rate_plot: longblob + cumulative_pellet_plot: longblob """ def make(self, key): - import plotly.graph_objs as go - # For position data , set confidence threshold to return position values and downsample by 5x conf_thresh = 0.9 downsampling_factor = 5 @@ -234,8 +231,10 @@ def make(self, key): weight_fig = go.Figure() pos_fig = go.Figure() wheel_fig = go.Figure() + patch_rate_fig = go.Figure() + cumulative_pellet_fig = go.Figure() - for subject_data in (BlockAnalysis.Subject & key): + for subject_data in BlockAnalysis.Subject & key: # Subject weight over time weight_fig.add_trace( go.Scatter( @@ -258,7 +257,7 @@ def make(self, key): ) # Cumulative wheel distance travelled over time - for patch_data in (BlockAnalysis.Patch & key): + for patch_data in BlockAnalysis.Patch & key: wheel_fig.add_trace( go.Scatter( x=patch_data["wheel_timestamps"][::2], @@ -268,13 +267,48 @@ def make(self, key): ) ) - # insert figures as json-formatted plotly plots + # Create a bar chart for patch rates + patch_df = (BlockAnalysis.Patch & key).fetch(format="frame").reset_index() + patch_rate_fig = px.bar( + patch_df, + x="patch_name", + y="patch_rate", + color="patch_name", + title="Patch Stats: Patch Rate for Each Patch", + labels={"patch_name": "Patch Name", "patch_rate": "Patch Rate"}, + text="patch_rate", + ) + patch_rate_fig.update_layout(bargap=0.2, width=600, height=400, template="simple_white") + + # Cumulative pellets per patch over time + for _, row in patch_df.iterrows(): + timestamps = row["pellet_timestamps"] + total_pellet_count = list(range(1, row["pellet_count"] + 1)) + + cumulative_pellet_fig.add_trace( + go.Scatter(x=timestamps, y=total_pellet_count, mode="lines+markers", name=row["patch_name"]) + ) + + cumulative_pellet_fig.update_layout( + title="Cumulative Pellet Count Over Time", + xaxis_title="Time", + yaxis_title="Cumulative Pellet Count", + width=800, + height=500, + legend_title="Patch Name", + showlegend=True, + template="simple_white", + ) + + # Insert figures as json-formatted plotly plots self.insert1( { **key, "subject_positions_plot": json.loads(pos_fig.to_json()), "subject_weights_plot": json.loads(weight_fig.to_json()), "patch_distance_travelled_plot": json.loads(wheel_fig.to_json()), + "patch_rate_plot": json.loads(patch_rate_fig.to_json()), + "cumulative_pellet_plot": json.loads(cumulative_pellet_fig.to_json()), } ) @@ -286,9 +320,7 @@ class BlockDetection(dj.Computed): """ def make(self, key): - """ - On a per-chunk basis, check for the presence of new block, insert into Block table - """ + """On a per-chunk basis, check for the presence of new block, insert into Block table.""" # find the 0s # that would mark the start of a new block # if the 0 is the first index - look back at the previous chunk