diff --git a/aeon/dj_pipeline/analysis/block_analysis.py b/aeon/dj_pipeline/analysis/block_analysis.py index 342528f6..6f80ec6e 100644 --- a/aeon/dj_pipeline/analysis/block_analysis.py +++ b/aeon/dj_pipeline/analysis/block_analysis.py @@ -1,17 +1,14 @@ -import datetime -import datajoint as dj -import pandas as pd import json + +import datajoint as dj import numpy as np +import pandas as pd +import plotly.express as px +import plotly.graph_objs as go from aeon.analysis import utils as analysis_utils - -from aeon.dj_pipeline import get_schema_name, fetch_stream -from aeon.dj_pipeline import acquisition, tracking, streams -from aeon.dj_pipeline.analysis.visit import ( - get_maintenance_periods, - filter_out_maintenance_periods, -) +from aeon.dj_pipeline import acquisition, fetch_stream, get_schema_name, streams, tracking +from aeon.dj_pipeline.analysis.visit import filter_out_maintenance_periods, get_maintenance_periods schema = dj.schema(get_schema_name("analysis")) @@ -47,7 +44,7 @@ class Patch(dj.Part): wheel_timestamps: longblob patch_threshold: longblob patch_threshold_timestamps: longblob - patch_rate: float + patch_rate: float """ class Subject(dj.Part): @@ -65,13 +62,11 @@ class Subject(dj.Part): """ def make(self, key): - """ - Restrict, fetch and aggregate data from different streams to produce intermediate data products - at a per-block level (for different patches and different subjects) - 1. Query data for all chunks within the block - 2. Fetch streams, filter by maintenance period - 3. Fetch subject position data (SLEAP) - 4. Aggregate and insert into the table + """Restrict, fetch and aggregate data from different streams to produce intermediate data products at a per-block level (for different patches and different subjects). + 1. Query data for all chunks within the block. + 2. Fetch streams, filter by maintenance period. + 3. Fetch subject position data (SLEAP). + 4. Aggregate and insert into the table. """ block_start, block_end = (Block & key).fetch1("block_start", "block_end") @@ -165,7 +160,7 @@ def make(self, key): pos_df = filter_out_maintenance_periods(pos_df, maintenance_period, block_end) position_diff = np.sqrt( - (np.square(np.diff(pos_df.x.astype(float))) + np.square(np.diff(pos_df.y.astype(float)))) + np.square(np.diff(pos_df.x.astype(float))) + np.square(np.diff(pos_df.y.astype(float))) ) cumsum_distance_travelled = np.concatenate([[0], np.cumsum(position_diff)]) @@ -207,7 +202,7 @@ class Patch(dj.Part): pellet_timestamps: longblob wheel_distance_travelled: longblob # wheel's cumulative distance travelled wheel_timestamps: longblob - cumulative_sum_preference: longblob + cumulative_sum_preference: longblob windowed_sum_preference: longblob """ @@ -223,11 +218,11 @@ class BlockPlots(dj.Computed): subject_positions_plot: longblob subject_weights_plot: longblob patch_distance_travelled_plot: longblob + patch_rate_plot: longblob + cumulative_pellet_plot: longblob """ def make(self, key): - import plotly.graph_objs as go - # For position data , set confidence threshold to return position values and downsample by 5x conf_thresh = 0.9 downsampling_factor = 5 @@ -235,7 +230,12 @@ def make(self, key): # Make plotly plots weight_fig = go.Figure() pos_fig = go.Figure() - for subject_data in (BlockAnalysis.Subject & key).fetch(as_dict=True): + wheel_fig = go.Figure() + patch_rate_fig = go.Figure() + cumulative_pellet_fig = go.Figure() + + for subject_data in BlockAnalysis.Subject & key: + # Subject weight over time weight_fig.add_trace( go.Scatter( x=subject_data["weight_timestamps"], @@ -244,6 +244,7 @@ def make(self, key): name=subject_data["subject_name"], ) ) + # Subject position over time mask = subject_data["position_likelihood"] > conf_thresh pos_fig.add_trace( go.Scatter3d( @@ -255,24 +256,59 @@ def make(self, key): ) ) - wheel_fig = go.Figure() - for patch_data in (BlockAnalysis.Patch & key).fetch(as_dict=True): + # Cumulative wheel distance travelled over time + for patch_data in BlockAnalysis.Patch & key: wheel_fig.add_trace( go.Scatter( x=patch_data["wheel_timestamps"][::2], - y=patch_data["cumulative_distance_travelled"][::2], + y=patch_data["wheel_cumsum_distance_travelled"][::2], mode="lines", name=patch_data["patch_name"], ) ) - # insert figures as json-formatted plotly plots + # Create a bar chart for patch rates + patch_df = (BlockAnalysis.Patch & key).fetch(format="frame").reset_index() + patch_rate_fig = px.bar( + patch_df, + x="patch_name", + y="patch_rate", + color="patch_name", + title="Patch Stats: Patch Rate for Each Patch", + labels={"patch_name": "Patch Name", "patch_rate": "Patch Rate"}, + text="patch_rate", + ) + patch_rate_fig.update_layout(bargap=0.2, width=600, height=400, template="simple_white") + + # Cumulative pellets per patch over time + for _, row in patch_df.iterrows(): + timestamps = row["pellet_timestamps"] + total_pellet_count = list(range(1, row["pellet_count"] + 1)) + + cumulative_pellet_fig.add_trace( + go.Scatter(x=timestamps, y=total_pellet_count, mode="lines+markers", name=row["patch_name"]) + ) + + cumulative_pellet_fig.update_layout( + title="Cumulative Pellet Count Over Time", + xaxis_title="Time", + yaxis_title="Cumulative Pellet Count", + width=800, + height=500, + legend_title="Patch Name", + showlegend=True, + template="simple_white", + ) + + # Insert figures as json-formatted plotly plots self.insert1( { **key, "subject_positions_plot": json.loads(pos_fig.to_json()), "subject_weights_plot": json.loads(weight_fig.to_json()), "patch_distance_travelled_plot": json.loads(wheel_fig.to_json()), + "patch_rate_plot": json.loads(patch_rate_fig.to_json()), + "cumulative_pellet_plot": json.loads(cumulative_pellet_fig.to_json()), } ) @@ -284,9 +320,7 @@ class BlockDetection(dj.Computed): """ def make(self, key): - """ - On a per-chunk basis, check for the presence of new block, insert into Block table - """ + """On a per-chunk basis, check for the presence of new block, insert into Block table.""" # find the 0s # that would mark the start of a new block # if the 0 is the first index - look back at the previous chunk diff --git a/aeon/util.py b/aeon/util.py index aed34803..a9dc88bc 100644 --- a/aeon/util.py +++ b/aeon/util.py @@ -1,4 +1,5 @@ """Utility functions.""" +from __future__ import annotations from typing import Any