-
Notifications
You must be signed in to change notification settings - Fork 2
/
score-tool
446 lines (376 loc) · 17.5 KB
/
score-tool
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
#!/usr/bin/env python3
import sys, os, glob
import argparse, fileinput
import math, random
import numpy as np
import csv, sqlite3
# helper to get uniques from a list or list of lists whilst preserving order
# https://www.peterbe.com/plog/uniqifiers-benchmark
def unique_ordered(seq, listlist=False):
seen = set()
seen_add = seen.add
if listlist:
seq = [ tuple(x) for x in seq ]
return [ x for x in seq if not (x in seen or seen_add(x)) ]
# helper for aggregation columns, with an option for shorter column names
extensions = False # flag that gets set in main if extension loading was successful
def aggregate_sql(c, sel_as=False):
ret = 'avg('+c+')' + ('' if not sel_as else 'as avg') + ','
if extensions:
ret += 'variance('+c+')' + ('' if not sel_as else 'as var') + ','
ret += 'stdev('+c+')' + ('' if not sel_as else 'as stdev') + ','
ret += 'min('+c+')' + ('' if not sel_as else 'as min') + ','
ret += 'max('+c+')' + ('' if not sel_as else 'as max') + ','
ret += 'count('+c+')' + ('' if not sel_as else 'as count') + ','
return ret
# create a new table with the specified columns
def db_create_table(cursor, table, fields):
# build SQL term
create_str = 'CREATE TABLE '+table+' ('
for i in fields: create_str += i.replace('\"','').replace('-','_') + ',' # remove " and -
create_str = create_str[:-1]
create_str += ');'
try:
cursor.execute(create_str)
except sqlite3.OperationalError as ex:
sys.stderr.write('exception while executing sql term:\n' + create_str + '\n')
#print(fields)
raise ex
# add a new row to a table with the specified values
def db_add_row(cursor, values, args):
# build SQL term, add one '?' for each value, for use with execute()
insert_str = 'INSERT INTO '+args.table+' VALUES ('
for i in values: insert_str += '?,'
insert_str = insert_str[:-1]
insert_str += ');'
if args.verbose: sys.stderr.write(insert_str + '\n')
try:
cursor.execute(insert_str, values)
except sqlite3.OperationalError as ex:
# parse the number of table columns and specified values from the error
split = str(ex).split()
ncol = int(split[split.index('columns')-1])
nval = int(split[split.index('values')-1])
# if more values are given than the table has columns skip the row
if ncol < nval:
err = 'sqlite3.OperationalError: {}\n --> skipped: {}\n'.format(str(ex),str(values))
if args.verbose: sys.stderr.write(err)
# if less values are given than the table has columns extend the value list with NULL values and try again
elif ncol > nval:
values.extend([ None for x in range(ncol-nval) ])
if args.verbose: sys.stderr.write('extended ' + str(nval) + ' --> ' + str(ncol) + '\n')
db_add_row(cursor, values, args)
# prints some short info on the specified table
def db_print_info(cursor, table):
cursor.execute('SELECT * FROM ' + table)
entries = cursor.fetchall()
print('table:\n ', table)
print('entries:\n ', len(entries))
print('columns:')
for c in [ d[0] for d in cursor.description ]: print(' ', c)
# print the loaded table
def db_print_table(cursor, args):
sql_select = False
# get the selected columns, either by name or index
if args.select_columns:
try:
sel_columns = [int(c) for c in args.select_columns.split(',')]
except ValueError:
sel_columns = args.select_columns.split(',')
sql_select = True
else:
sel_columns = []
if not args.sql:
# build the SQL term
args.sql = 'SELECT ' + ('*' if not sql_select else args.select_columns) + ' FROM ' + args.table
if args.where: args.sql += ' WHERE (' + args.where + ')'
if args.order_by:
if args.lessthan:
args.sql += (' AND ' if args.where else ' WHERE ') + args.order_by + '<' + str(args.lessthan)
if args.greaterequalthan:
args.sql += (' AND ' if args.where or args.lessthan else ' WHERE ') + args.order_by + '>=' + str(args.greaterequalthan)
args.sql += ' ORDER BY ' + args.order_by + (' ASC' if args.ascending else ' DESC')
if args.verbose: sys.stderr.write('\n' + args.sql + '\n')
cursor.execute(args.sql)
columns = [ d[0] for d in cursor.description ]
rows = cursor.fetchall()
if args.verbose: sys.stderr.write(str(len(rows)) + ' rows in ' + str(len(columns)) + ' columns loaded\n')
# put limits if specified
nc = min(len(columns), args.limit_columns)
if nc == 0: nc = len(columns)
nr = min(len(rows), args.limit_rows)
if nr == 0: nr = len(rows)
# pass on selection, range on limits if none
if not sql_select and len(sel_columns): scols = sel_columns
else: scols = range(nc)
srows = range(nr)
print()
# if raw only print values without formatting
if args.raw:
for c in scols: print(columns[c], end='\t')
print()
for r in srows:
for c in scols: print(rows[r][c], end='\t')
print()
else:
# get the selected and limited rows and columns
tab_cols=list()
tab_rows=list()
for c in scols: tab_cols.append(columns[c])
for r in srows: tab_rows.append([ c for i,c in enumerate(rows[r]) if scols.count(i) != 0 ])
# ordering indicator
if args.order_by in tab_cols:
ix = tab_cols.index(args.order_by)
if args.ascending: tab_cols[ix] = tab_cols[ix] + b' \u25B2'.decode('unicode_escape')
else: tab_cols[ix] = tab_cols[ix] + b' \u25BC'.decode('unicode_escape')
# print table with tabulate
print(tabulate(tab_rows, headers=tab_cols, tablefmt=args.format))
# print aggregation values for the selected columns
def db_print_aggregate(cursor, args):
# get the selected columns, make sure they are strings not indices
if args.select_columns:
try:
sel_columns = [int(c) for c in args.select_columns.split(',')] # TODO: check all columns not just the first occurrence
sys.stderr.write('Select columns by name, not index!\n')
sys.exit(-1)
except ValueError:
sel_columns = args.select_columns.split(',')
else:
sys.stderr.write('No columns selected!\n')
sys.exit(-1)
print('aggregation results for columns in ' + args.table)
columns = list()
rows = list()
for i in sel_columns:
# build the SQL term with the aggregating functions
args.sql = 'SELECT ' + aggregate_sql(i, True)
args.sql = args.sql[:-1] + ' FROM ' + args.table
if args.where: args.sql += ' WHERE (' + args.where + ')'
if args.order_by:
if args.lessthan:
args.sql += (' AND ' if args.where else ' WHERE ') + args.order_by + '<' + str(args.lessthan)
if args.greaterequalthan:
args.sql += (' AND ' if args.where or args.lessthan else ' WHERE ') + args.order_by + '>=' + str(args.greaterequalthan)
if args.verbose: sys.stderr.write('\n' + args.sql + '\n')
cursor.execute(args.sql)
columns = [ d[0] for d in cursor.description ]
# each result per selected column has only one row (aggregation)
rows += [list(cursor.fetchall()[0])]
rows[len(rows) - 1].insert(0,i)
print()
print(tabulate(rows, headers=columns, tablefmt=args.format))
print()
# print distinct values for the selected columns
def db_print_distinct(cursor, args):
# get the selected columns, make sure they are strings not indices
if args.select_columns:
try:
sel_columns = [int(c) for c in args.select_columns.split(',')] # TODO: check all columns not just the first occurrence
sys.stderr.write('Select columns by name, not index!\n')
sys.exit(-1)
except ValueError:
sel_columns = args.select_columns.split(',')
else:
sys.stderr.write('No columns selected!\n')
sys.exit(-1)
print('distinct values for columns in ' + args.table)
columns = list()
rows = list()
for ix in range(len(sel_columns)):
col = sel_columns[ix]
# build the SQL term, orders each column
args.sql = 'SELECT DISTINCT ' + col + ' FROM ' + args.table
if args.where: args.sql += ' WHERE (' + args.where + ')'
args.sql += ' ORDER BY ' + col + (' ASC' if args.ascending else ' DESC')
if args.verbose: sys.stderr.write('\n' + args.sql + '\n')
# add empty column if error
try: cursor.execute(args.sql)
except sqlite3.OperationalError as ex:
columns += [col]
continue
columns += [ d[0] for d in cursor.description ]
new = cursor.fetchall()
# extend the table if new column has more rows
if len(rows) < len(new): rows.extend([[None,]*len(sel_columns) for i in range(len(new)-len(rows))])
for jx in range(len(new)): rows[jx][ix] = new[jx][0]
print()
print(tabulate(rows, headers=columns, tablefmt=args.format))
print()
# enables parsing the filename according to a specified key '[sep][key1][sep][key2]...'
# if a file is specified returns the split values from that filename according to key
# if no file is specified returns the field names of the key
# has the possibility of one merged key, for example:
# key: _cat1_[cat2]_cat3
# file: this_is_an_example
# cat1: this
# cat2: is_an
# cat3: example
def parse_filekey(key, file=None):
sep = key[0]
keys = key.split(sep)
del keys[0]
fields = keys
# check if merged key exists and set index if so
merge_ix = -1
for i in range(len(keys)):
if keys[i].find('[') != -1 and keys[i].find(']') != -1:
merge_ix = i
break;
if file:
if merge_ix > -1:
merge = file.split(sep)
# first add the fields before the merge index, then those after, then the rest as one merged field
for i in range(merge_ix): fields[i] = merge.pop(0)
for i in reversed([i for i,x in enumerate(fields) if i > merge_ix]): fields[i] = merge.pop()
fields[merge_ix] = sep.join(merge)
else: fields = file.split(sep)
return fields
# load score files, generated by [grt score -f]
def load_scores(files, cursor, args):
init = False
for f in files:
# skip directories, empty and nonexisting files
if not os.path.isfile(f):
if args.verbose: sys.stderr.write("skipped not file " + f + '\n')
continue
if os.path.getsize(f) == 0:
if args.verbose: sys.stderr.write("skipped empty file " + f + '\n')
continue
for line in fileinput.input(f, bufsize=1000):
line = line.strip()
line = line.strip('\n')
# first line with # has the column headers
if not init and line[0] == '#':
fields = line.split()
if args.total: fields = fields[:8] # only load total scores, ignore per class
fields[0] = 'file' # first field is the #, change to file name
if args.file_key: # insert file key columns on the left of the table
fields[:0] = parse_filekey(args.file_key)
# create new table with loaded columns
db_create_table(cursor, args.table, fields)
init = True
# skip further comment lines
if line[0] == '#': continue
# first value is the filename for the 'file' column
values = [ os.path.basename(f) ]
if args.file_key: # parse the file key values from the filename
values[:0] = parse_filekey(args.file_key, os.path.basename(f))
# load rest of values
if args.total: values.extend(line.split()[:7])
else: values.extend(line.split())
# cast number strings to floats
for i in range(len(values)):
try: values[i] = float(values[i])
except ValueError: pass
# add loaded values to table
db_add_row(cursor, values, args)
# this function can be used to execute complex custom sql queries and calculations on the loaded tables, e.g. if they are not supported via sqlite syntax
def db_sandbox(cursor, args):
pass
if __name__=="__main__":
class Formatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawTextHelpFormatter): pass
cmdline = argparse.ArgumentParser(description="parse scoring files\nfiles generated via [grt score -f]", epilog="Default output: all rows and columns of the loaded table", formatter_class=Formatter)
cmdline.add_argument('files', type=str, nargs='*', help='files to parse (stdin reads one path per line, e.g. per [find])')
cmdline.add_argument('-v', '--verbose', help='be verbose\n', action='store_true')
cmdline.add_argument('--recursive', help='enable recursive file glob\n', action='store_true')
cmdline.add_argument('-i', '--info', help='print info on loaded table\n', action='store_true')
cmdline.add_argument('--raw', help='raw output, no tabulate\n', action='store_true')
cmdline.add_argument('-t', '--total', help='only load total scores, no per-class scores\n', action='store_true')
cmdline.add_argument('--sql', metavar='SQL', type=str, help='perform SQL term on loaded table\n')
cmdline.add_argument('--where', metavar='SQL', type=str, help='add where clause to sql term\n')
cmdline.add_argument('--dump', metavar='FILE', type=str, help='dump created database to SQL text format\n')
cmdline.add_argument('--load', metavar='FILE', type=str, nargs='*', help='load database from SQL dump\n')
cmdline.add_argument('--sqlitefunctions', metavar='FILE', type=str, default='./libsqlitefunctions', help='SQL extension functions; path to shared object file\n')
cmdline.add_argument('--sandbox', help='execute sandbox code; see code/comments; for advanced usage\n', action='store_true')
cmdline_table_group = cmdline.add_argument_group('table arguments')
cmdline_table_group.add_argument('--table', metavar='NAME', type=str, default='scores', help='specify the table name\n')
cmdline_table_group.add_argument('-c','--limit-columns', metavar='N', type=int, default=0, help='limit table printing to first N columns\n')
cmdline_table_group.add_argument('-r','--limit-rows', metavar='N', type=int, default=0, help='limit table printing to first N rows\n')
cmdline_table_group.add_argument('-sc','--select-columns', metavar='COLS', type=str, help='limit table printing to selected columns, format:\n- 1,2,4,6,...\n- [column1],[column2],...\noverrides --limit-columns\n')
cmdline_table_group.add_argument('-a', '--aggregate', help='aggregate on selected columns\n', action='store_true')
cmdline_table_group.add_argument('-d', '--distinct', help='get distinct values from selected columns\n', action='store_true')
cmdline_table_group.add_argument('-o', '--order-by', metavar='COL', type=str, help='order table by COL\n')
cmdline_table_group.add_argument('-asc', '--ascending', help='ascending order instead of descending\n', action='store_true')
cmdline_table_group.add_argument('-lt', '--lessthan', metavar='N', type=float, help='only print rows with [--order-by] value less than N\n')
cmdline_table_group.add_argument('-get', '--greaterequalthan', metavar='N', type=float, help='only print rows with [--order-by] value greater or equal than N\n')
cmdline_table_group.add_argument('--file-key', metavar='KEY', type=str, help='if set, will try to break the file names according to the specified key\nformat: [sep][key1][sep][key2]...\nexactly one key can be a merged key (as \'[key]\'), which will merge remaining name even if it contains [sep]\n')
cmdline_table_group.add_argument('--format', metavar='FMT', type=str, default='simple', help='tabulate table format, no effect if --raw\n')
args = cmdline.parse_args()
# input checks
if not len(args.files) and not args.load:
sys.stderr.write('specify "-", at least one file, or the --load option\n')
sys.exit(-1)
if len(args.files) and args.load:
print('warning! if --load is set other specified files will be ignored')
# lt and get are applied to order-by column
if (args.lessthan or args.greaterequalthan) and not args.order_by:
sys.stderr.write('warning: -lt and -get need --order-by to be set!\n')
files = list()
# TODO: also allow direct passing of scores via stdin
# load score filenames from stdin
if not args.load and args.files[0] == '-':
for line in fileinput.input(args.files[0], bufsize=1000):
line = line.strip()
line = line.strip('\n')
# skip empty and comment lines
if line == "":
continue
if line[0] == '#':
print(line)
continue
files.append(line)
# load score filenames from glob
else:
for f in args.files:
files.extend(glob.glob(f, recursive=args.recursive))
if not args.load and len(files) == 0:
sys.stderr.write('No files loaded!\n')
sys.exit(-1)
# create in-memory database
db = sqlite3.connect(':memory:')
# need some aggregating functions
if args.aggregate:
db.enable_load_extension(True)
try:
db.load_extension(args.sqlitefunctions)
extensions = True
except sqlite3.OperationalError as ex:
sys.stderr.write('Error loading sqlite extension ' + args.sqlitefunctions + '\nsee extension-functions.c for instructions, compile into shared object and set via --sqlitefunctions\n')
sys.stderr.write(str(ex) + '\ncontinuing without aggregates: var, stdev\n\n')
db.enable_load_extension(False)
# reference db cursor position
cursor = db.cursor()
# load score files, or sql dump files
if not args.load:
load_scores(files, cursor, args)
else:
for l in args.load:
f = open(l,'r')
sql = f.read()
cursor.executescript(sql)
db.commit()
# check table names, adjust table arg if only one
cursor.execute('SELECT name FROM sqlite_master WHERE type="table"')
tables = cursor.fetchall()
if len(tables) == 1:
args.table = tables[0][0]
# table info
if args.info:
db_print_info(cursor, args.table)
quit()
# dump loaded table
if args.dump:
with open(args.dump, 'w') as f:
for line in db.iterdump():
f.write('%s\n' % line)
quit()
# need tabulate for formatted table output
if not args.raw: from tabulate import tabulate
if args.aggregate:
db_print_aggregate(cursor, args)
elif args.distinct:
db_print_distinct(cursor, args)
elif args.sandbox:
db_sandbox(cursor, args)
else:
db_print_table(cursor, args)