forked from matsui528/sis
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfeature_extractor.py
31 lines (26 loc) · 1.17 KB
/
feature_extractor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import os
# fix cuda not being detected on windows :|
os.add_dll_directory("C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.4/bin")
from tensorflow.keras.applications.vgg16 import VGG16, preprocess_input
from tensorflow.keras.models import Model
import numpy as np
import tensorflow as tf
# See https://keras.io/api/applications/ for details
class FeatureExtractor:
def __init__(self):
base_model = VGG16(weights='imagenet')
self.model = Model(inputs=base_model.input, outputs=base_model.get_layer('fc1').output)
def predict_and_normalize(self, img):
feature = self.model.predict(img)[0] # (1, 4096) -> (4096, )
return feature / np.linalg.norm(feature) # Normalize
def process_predict_and_normalize(self, image_path):
x = tf.io.read_file(image_path)
x = tf.cond(
tf.image.is_jpeg(x),
lambda: tf.image.decode_jpeg(x, channels=3, try_recover_truncated=True, acceptable_fraction=0.5),
lambda: tf.image.decode_png(x, channels=3)
)
x = tf.image.resize(x, (224, 224))
x = tf.expand_dims(x, axis=0)
x = preprocess_input(x)
return self.predict_and_normalize(x)