forked from albanie/collaborative-experts
-
Notifications
You must be signed in to change notification settings - Fork 0
/
sent_feat_demo.py
66 lines (45 loc) · 2.11 KB
/
sent_feat_demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
def sent_feat (text, feat_type):
if feat_type =='w2v':
import gensim
import numpy as np
model = gensim.models.KeyedVectors.load_word2vec_format('/scratch/shared/slow/yangl/w2v/GoogleNews-vectors-negative300.bin', binary=True)
final_feats=[]
for word in (text.split(' ')):
if (word !='a') and (word in model.vocab):
final_feats.append(model.get_vector(word))
final_feats = np.asarray(final_feats)
elif feat_type == 'openai':
import json
import torch
from pytorch_pretrained_bert import OpenAIGPTTokenizer, OpenAIGPTModel, OpenAIGPTLMHeadModel
import logging
logging.basicConfig(level=logging.INFO)
# Load pre-trained model tokenizer (vocabulary)
tokenizer = OpenAIGPTTokenizer.from_pretrained('openai-gpt')
# Tokenized input
#text = "Who was Jim Henson ? Jim Henson was a puppeteer"
model = OpenAIGPTModel.from_pretrained('openai-gpt')
model.eval()
model.to('cuda')
tokenized_text = tokenizer.tokenize(text)
# Convert token to vocabulary indices
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
# Convert inputs to PyTorch tensors
tokens_tensor = torch.tensor([indexed_tokens])
# If you have a GPU, put everything on cuda
tokens_tensor = tokens_tensor.to('cuda')
# Predict hidden states features for each layer
with torch.no_grad():
hidden_states = model(tokens_tensor)
final_feats = hidden_states[0].cpu().numpy()
else:
print ('Unrecognised FEAT_TYPE.')
return final_feats
if __name__ == '__main__':
query_sent = 'a cartoon animals runs through an ice cave in a video game'
print ("Query: {}".format(query_sent))
print ("FEAT_TYPE can be selected from ['w2v', 'openai']")
w2v_feats = sent_feat(query_sent,'w2v')
print ("word2vec shape is: {}".format(w2v_feats.shape))
openai_feats = sent_feat(query_sent,'openai')
print ("openai shape is: {}".format(openai_feats.shape))