Skip to content

Commit

Permalink
Merge pull request #228 from yujiahu415/master
Browse files Browse the repository at this point in the history
v2.6.1
  • Loading branch information
yujiahu415 authored Nov 7, 2024
2 parents 95ebce8 + 7ba532b commit 4613b73
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 17 deletions.
2 changes: 1 addition & 1 deletion LabGym/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@



__version__='2.6.0'
__version__='2.6.1'



6 changes: 3 additions & 3 deletions LabGym/detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,12 @@ def train(self,path_to_annotation,path_to_trainingimages,path_to_detector,iterat
print('Animal names in annotation file: '+str(model_parameters_dict['animal_names']))

cfg=get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
cfg.merge_from_file(model_zoo.get_config_file('COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml'))
cfg.OUTPUT_DIR=path_to_detector
cfg.DATASETS.TRAIN=('LabGym_detector_train',)
cfg.DATASETS.TEST=()
cfg.DATALOADER.NUM_WORKERS=4
cfg.MODEL.WEIGHTS=model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
cfg.MODEL.WEIGHTS=model_zoo.get_checkpoint_url('COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml')
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE=128
cfg.MODEL.ROI_HEADS.NUM_CLASSES=int(len(classnames))
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST=0.5
Expand All @@ -95,7 +95,7 @@ def train(self,path_to_annotation,path_to_trainingimages,path_to_detector,iterat
cfg.SOLVER.GAMMA=0.5
cfg.SOLVER.IMS_PER_BATCH=4
cfg.MODEL.DEVICE=self.device
cfg.SOLVER.CHECKPOINT_PERIOD=10000000
cfg.SOLVER.CHECKPOINT_PERIOD=10000000000
cfg.INPUT.MIN_SIZE_TEST=int(inference_size)
cfg.INPUT.MAX_SIZE_TEST=int(inference_size)
cfg.INPUT.MIN_SIZE_TRAIN=(int(inference_size),)
Expand Down
27 changes: 14 additions & 13 deletions LabGym/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -1299,7 +1299,8 @@ def calculate_distances(path_to_folder,filename,behavior_to_include,out_path):
behavior=event[0]
if behavior!='NA':
if frame_index is None:
frame_index=n
if behavior in behavior_to_include:
frame_index=n
if behavior not in behavior_names[idx]:
behavior_names[idx].append(behavior)
start_centers[idx][behavior]=centers[idx][n]
Expand Down Expand Up @@ -1338,26 +1339,26 @@ def calculate_distances(path_to_folder,filename,behavior_to_include,out_path):
centers_for_calculation.append(start_centers[idx][behavior])
indices_for_calculation.append(start_indices[idx][behavior])

n=0
while n<len(centers_for_calculation):
if n!=len(centers_for_calculation)-1:
shortest_distance+=math.dist(centers_for_calculation[n],centers_for_calculation[n+1])
cv2.circle(frame,(centers_for_calculation[n]),5,(255,max(0,255-int(n*diff)),min(255,int(n*diff))),-1)
cv2.line(frame,centers_for_calculation[n],centers_for_calculation[n+1],(255,max(0,255-int(n*diff)),min(255,int(n*diff))),5)
n+=1

centers_traveled=centers[idx][indices_for_calculation[0]:indices_for_calculation[-1]+1]

n=0
while n<len(centers_traveled)-1:
if centers_traveled[n] is not None:
if centers_traveled[n+1] is not None:
cv2.line(frame,centers_traveled[n],centers_traveled[n+1],(0,max(0,255-int(idx*diff_animal)),0),2)
cv2.line(frame,centers_traveled[n],centers_traveled[n+1],(255,0,max(0,255-int(idx*diff_animal))),2)
traveling_distance+=math.dist(centers_traveled[n],centers_traveled[n+1])
else:
cv2.circle(frame,(centers_traveled[n]),2,(0,max(0,255-int(idx*diff_animal)),0),-1)
cv2.circle(frame,(centers_traveled[n]),2,(255,0,max(0,255-int(idx*diff_animal))),-1)
n+=1

n=0
while n<len(centers_for_calculation):
if n!=len(centers_for_calculation)-1:
shortest_distance+=math.dist(centers_for_calculation[n],centers_for_calculation[n+1])
cv2.circle(frame,(centers_for_calculation[n]),4,(max(0,255-int(n*diff)),max(0,255-int(n*diff)),0),-1)
cv2.line(frame,centers_for_calculation[n],centers_for_calculation[n+1],(max(0,255-int(n*diff)),max(0,255-int(n*diff))),4)
n+=1

centers_traveled=centers[idx][indices_for_calculation[0]:indices_for_calculation[-1]+1]

shortest_distances[idx]=shortest_distance
traveling_distances[idx]=traveling_distance
distance_ratios[idx]=shortest_distance/traveling_distance
Expand Down

0 comments on commit 4613b73

Please sign in to comment.