-
Notifications
You must be signed in to change notification settings - Fork 20
/
ckpt2json
executable file
·176 lines (145 loc) · 4.94 KB
/
ckpt2json
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
#!/usr/bin/env python3
#==============================================================================
# Functionality
#==============================================================================
import pdb
import sys
import os
import re
# utility funcs, classes, etc go here.
def asserting(cond):
if not cond:
pdb.set_trace()
assert(cond)
def has_stdin():
return not sys.stdin.isatty()
def reg(pat, flags=0):
return re.compile(pat, re.VERBOSE | flags)
def matches(s, pat, *, case='smart', verbose=False):
inverted = True
if len(pat) > 0 and pat[0] == "!":
inverted = False
pat = pat[1:]
flags = 0
if case == 'off':
if verbose: print('case==of')
flags = re.IGNORECASE
elif case == 'on':
if verbose: print('case==on')
pass
elif case == 'smart': # smart case
if verbose: print('case==smart')
if not any(x.isupper() for x in pat):
if verbose: print('smartcase-ignore')
flags = re.IGNORECASE
else:
raise ValueError('unknown case specification %s' % case)
x = reg('(?:.*)' + pat, flags).match(s)
return inverted == (not (not x))
def narrow(strings, searches, **kws):
if isinstance(strings, str):
return len(narrow([strings], searches, **kws)) > 0
if isinstance(searches, str):
searches = [searches]
if searches is None:
searches = []
for search in searches:
strings = [s for s in strings if matches(s, search, **kws)]
return strings
#==============================================================================
# Cmdline
#==============================================================================
import argparse
parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter,
description="""
TODO
""")
parser.add_argument('checkpoint',
help="The path to the Tensorflow checkpoint. E.g. gs://mlpublic-euw4/runs/bigrun97/dec28/run6_evos0_imagenet_dlrmul_0_4/model.ckpt-9658250" )
parser.add_argument('-v', '--verbose',
action="store_true",
help="verbose output" )
parser.add_argument('-d', '--dry-run',
action="store_true",
help="Just list the tensors that would be exported; don't actually download any of them." )
parser.add_argument('-e', '--pattern', nargs='*',
help="only export layers that match this (python-style) regexp." )
args = None
#==============================================================================
# Main
#==============================================================================
import json
import tensorflow as tf
import os
import tqdm
from natsort import natsorted
def maketree(path):
try:
os.makedirs(path)
except:
pass
def numpy2json(arr):
return json.dumps(arr.tolist())
def tensor2json(ckpt, outdir, name):
outpath = os.path.join(outdir, name.replace('/', os.path.sep))
maketree(outpath)
outpath = os.path.join(outpath, 'data.json')
if os.path.exists(outpath) and os.path.getsize(outpath) > 0:
return
with open(outpath, 'w') as f:
tensor = ckpt.get_tensor(name)
json = numpy2json(tensor)
print(json, file=f)
from contextlib import contextmanager
@contextmanager
def nullcontext():
yield
def maybe_open(path, *argz, **kws):
if args.dry_run:
return nullcontext()
return open(path, *argz, **kws)
def ckpt2json(checkpoint_path):
ckpt = tf.train.load_checkpoint(checkpoint_path)
#outdir = os.path.basename(checkpoint_path)
outdir = checkpoint_path.replace('://', '/')
outdir = outdir.replace('/', os.path.sep)
maketree(outdir)
shapes = ckpt.get_variable_to_shape_map()
types = {variable: type.name for variable, type in ckpt.get_variable_to_dtype_map().items()}
names = list(natsorted([name for name, shape in shapes.items()]))
with maybe_open(os.path.join(outdir, 'shapes.jsonl'), 'w') as f:
for name in names:
if f is None: continue
print(json.dumps([name, shapes[name]]), file=f)
with maybe_open(os.path.join(outdir, 'types.jsonl'), 'w') as f:
for name in names:
if f is None: continue
print(json.dumps([name, types[name]]), file=f)
with tqdm.tqdm(narrow(names, args.pattern)) as pbar:
for name in pbar:
args.log = lambda x: pbar.write(x)
args.log('{shape!r:>24} {name!r} -> {path}'.format(name=name, path=os.path.join(outdir, name, 'data.json'), shape=shapes[name], type=types[name]))
if not args.dry_run:
tensor2json(ckpt, outdir, name)
def run():
if args.verbose:
print(args)
ckpt2json(args.checkpoint)
def main():
try:
global args
if not args:
args = parser.parse_args()
return run()
except IOError:
# http://stackoverflow.com/questions/15793886/how-to-avoid-a-broken-pipe-error-when-printing-a-large-amount-of-formatted-data
try:
sys.stdout.close()
except IOError:
pass
try:
sys.stderr.close()
except IOError:
pass
if __name__ == "__main__":
main()