-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathembedding.py
96 lines (72 loc) · 2.83 KB
/
embedding.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import sys
import os
vse_path = '/home/student/pylib/visual-semantic-embedding/'
if os.path.exists(vse_path):
sys.path.append(vse_path)
else:
raise IOError('You need to have "visual-semantic-embedding" git repo\
on local path')
import numpy as np
import theano
from theano import tensor as T
import lasagne
from lasagne import layers as L
import demo, tools
class VisualSemanticEmbedder:
"""Joint embedding model for image and text using deep neural nets
This class is wrapping class for the ryankiros' implementation of
'visual-semantic-embedding' functionalities.
It uses Lasagne and Theano for pre-trained networks, so one should
install them to use this class. Since pre-trained models are deep,
it requires at least 3 Gib of memory to load the model.
This wrapper only use MS COCO model for the efficiency purpose.
This implementation is from "Unifying Visual-Semantic Embeddings with
Multimodal Neural Language Models" (Kiros, Salakhutdinov, Zemel. 2014).
"""
def __init__(self,model_path_dict):
"""Initialization requires embedding model path
"""
self.model_path = model_path_dict
# compile image feature extractor
self.vggnet = demo.build_convnet()
self._get_image_features = theano.function(
inputs = [self.vggnet['input'].input_var],
outputs = L.get_output(self.vggnet['fc7'],deterministic=True),
allow_input_downcast = True
)
# load up pretrained VSEM model
self.model = tools.load_model(
path_to_model=self.model_path['vse_model']
)
def get_image_embedding(self,file_names):
"""
"""
# check input paths
if not hasattr(file_names,'__iter__'):
if isinstance(file_names,str):
file_names = [file_names]
else:
raise ValueError('File names must be a iterable of strings!')
# (n_images,rgb,width,height)
X = np.array(map(lambda x:x[0],map(demo.load_image,file_names)))
# calculate VGG19 image embedding
Y = self._get_image_features(X).astype(np.float32)
# project them into VSEM embedding space
Z = tools.encode_images(self.model,Y)
return Z
def get_sentence_embedding(self,sentences):
"""
"""
# check input paths
if not hasattr(sentences,'__iter__'):
if isinstance(sentences,str):
sentences = [sentences]
else:
raise ValueError('Sentences must be a iterable of strings!')
Z = tools.encode_sentences(self.model,sentences)
return Z
def score(self,image_embeddings,sentence_embeddings,method='dot'):
"""
"""
if method=='dot':
return np.dot(image_embeddings,sentence_embeddings.T)