forked from wang-xinyu/tensorrtx
-
Notifications
You must be signed in to change notification settings - Fork 0
/
gen_tf_wts.py
31 lines (25 loc) · 851 Bytes
/
gen_tf_wts.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from sys import prefix
import tensorflow as tf
from tensorflow.python import pywrap_tensorflow
import numpy as np
import struct
model_dir = "model"
ckpt = tf.train.get_checkpoint_state(model_dir)
ckpt_path = ckpt.model_checkpoint_path
reader = pywrap_tensorflow.NewCheckpointReader(ckpt_path)
param_dict = reader.get_variable_to_shape_map()
f = open(r"psenet.wts", "w")
keys = param_dict.keys()
f.write("{}\n".format(len(keys)))
for key in keys:
weight = reader.get_tensor(key)
print(key, weight.shape)
if len(weight.shape) == 4:
weight = np.transpose(weight, (3, 2, 0, 1))
print(weight.shape)
weight = np.reshape(weight, -1)
f.write("{} {} ".format(key, len(weight)))
for w in weight:
f.write(" ")
f.write(struct.pack(">f", float(w)).hex())
f.write("\n")