Skip to content

Commit

Permalink
Merge pull request #313 from JaerongA/datajoint_pipeline
Browse files Browse the repository at this point in the history
Bug fix & add block analysis plots
  • Loading branch information
Thinh Nguyen authored Jan 31, 2024
2 parents c0afa90 + 1271fa4 commit 4ea9ceb
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 30 deletions.
94 changes: 64 additions & 30 deletions aeon/dj_pipeline/analysis/block_analysis.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
import datetime
import datajoint as dj
import pandas as pd
import json

import datajoint as dj
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objs as go

from aeon.analysis import utils as analysis_utils

from aeon.dj_pipeline import get_schema_name, fetch_stream
from aeon.dj_pipeline import acquisition, tracking, streams
from aeon.dj_pipeline.analysis.visit import (
get_maintenance_periods,
filter_out_maintenance_periods,
)
from aeon.dj_pipeline import acquisition, fetch_stream, get_schema_name, streams, tracking
from aeon.dj_pipeline.analysis.visit import filter_out_maintenance_periods, get_maintenance_periods

schema = dj.schema(get_schema_name("analysis"))

Expand Down Expand Up @@ -47,7 +44,7 @@ class Patch(dj.Part):
wheel_timestamps: longblob
patch_threshold: longblob
patch_threshold_timestamps: longblob
patch_rate: float
patch_rate: float
"""

class Subject(dj.Part):
Expand All @@ -65,13 +62,11 @@ class Subject(dj.Part):
"""

def make(self, key):
"""
Restrict, fetch and aggregate data from different streams to produce intermediate data products
at a per-block level (for different patches and different subjects)
1. Query data for all chunks within the block
2. Fetch streams, filter by maintenance period
3. Fetch subject position data (SLEAP)
4. Aggregate and insert into the table
"""Restrict, fetch and aggregate data from different streams to produce intermediate data products at a per-block level (for different patches and different subjects).
1. Query data for all chunks within the block.
2. Fetch streams, filter by maintenance period.
3. Fetch subject position data (SLEAP).
4. Aggregate and insert into the table.
"""
block_start, block_end = (Block & key).fetch1("block_start", "block_end")

Expand Down Expand Up @@ -165,7 +160,7 @@ def make(self, key):
pos_df = filter_out_maintenance_periods(pos_df, maintenance_period, block_end)

position_diff = np.sqrt(
(np.square(np.diff(pos_df.x.astype(float))) + np.square(np.diff(pos_df.y.astype(float))))
np.square(np.diff(pos_df.x.astype(float))) + np.square(np.diff(pos_df.y.astype(float)))
)
cumsum_distance_travelled = np.concatenate([[0], np.cumsum(position_diff)])

Expand Down Expand Up @@ -207,7 +202,7 @@ class Patch(dj.Part):
pellet_timestamps: longblob
wheel_distance_travelled: longblob # wheel's cumulative distance travelled
wheel_timestamps: longblob
cumulative_sum_preference: longblob
cumulative_sum_preference: longblob
windowed_sum_preference: longblob
"""

Expand All @@ -223,19 +218,24 @@ class BlockPlots(dj.Computed):
subject_positions_plot: longblob
subject_weights_plot: longblob
patch_distance_travelled_plot: longblob
patch_rate_plot: longblob
cumulative_pellet_plot: longblob
"""

def make(self, key):
import plotly.graph_objs as go

# For position data , set confidence threshold to return position values and downsample by 5x
conf_thresh = 0.9
downsampling_factor = 5

# Make plotly plots
weight_fig = go.Figure()
pos_fig = go.Figure()
for subject_data in (BlockAnalysis.Subject & key).fetch(as_dict=True):
wheel_fig = go.Figure()
patch_rate_fig = go.Figure()
cumulative_pellet_fig = go.Figure()

for subject_data in BlockAnalysis.Subject & key:
# Subject weight over time
weight_fig.add_trace(
go.Scatter(
x=subject_data["weight_timestamps"],
Expand All @@ -244,6 +244,7 @@ def make(self, key):
name=subject_data["subject_name"],
)
)
# Subject position over time
mask = subject_data["position_likelihood"] > conf_thresh
pos_fig.add_trace(
go.Scatter3d(
Expand All @@ -255,24 +256,59 @@ def make(self, key):
)
)

wheel_fig = go.Figure()
for patch_data in (BlockAnalysis.Patch & key).fetch(as_dict=True):
# Cumulative wheel distance travelled over time
for patch_data in BlockAnalysis.Patch & key:
wheel_fig.add_trace(
go.Scatter(
x=patch_data["wheel_timestamps"][::2],
y=patch_data["cumulative_distance_travelled"][::2],
y=patch_data["wheel_cumsum_distance_travelled"][::2],
mode="lines",
name=patch_data["patch_name"],
)
)

# insert figures as json-formatted plotly plots
# Create a bar chart for patch rates
patch_df = (BlockAnalysis.Patch & key).fetch(format="frame").reset_index()
patch_rate_fig = px.bar(
patch_df,
x="patch_name",
y="patch_rate",
color="patch_name",
title="Patch Stats: Patch Rate for Each Patch",
labels={"patch_name": "Patch Name", "patch_rate": "Patch Rate"},
text="patch_rate",
)
patch_rate_fig.update_layout(bargap=0.2, width=600, height=400, template="simple_white")

# Cumulative pellets per patch over time
for _, row in patch_df.iterrows():
timestamps = row["pellet_timestamps"]
total_pellet_count = list(range(1, row["pellet_count"] + 1))

cumulative_pellet_fig.add_trace(
go.Scatter(x=timestamps, y=total_pellet_count, mode="lines+markers", name=row["patch_name"])
)

cumulative_pellet_fig.update_layout(
title="Cumulative Pellet Count Over Time",
xaxis_title="Time",
yaxis_title="Cumulative Pellet Count",
width=800,
height=500,
legend_title="Patch Name",
showlegend=True,
template="simple_white",
)

# Insert figures as json-formatted plotly plots
self.insert1(
{
**key,
"subject_positions_plot": json.loads(pos_fig.to_json()),
"subject_weights_plot": json.loads(weight_fig.to_json()),
"patch_distance_travelled_plot": json.loads(wheel_fig.to_json()),
"patch_rate_plot": json.loads(patch_rate_fig.to_json()),
"cumulative_pellet_plot": json.loads(cumulative_pellet_fig.to_json()),
}
)

Expand All @@ -284,9 +320,7 @@ class BlockDetection(dj.Computed):
"""

def make(self, key):
"""
On a per-chunk basis, check for the presence of new block, insert into Block table
"""
"""On a per-chunk basis, check for the presence of new block, insert into Block table."""
# find the 0s
# that would mark the start of a new block
# if the 0 is the first index - look back at the previous chunk
Expand Down
1 change: 1 addition & 0 deletions aeon/util.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Utility functions."""
from __future__ import annotations

from typing import Any

Expand Down

0 comments on commit 4ea9ceb

Please sign in to comment.