-
Notifications
You must be signed in to change notification settings - Fork 7
/
demo.py
executable file
·91 lines (69 loc) · 2.34 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import argparse
import numpy
import torch
from torch.autograd import Variable
import data
def softmax(x):
"""Compute softmax values for each sets of scores in x."""
e_x = numpy.exp(x - numpy.max(x, axis=1, keepdims=True))
return e_x / e_x.sum(axis=1, keepdims=True)
numpy.set_printoptions(precision=2, suppress=True, linewidth=5000)
parser = argparse.ArgumentParser(description='PyTorch NLI Language Model')
# Model parameters.
parser.add_argument('--data', type=str, default='../datasets/nli_data/',
help='location of the data corpus')
parser.add_argument('--checkpoint', type=str, default='./model/model.pt',
help='model checkpoint to use')
parser.add_argument('--seed', type=int, default=1111,
help='random seed')
args = parser.parse_args()
def build_tree(depth, sen):
assert len(depth) == len(sen)
if len(depth) == 1:
parse_tree = sen[0]
else:
idx_max = numpy.argmax(depth)
parse_tree = []
if len(sen[:idx_max]) > 0:
tree0 = build_tree(depth[:idx_max], sen[:idx_max])
parse_tree.append(tree0)
tree1 = sen[idx_max]
if len(sen[idx_max + 1:]) > 0:
tree2 = build_tree(depth[idx_max + 1:], sen[idx_max + 1:])
tree1 = [tree1, tree2]
if parse_tree == []:
parse_tree = tree1
else:
parse_tree.append(tree1)
return parse_tree
def MRG(tr):
if isinstance(tr, str):
#return '(' + tr + ')'
return tr + ' '
else:
s = '( '
for subtr in tr:
s += MRG(subtr)
s += ') '
return s
# Set the random seed manually for reproducibility.
torch.manual_seed(args.seed)
with open(args.checkpoint, 'rb') as f:
model = torch.load(f)
model.eval()
print model
model.cpu()
corpus = data.Corpus(args.data)
ntokens = len(corpus.dictionary)
input = Variable(torch.rand(1, 1).mul(ntokens).long(), volatile=True)
while True:
sens = raw_input('Input a sentences:')
words = sens.strip().split()
x = numpy.array([corpus.dictionary[w] for w in words])
input = Variable(torch.LongTensor(x[:, None]))
hidden = model.init_hidden(1)
_, hidden = model(input, hidden)
gates = model.gates.squeeze().data.numpy()
parse_tree = build_tree(gates, words)
print parse_tree
print MRG(parse_tree)