diff --git a/datasets/vggface2.py b/datasets/vggface2.py index a36d258c1..8d84d6def 100644 --- a/datasets/vggface2.py +++ b/datasets/vggface2.py @@ -23,8 +23,8 @@ from torchvision import transforms import cv2 -import face_detection import kornia.geometry.transform as GT +from batch_face import RetinaFace from PIL import Image from skimage import transform as trans from tqdm import tqdm @@ -38,7 +38,7 @@ class VGGFace2(Dataset): VGGFace2 Dataset """ def __init__(self, root_dir, d_type, mode, transform=None, - teacher_transform=None, img_size=(112, 112)): + teacher_transform=None, img_size=(112, 112), args=None): if d_type not in ('test', 'train'): raise ValueError("d_type can only be set to 'test' or 'train'") @@ -47,6 +47,7 @@ def __init__(self, root_dir, d_type, mode, transform=None, raise ValueError("mode can only be set to 'detection', 'identification'," "or 'identification_dr'") + self.device = args.device self.root_dir = root_dir self.d_type = d_type self.transform = transform @@ -99,8 +100,11 @@ def __extract_gt(self): """ Extracts the ground truth from the dataset """ - detector = face_detection.build_detector("RetinaNetResNet50", confidence_threshold=.5, - nms_iou_threshold=.4) + if self.device == 'cuda': + detector = RetinaFace(gpu_id=torch.cuda.current_device(), network="resnet50") + else: + detector = RetinaFace(gpu_id=-1, network="resnet50") + img_paths = list(glob.glob(os.path.join(self.d_path + '/**/', '*.jpg'), recursive=True)) nf_number = 0 words_count = 0 @@ -111,22 +115,17 @@ def __extract_gt(self): boxes = [] image = cv2.imread(jpg) - img_max = max(image.shape[0], image.shape[1]) - if img_max > 1320: - continue - bboxes, lndmrks = detector.batched_detect_with_landmarks(np.expand_dims(image, 0)) - bboxes = bboxes[0] - lndmrks = lndmrks[0] + faces = detector(image) - if (bboxes.shape[0] == 0) or (lndmrks.shape[0] == 0): + if len(faces) == 0: nf_number += 1 continue - for box in bboxes: + for face in faces: + box = face[0] box = np.clip(box[:4], 0, None) boxes.append(box) - - lndmrks = lndmrks[0] + lndmrks = faces[0][1] dir_name = os.path.dirname(jpg) lbl = os.path.relpath(dir_name, self.d_path) @@ -343,7 +342,7 @@ def VGGFace2_FaceID_get_datasets(data, load_train=True, load_test=True, img_size train_dataset = VGGFace2(root_dir=data_dir, d_type='train', mode='identification', transform=train_transform, teacher_transform=teacher_transform, - img_size=img_size) + img_size=img_size, args=args) print(f'Train dataset length: {len(train_dataset)}\n') else: @@ -355,7 +354,7 @@ def VGGFace2_FaceID_get_datasets(data, load_train=True, load_test=True, img_size test_dataset = VGGFace2(root_dir=data_dir, d_type='test', mode='identification', transform=test_transform, teacher_transform=teacher_transform, - img_size=img_size) + img_size=img_size, args=args) print(f'Test dataset length: {len(test_dataset)}\n') else: @@ -378,7 +377,7 @@ def VGGFace2_FaceID_dr_get_datasets(data, load_train=True, load_test=True, img_s if load_train: train_dataset = VGGFace2(root_dir=data_dir, d_type='train', mode='identification_dr', - transform=train_transform, img_size=img_size) + transform=train_transform, img_size=img_size, args=args) print(f'Train dataset length: {len(train_dataset)}\n') else: @@ -389,7 +388,7 @@ def VGGFace2_FaceID_dr_get_datasets(data, load_train=True, load_test=True, img_s ai8x.normalize(args=args)]) test_dataset = VGGFace2(root_dir=data_dir, d_type='test', mode='identification_dr', - transform=test_transform, img_size=img_size) + transform=test_transform, img_size=img_size, args=args) print(f'Test dataset length: {len(test_dataset)}\n') else: @@ -409,7 +408,7 @@ def VGGFace2_Facedet_get_datasets(data, load_train=True, load_test=True, img_siz ai8x.normalize(args=args)]) train_dataset = VGGFace2(root_dir=data_dir, d_type='train', mode='detection', - transform=train_transform, img_size=img_size) + transform=train_transform, img_size=img_size, args=args) print(f'Train dataset length: {len(train_dataset)}\n') else: @@ -419,7 +418,7 @@ def VGGFace2_Facedet_get_datasets(data, load_train=True, load_test=True, img_siz test_transform = transforms.Compose([ai8x.normalize(args=args)]) test_dataset = VGGFace2(root_dir=data_dir, d_type='test', mode='detection', - transform=test_transform, img_size=img_size) + transform=test_transform, img_size=img_size, args=args) print(f'Test dataset length: {len(test_dataset)}\n') else: diff --git a/requirements.txt b/requirements.txt index dbf6502ac..a1662a779 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,7 +6,7 @@ Pillow>=7 PyYAML>=5.1.1 albumentations>=1.3.0 faiss-cpu==1.7.4 -face-detection==0.2.2 +batch-face>=1.4.0 h5py>=3.7.0 kornia==0.6.8 librosa>=0.7.2