-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathget_image_embeddings.py
72 lines (48 loc) · 2.07 KB
/
get_image_embeddings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
# SPDX-FileCopyrightText: 2024 Idiap Research Institute <[email protected]>
# SPDX-FileContributor: Alina Elena Baia <[email protected]>
#
# SPDX-License-Identifier: CC-BY-NC-SA-4.0
import os
import argparse
import pandas as pd
import numpy as np
from tqdm import tqdm
from PIL import Image
import imagebind.data
import torch
from imagebind.models import imagebind_model
from imagebind.models.imagebind_model import ModalityType
device = "cuda:0" if torch.cuda.is_available() else "cpu"
# Instantiate model
model = imagebind_model.imagebind_huge(pretrained=True)
model.eval()
model.to(device)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-dataset', type=str, default='./data/dataset_train')
parser.add_argument("-csv", "--csv_name", help="csv file with the image ids", type=str, default ="./data/dataset_train_info.csv")
parser.add_argument("-output", "--output_file_name", help="file with the generated embeddings", type=str, default ="image_embeddings")
parser.add_argument("-b", "--batch_size", help="batch size", type=int, default=20)
args = parser.parse_args()
if not os.path.exists("./generated_data"):
os.makedirs("./generated_data")
dataset_dir = args.dataset
csv_name = args.csv_name
output = args.output_file_name
batch_size = args.batch_size
dataset_info = pd.read_csv(csv_name)
images_name = list(dataset_info["image_name"])
image_paths = [os.path.join(dataset_dir, img_path) for img_path in images_name]
nr_iterations = int(np.ceil(len(image_paths)/batch_size))
image_embeddings = []
for i in tqdm(range(nr_iterations)):
start_index = i * batch_size
end_index = (i * batch_size) + batch_size
inputs = {
ModalityType.VISION: imagebind.data.load_and_transform_vision_data(image_paths[start_index:end_index], device),
}
with torch.no_grad():
embeddings = model(inputs)
image_embeddings.extend(embeddings[ModalityType.VISION].cpu().numpy().tolist())
#print(np.array(image_embeddings).shape)
np.save("./generated_data/{}.npy".format(output), np.array(image_embeddings))