diff --git a/tf_bodypix/model.py b/tf_bodypix/model.py index 6ff4a13..808cf3a 100644 --- a/tf_bodypix/model.py +++ b/tf_bodypix/model.py @@ -343,7 +343,7 @@ def get_colored_part_mask( part_colors=part_colors ) - def get_poses(self) -> List[Pose]: + def get_poses(self, maxPoseDetections=10) -> List[Pose]: assert self.heatmap_logits is not None assert self.short_offsets is not None assert self.displacement_fwd is not None @@ -354,7 +354,7 @@ def get_poses(self) -> List[Pose]: displacementsFwdBuffer=np.asarray(self.displacement_fwd[0]), displacementsBwdBuffer=np.asarray(self.displacement_bwd[0]), outputStride=self.output_stride, - maxPoseDetections=2 + maxPoseDetections=maxPoseDetections ) scaled_poses = scaleAndFlipPoses( poses,