diff --git a/image_recognition_age_gender/scripts/face_properties_node b/image_recognition_age_gender/scripts/face_properties_node index 69ddac9f..28169406 100755 --- a/image_recognition_age_gender/scripts/face_properties_node +++ b/image_recognition_age_gender/scripts/face_properties_node @@ -12,13 +12,13 @@ from image_recognition_util import image_writer class FacePropertiesNode: - def __init__(self, weights_file_path, img_size, depth, width, save_images_folder, use_gpu): + def __init__(self, weights_file_path, img_size, save_images_folder, use_gpu): """ ROS node that wraps the PyTorch age gender estimator """ self._bridge = CvBridge() self._properties_srv = rospy.Service('get_face_properties', GetFaceProperties, self._get_face_properties_srv) - self._estimator = AgeGenderEstimator(weights_file_path, img_size, depth, width, use_gpu) + self._estimator = AgeGenderEstimator(weights_file_path, img_size, use_gpu) if save_images_folder: self._save_images_folder = os.path.expanduser(save_images_folder) @@ -30,8 +30,6 @@ class FacePropertiesNode: rospy.loginfo("PytorchFaceProperties node initialized:") rospy.loginfo(" - weights_file_path=%s", weights_file_path) rospy.loginfo(" - img_size=%s", img_size) - rospy.loginfo(" - depth=%s", depth) - rospy.loginfo(" - width=%s", width) rospy.loginfo(" - save_images_folder=%s", save_images_folder) rospy.loginfo(" - use_gpu=%s", use_gpu) @@ -82,8 +80,6 @@ if __name__ == '__main__': default_weights_path = os.path.expanduser('~/data/pytorch_models/best-epoch47-0.9314.onnx') weights_file_path = rospy.get_param("~weights_file_path", default_weights_path) img_size = rospy.get_param("~image_size", 64) - depth = rospy.get_param("~depth", 16) - width = rospy.get_param("~width", 8) save_images = rospy.get_param("~save_images", True) use_gpu = rospy.get_param("~use_gpu", False) diff --git a/image_recognition_age_gender/src/image_recognition_age_gender/age_gender_estimator.py b/image_recognition_age_gender/src/image_recognition_age_gender/age_gender_estimator.py index 62cc9d22..f216ca05 100644 --- a/image_recognition_age_gender/src/image_recognition_age_gender/age_gender_estimator.py +++ b/image_recognition_age_gender/src/image_recognition_age_gender/age_gender_estimator.py @@ -8,11 +8,13 @@ class AgeGenderEstimator(object): - def __init__(self, weights_file_path, img_size=64, depth=16, width=8, use_gpu=False): + def __init__(self, weights_file_path, img_size=64, use_gpu=False): """ Estimate the age and gender of the incoming image :param weights_file_path: path to a pre-trained network in onnx format + :param img_size: Images are resized to a square image of (img_size X img_size) + :param use_gpu: Use GPU or CPU """ weights_file_path = os.path.expanduser(weights_file_path) @@ -22,8 +24,6 @@ def __init__(self, weights_file_path, img_size=64, depth=16, width=8, use_gpu=Fa self._model = None self._weights_file_path = weights_file_path self._img_size = img_size - self._depth = depth - self._width = width self._use_gpu = use_gpu def estimate(self, np_images): @@ -52,7 +52,7 @@ def estimate(self, np_images): results = [] for np_image in np_images: - inputs = np.transpose(cv2.resize(np_image, (64, 64)), (2, 0, 1)) + inputs = np.transpose(cv2.resize(np_image, (self._img_size, self._img_size)), (2, 0, 1)) inputs = np.expand_dims(inputs, 0).astype(np.float32) / 255. predictions = self._model.run(['output'], input_feed={'input': inputs})[0][0] # age p(male) p(female) diff --git a/image_recognition_age_gender/test/test_face_properties.py b/image_recognition_age_gender/test/test_face_properties.py index f8883f16..6cf6dc72 100644 --- a/image_recognition_age_gender/test/test_face_properties.py +++ b/image_recognition_age_gender/test/test_face_properties.py @@ -28,7 +28,7 @@ def age_is_female_from_asset_name(asset_name): images_gt = [(cv2.imread(os.path.join(assets_path, asset)), age_is_female_from_asset_name(asset)) for asset in os.listdir(assets_path)] - estimations = AgeGenderEstimator(local_path, 64, 16, 8).estimate([image for image, _ in images_gt]) + estimations = AgeGenderEstimator(local_path, 64).estimate(image for image, _ in images_gt) for (_, (age_gt, is_female_gt)), (age, gender) in zip(images_gt, estimations): age = int(age) is_female = gender[0] > 0.5