-
Notifications
You must be signed in to change notification settings - Fork 2
/
predict-skl
executable file
·188 lines (145 loc) · 6.33 KB
/
predict-skl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import sys, os, time
import argparse, fileinput
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from sklearn import cluster
import pickle
# load stream
def load(stream):
labels = list()
features_l = list()
num_features = 0
for line in fileinput.input(stream, bufsize=1000):
line = line.strip()
line = line.strip('\n')
# skip empty and comment lines
if line == "":
continue
if line[0] == '#':
#print(line)
continue
fields = line.split()
if len(fields) < 2:
sys.stderr.write("predict-skl: no features?\n")
sys.exit(-1)
# make sure the number of data fields is always the same
if not num_features:
num_features = len(fields[1:])
elif len(fields[1:]) != num_features:
sys.stderr.write("predict-skl: incorrect number of features:" + str(len(fields[1:])) + "!=" + str(num_features) + '\n')
sys.exit(-1)
labels.append(fields[0])
features_l.append([ float(x) for x in fields[1:] ])
return labels, np.array(features_l)
# output graphs showing the results of the last fitting of the given estimator
# TODO: experimental at best, needs to be extended
def graph_result(estimator, labels, features):
ulabels, ilabels = list(), list()
for i in labels:
if ulabels.count(i) == 0:
ulabels.append(i)
for i in labels:
ilabels.append(ulabels.index(i))
print("samples:", len(labels))
print("uniques:", len(ulabels))
print("features:", len(features[0]))
print("estimator:", estimator)
try:
centers = estimator.cluster_centers_
except AttributeError:
sys.stderr.write("predict-skl: sorry, graphing for the selected estimator is not supported yet\n")
sys.exit(-1)
fig = plt.figure(1, figsize=(20,12))
fig.clf()
fig.canvas.set_window_title('predict-skl: ' + str(estimator)[:str(estimator).find("(")])
ax = fig.add_subplot(2,3,1)
ax.scatter(features[:,0], features[:,1], c=ilabels, cmap=plt.get_cmap('gist_rainbow'))
ax.scatter(centers[:,0], centers[:,1], marker='x', color='r', s=150, linewidths=2)
ax.set_title('X-Y')
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax = fig.add_subplot(2,3,2)
ax.scatter(features[:,0], features[:,2], c=ilabels, cmap=plt.get_cmap('gist_rainbow'))
ax.scatter(centers[:,0], centers[:,2], marker='x', color='r', s=150, linewidths=2)
ax.set_title('X-Z')
ax.set_xlabel('X')
ax.set_ylabel('Z')
ax = fig.add_subplot(2,3,3)
ax.scatter(features[:,1], features[:,2], c=ilabels, cmap=plt.get_cmap('gist_rainbow'))
ax.scatter(centers[:,1], centers[:,2], marker='x', color='r', s=150, linewidths=2)
ax.set_title('Y-Z')
ax.set_xlabel('Y')
ax.set_ylabel('Z')
ax = fig.add_subplot(2,3,5, projection='3d')
ax.scatter(features[:,0], features[:,1], features[:,2], c=ilabels, cmap=plt.get_cmap('gist_rainbow'))
ax.scatter(centers[:,0], centers[:,1], centers[:,2], marker='x', color='r', s=150, linewidths=2)
ax.set_title('X-Y-Z')
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
plt.show()
if __name__=="__main__":
class Formatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawTextHelpFormatter): pass
cmdline = argparse.ArgumentParser(description="Unsupervised clustering algorithms from scikit-learn\n[predict module]", epilog="Default output: prediction stream with labels and predictions", formatter_class=Formatter)
cmdline.add_argument('estimator', metavar='ESTIMATOR', type=argparse.FileType('rb'), nargs='?', default='-', help="estimator dump, filename or stdin\n")
cmdline.add_argument('samples', metavar='SAMPLES', type=str, nargs='?', default='-', help="sample stream, format: [label] [[features]]\n")
cmdline.add_argument('-g', '--graph', help='graph results of estimator fitting\n', action='store_true')
cmdline.add_argument('-i', '--info', help='print estimator info and quit\n', action='store_true')
cmdline.add_argument('-r', '--retry', metavar=('X','Y'), type=int, nargs=2, default=(10,1), help='retry loading model X times in Y-second intervals before giving up\n')
args = cmdline.parse_args()
if args.estimator == None:
cmdline.print_help()
quit()
estimator = None
# get the estimator; if estimator FileType is reset to not binary then source is stdin else load from file
for i in range(args.retry[0]):
try:
if args.estimator.mode == 'r':
size = int(args.estimator.readline())
estimator = pickle.loads(args.estimator.read(size-33).encode('latin-1')) # TODO: -33 because sizeof in train-skl is not accurate
else:
estimator = pickle.load(args.estimator)
break;
# file is not ready yet, retry after pause
except EOFError:
time.sleep(args.retry[1])
if estimator == None:
sys.stderr.write("predict-skl: model loading error\n")
sys.exit(-1)
if args.info:
print(estimator)
quit()
# load sample stream and fill label and feature containers
labels, features = load(args.samples)
if len(labels) == 0:
sys.stderr.write("predict-skl: no samples loaded\n")
sys.exit(-1)
# fit model
# there are two predict functions in skl :/
try:
pred = estimator.predict(features)
except ValueError as ex:
sys.stderr.write("predict-skl: ValueError when predicting with estimator!\n")
sys.stderr.write(str(ex) + '\n')
sys.exit(-1)
except AttributeError as ex:
#sys.stderr.write("# predict-skl: AttributeError when predicting with estimator!\n")
#sys.stderr.write('# ' + str(ex) + '\n')
#sys.stderr.write("# predict-skl: trying fit_predict()!\n")
try:
pred = estimator.fit_predict(features)
except ValueError as ex:
sys.stderr.write("predict-skl: ValueError when predicting with estimator!\n")
sys.stderr.write(str(ex) + '\n')
sys.exit(-1)
# graph
if args.graph:
graph_result(estimator, labels, features)
quit()
# print predictions
for i in range(0, len(labels)):
print(labels[i], pred[i], sep='\t')