Skip to content

Commit

Permalink
Merge pull request #225 from yujiahu415/master
Browse files Browse the repository at this point in the history
v2.6.0
  • Loading branch information
yujiahu415 authored Oct 25, 2024
2 parents 771132e + a4b5f2c commit 687f26b
Show file tree
Hide file tree
Showing 7 changed files with 413 additions and 29 deletions.
2 changes: 1 addition & 1 deletion LabGym/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@



__version__='2.5.6'
__version__='2.6.0'



27 changes: 23 additions & 4 deletions LabGym/analyzebehavior.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,10 +521,11 @@ def craft_data(self):
print('Data crafting completed!')


def categorize_behaviors(self,path_to_categorizer,uncertain=0):
def categorize_behaviors(self,path_to_categorizer,uncertain=0,min_length=None):

# path_to_categorizer: path to the Categorizer
# uncertain: a threshold between the highest the 2nd highest probablity of behaviors to determine if output an 'NA' in behavior classification
# min_length: the minimum length (in frames) a behavior should last, can be used to filter out the brief false positives

print('Categorizing behaviors...')
print(datetime.datetime.now())
Expand Down Expand Up @@ -591,17 +592,29 @@ def categorize_behaviors(self,path_to_categorizer,uncertain=0):
self.event_probability[n][i]=[behavior_names[1],prediction[0]]
if prediction[0]<0.5:
if (1-prediction[0])-prediction[0]>uncertain:
self.event_probability[n][i]=[behavior_names[0],1-prediction[0]]
self.event_probability[n][i]=[behavior_names[0],1-prediction[0]]
else:
if sorted(prediction)[-1]-sorted(prediction)[-2]>uncertain:
self.event_probability[n][i]=[behavior_names[np.argmax(prediction)],max(prediction)]

idx+=1
i+=1

del predictions
gc.collect()

if min_length is not None:
for n in IDs:
i=self.length+self.register_counts[n]
continued_length=1
while i<len(self.event_probability[n]):
if self.event_probability[n][i][0]==self.event_probability[n][i-1][0]:
continued_length+=1
else:
if continued_length<min_length:
self.event_probability[n][i-continued_length:i]=[['NA',-1]]*continued_length
continued_length=1
i+=1

print('Behavioral categorization completed!')


Expand Down Expand Up @@ -691,7 +704,13 @@ def annotate_video(self,behavior_to_include,show_legend=True,interact_all=False)

cx=self.animal_centers[i][frame_count_analyze][0]
cy=self.animal_centers[i][frame_count_analyze][1]
cv2.circle(self.background,(cx,cy),int(text_tk),(abs(int(color_diff*(total_animal_number-current_animal_number)-255)),int(color_diff*current_animal_number/2),int(color_diff*(total_animal_number-current_animal_number)/2)),-1)

if self.animal_centers[i][max(frame_count_analyze-1,0)] is not None:
cxp=self.animal_centers[i][max(frame_count_analyze-1,0)][0]
cyp=self.animal_centers[i][max(frame_count_analyze-1,0)][1]
cv2.line(self.background,(cx,cy),(cxp,cyp),(abs(int(color_diff*(total_animal_number-current_animal_number)-255)),int(color_diff*current_animal_number/2),int(color_diff*(total_animal_number-current_animal_number)/2)),int(text_tk))
else:
cv2.circle(self.background,(cx,cy),int(text_tk),(abs(int(color_diff*(total_animal_number-current_animal_number)-255)),int(color_diff*current_animal_number/2),int(color_diff*(total_animal_number-current_animal_number)/2)),-1)

if interact_all is False:
cv2.putText(frame,str(i),(cx-10,cy-10),cv2.FONT_HERSHEY_SIMPLEX,text_scl,(255,255,255),text_tk)
Expand Down
29 changes: 25 additions & 4 deletions LabGym/analyzebehavior_dt.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,10 +946,11 @@ def craft_data(self):
print('Data crafting completed!')


def categorize_behaviors(self,path_to_categorizer,uncertain=0):
def categorize_behaviors(self,path_to_categorizer,uncertain=0,min_length=None):

# path_to_categorizer: path to the Categorizer
# uncertain: a threshold between the highest the 2nd highest probability of behaviors to determine if output an 'NA' in behavior classification
# min_length: the minimum length (in frames) a behavior should last, can be used to filter out the brief false positives

print('Categorizing behaviors...')
print(datetime.datetime.now())
Expand Down Expand Up @@ -1027,13 +1028,26 @@ def categorize_behaviors(self,path_to_categorizer,uncertain=0):
else:
if sorted(prediction)[-1]-sorted(prediction)[-2]>uncertain:
self.event_probability[animal_name][n][i]=[behavior_names[np.argmax(prediction)],max(prediction)]

idx+=1
i+=1

del predictions
gc.collect()

if min_length is not None:
for animal_name in self.animal_kinds:
for n in IDs:
i=self.length+self.register_counts[animal_name][n]
continued_length=1
while i<len(self.event_probability[animal_name][n]):
if self.event_probability[animal_name][n][i][0]==self.event_probability[animal_name][n][i-1][0]:
continued_length+=1
else:
if continued_length<min_length:
self.event_probability[animal_name][n][i-continued_length:i]=[['NA',-1]]*continued_length
continued_length=1
i+=1

print('Behavioral categorization completed!')


Expand Down Expand Up @@ -1171,8 +1185,15 @@ def annotate_video(self,animal_to_include,behavior_to_include,show_legend=True):

cx=self.animal_centers[animal_name][i][frame_count_analyze][0]
cy=self.animal_centers[animal_name][i][frame_count_analyze][1]
cv2.circle(self.background,(cx,cy),int(text_tk),(abs(int(color_diff*(total_animal_number-current_animal_number)-255)),int(color_diff*current_animal_number/2),int(color_diff*(total_animal_number-current_animal_number)/2)),-1)
cv2.circle(background,(cx,cy),int(text_tk),(abs(int(color_diff*(total_animal_number-current_animal_number)-255)),int(color_diff*current_animal_number/2),int(color_diff*(total_animal_number-current_animal_number)/2)),-1)

if self.animal_centers[animal_name][i][max(frame_count_analyze-1,0)] is not None:
cxp=self.animal_centers[animal_name][i][max(frame_count_analyze-1,0)][0]
cyp=self.animal_centers[animal_name][i][max(frame_count_analyze-1,0)][1]
cv2.line(self.background,(cx,cy),(cxp,cyp),(abs(int(color_diff*(total_animal_number-current_animal_number)-255)),int(color_diff*current_animal_number/2),int(color_diff*(total_animal_number-current_animal_number)/2)),int(text_tk))
cv2.line(background,(cx,cy),(cxp,cyp),(abs(int(color_diff*(total_animal_number-current_animal_number)-255)),int(color_diff*current_animal_number/2),int(color_diff*(total_animal_number-current_animal_number)/2)),int(text_tk))
else:
cv2.circle(self.background,(cx,cy),int(text_tk),(abs(int(color_diff*(total_animal_number-current_animal_number)-255)),int(color_diff*current_animal_number/2),int(color_diff*(total_animal_number-current_animal_number)/2)),-1)
cv2.circle(background,(cx,cy),int(text_tk),(abs(int(color_diff*(total_animal_number-current_animal_number)-255)),int(color_diff*current_animal_number/2),int(color_diff*(total_animal_number-current_animal_number)/2)),-1)

if self.behavior_mode!=1:
cv2.circle(frame,(cx,cy),int(text_tk*3),(255,0,0),-1)
Expand Down
32 changes: 20 additions & 12 deletions LabGym/categorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from .tools import *
import matplotlib
matplotlib.use("Agg")
matplotlib.use('Agg')
import os
import cv2
import datetime
Expand Down Expand Up @@ -924,8 +924,9 @@ def train_pattern_recognizer(self,data_path,model_path,out_path=None,dim=64,chan
predictions=model.predict(testX,batch_size=batch_size)

if len(self.classnames)==2:
print(classification_report(testY.argmax(axis=1),predictions.argmax(axis=1),target_names=[self.classnames[0]]))
report=classification_report(testY.argmax(axis=1),predictions.argmax(axis=1),target_names=[self.classnames[0]],output_dict=True)
predictions=[round(i[0]) for i in predictions]
print(classification_report(testY,predictions,target_names=self.classnames))
report=classification_report(testY,predictions,target_names=self.classnames,output_dict=True)
else:
print(classification_report(testY.argmax(axis=1),predictions.argmax(axis=1),target_names=self.classnames))
report=classification_report(testY.argmax(axis=1),predictions.argmax(axis=1),target_names=self.classnames,output_dict=True)
Expand All @@ -934,7 +935,7 @@ def train_pattern_recognizer(self,data_path,model_path,out_path=None,dim=64,chan
if out_path is not None:
pd.DataFrame(report).transpose().to_excel(os.path.join(out_path,'training_metrics.xlsx'),float_format='%.2f')

plt.style.use('seaborn-bright')
plt.style.use('classic')
plt.figure()
plt.plot(H.history['loss'],label='train_loss')
plt.plot(H.history['val_loss'],label='val_loss')
Expand Down Expand Up @@ -1078,8 +1079,9 @@ def train_animation_analyzer(self,data_path,model_path,out_path=None,dim=64,chan
predictions=model.predict(testX,batch_size=batch_size)

if len(self.classnames)==2:
print(classification_report(testY.argmax(axis=1),predictions.argmax(axis=1),target_names=[self.classnames[0]]))
report=classification_report(testY.argmax(axis=1),predictions.argmax(axis=1),target_names=[self.classnames[0]],output_dict=True)
predictions=[round(i[0]) for i in predictions]
print(classification_report(testY,predictions,target_names=self.classnames))
report=classification_report(testY,predictions,target_names=self.classnames,output_dict=True)
else:
print(classification_report(testY.argmax(axis=1),predictions.argmax(axis=1),target_names=self.classnames))
report=classification_report(testY.argmax(axis=1),predictions.argmax(axis=1),target_names=self.classnames,output_dict=True)
Expand All @@ -1088,7 +1090,7 @@ def train_animation_analyzer(self,data_path,model_path,out_path=None,dim=64,chan
if out_path is not None:
pd.DataFrame(report).transpose().to_excel(os.path.join(out_path,'training_metrics.xlsx'),float_format='%.2f')

plt.style.use('seaborn-bright')
plt.style.use('classic')
plt.figure()
plt.plot(H.history['loss'],label='train_loss')
plt.plot(H.history['val_loss'],label='val_loss')
Expand Down Expand Up @@ -1225,8 +1227,9 @@ def train_combnet(self,data_path,model_path,out_path=None,dim_tconv=32,dim_conv=
predictions=model.predict([test_animations,test_pattern_images],batch_size=batch_size)

if len(self.classnames)==2:
print(classification_report(testY.argmax(axis=1),predictions.argmax(axis=1),target_names=[self.classnames[0]]))
report=classification_report(testY.argmax(axis=1),predictions.argmax(axis=1),target_names=[self.classnames[0]],output_dict=True)
predictions=[round(i[0]) for i in predictions]
print(classification_report(testY,predictions,target_names=self.classnames))
report=classification_report(testY,predictions,target_names=self.classnames,output_dict=True)
else:
print(classification_report(testY.argmax(axis=1),predictions.argmax(axis=1),target_names=self.classnames))
report=classification_report(testY.argmax(axis=1),predictions.argmax(axis=1),target_names=self.classnames,output_dict=True)
Expand All @@ -1235,7 +1238,7 @@ def train_combnet(self,data_path,model_path,out_path=None,dim_tconv=32,dim_conv=
if out_path is not None:
pd.DataFrame(report).transpose().to_excel(os.path.join(out_path,'training_metrics.xlsx'),float_format='%.2f')

plt.style.use('seaborn-bright')
plt.style.use('classic')
plt.figure()
plt.plot(H.history['loss'],label='train_loss')
plt.plot(H.history['val_loss'],label='val_loss')
Expand Down Expand Up @@ -1386,8 +1389,13 @@ def test_categorizer(self,groundtruth_path,model_path,result_path=None):
else:
predictions=model.predict([animations,pattern_images],batch_size=32)

print(classification_report(labels,predictions.argmax(axis=1),target_names=classnames))
report=classification_report(labels,predictions.argmax(axis=1),target_names=classnames,output_dict=True)
if len(classnames)==2:
predictions=[round(i[0]) for i in predictions]
print(classification_report(labels,predictions,target_names=classnames))
report=classification_report(labels,predictions,target_names=classnames,output_dict=True)
else:
print(classification_report(labels,predictions.argmax(axis=1),target_names=classnames))
report=classification_report(labels,predictions.argmax(axis=1),target_names=classnames,output_dict=True)

if result_path is not None:
pd.DataFrame(report).transpose().to_excel(os.path.join(result_path,'testing_reports.xlsx'),float_format='%.2f')
Expand Down
Loading

0 comments on commit 687f26b

Please sign in to comment.