-
Notifications
You must be signed in to change notification settings - Fork 201
/
main.cc
144 lines (118 loc) · 4.34 KB
/
main.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
// OpenMP is required..
// g++-4.8 -ow2v -fopenmp -std=c++0x -Ofast -march=native -funroll-loops main.cc -lpthread
#include "word2vec.h"
#include <iostream>
#include <initializer_list>
int accuracy(Word2Vec<std::string>& model, std::string questions, int restrict_vocab = 30000) {
std::ifstream in(questions);
std::string line;
auto lower = [](std::string& data) { std::transform(data.begin(), data.end(), data.begin(), ::tolower);};
size_t count = 0, correct = 0, ignore = 0, almost_correct = 0;
const int topn = 10;
while (std::getline(in, line)) {
if (line[0] == ':') {
printf("%s\n", line.c_str());
continue;
}
std::istringstream iss(line);
std::string a, b, c, expected;
iss >> a >> b >> c >> expected;
lower(a); lower(b); lower(c); lower(expected);
if (!model.has(a) || !model.has(b) || !model.has(c) || !model.has(expected)) {
printf("unhandled: %s %s %s %s\n", a.c_str(), b.c_str(), c.c_str(), expected.c_str());
++ignore;
continue;
}
++count;
std::vector<std::string> positive{b, c}, negative{a};
auto predict = model.most_similar(positive, negative, topn);
if (predict[0].first == expected) { ++ correct; ++almost_correct; }
else {
bool found = false;
for (auto& v: predict) {
if (v.first == expected) { found = true; break; }
}
if (found) ++almost_correct;
else printf("predicted: %s, expected: %s\n", predict[0].first.c_str(), expected.c_str());
}
}
if (count > 0) printf("predict %lu out of %lu (%f%%), almost correct %lu (%f%%) ignore %lu\n", correct, count, correct * 100.0 / count, almost_correct, almost_correct * 100.0 / count, ignore);
return 0;
}
int main(int argc, const char *argv[])
{
Word2Vec<std::string> model(200);
using Sentence = Word2Vec<std::string>::Sentence;
using SentenceP = Word2Vec<std::string>::SentenceP;
model.sample_ = 0;
// model.window_ = 10;
// model.phrase_ = true;
int n_workers = 4;
::srand(::time(NULL));
auto distance = [&model]() {
while (1) {
std::string s;
std::cout << "\nFind nearest word for (:quit to break):";
std::cin >> s;
if (s == ":quit") break;
auto p = model.most_similar(std::vector<std::string>{s}, std::vector<std::string>(), 10);
size_t i = 0;
for (auto& v: p) {
std::cout << i++ << " " << v.first << " " << v.second << std::endl;
}
}
};
bool train = true, test = false;
if (argc > 1 && std::string(argv[1]) == "test") {
std::swap(train, test);
}
if (train) {
std::vector<SentenceP> sentences;
size_t count =0;
const size_t max_sentence_len = 200;
SentenceP sentence(new Sentence);
// wget http://mattmahoney.net/dc/text8.zip
std::ifstream in("text8");
while (true) {
std::string s;
in >> s;
if (s.empty()) break;
++count;
sentence->tokens_.push_back(std::move(s));
if (count == max_sentence_len) {
count = 0;
sentence->words_.reserve(sentence->tokens_.size());
sentences.push_back(std::move(sentence));
sentence.reset(new Sentence);
}
}
if (!sentence->tokens_.empty())
sentences.push_back(std::move(sentence));
auto cstart = std::chrono::high_resolution_clock::now();
model.build_vocab(sentences);
auto cend = std::chrono::high_resolution_clock::now();
printf("load vocab: %.4f seconds\n", std::chrono::duration_cast<std::chrono::microseconds>(cend - cstart).count() / 1000000.0);
cstart = cend;
model.train(sentences, n_workers);
cend = std::chrono::high_resolution_clock::now();
printf("train: %.4f seconds\n", std::chrono::duration_cast<std::chrono::microseconds>(cend - cstart).count() / 1000000.0);
cstart = cend;
model.save("vectors.bin");
model.save_text("vectors.txt");
cend = std::chrono::high_resolution_clock::now();
printf("save model: %.4f seconds\n", std::chrono::duration_cast<std::chrono::microseconds>(cend - cstart).count() / 1000000.0);
// distance();
}
if (test) {
auto cstart = std::chrono::high_resolution_clock::now();
model.load("vectors.bin");
auto cend = std::chrono::high_resolution_clock::now();
printf("load model: %.4f seconds\n", std::chrono::duration_cast<std::chrono::microseconds>(cend - cstart).count() / 1000000.0);
distance();
cstart = cend;
accuracy(model, "questions-words.txt");
cend = std::chrono::high_resolution_clock::now();
printf("test model: %.4f seconds\n", std::chrono::duration_cast<std::chrono::microseconds>(cend - cstart).count() / 1000000.0);
}
return 0;
}