-
Notifications
You must be signed in to change notification settings - Fork 107
/
main.py
105 lines (77 loc) · 2.83 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
# -*- coding: utf-8 -*-
import argparse
import numpy as np
from pprint import pprint
from PIL import Image
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import grad
import torchvision
from torchvision import models, datasets, transforms
print(torch.__version__, torchvision.__version__)
from utils import label_to_onehot, cross_entropy_for_onehot
parser = argparse.ArgumentParser(description='Deep Leakage from Gradients.')
parser.add_argument('--index', type=int, default="25",
help='the index for leaking images on CIFAR.')
parser.add_argument('--image', type=str,default="",
help='the path to customized image.')
args = parser.parse_args()
device = "cpu"
if torch.cuda.is_available():
device = "cuda"
print("Running on %s" % device)
dst = datasets.CIFAR100("~/.torch", download=True)
tp = transforms.ToTensor()
tt = transforms.ToPILImage()
img_index = args.index
gt_data = tp(dst[img_index][0]).to(device)
if len(args.image) > 1:
gt_data = Image.open(args.image)
gt_data = tp(gt_data).to(device)
gt_data = gt_data.view(1, *gt_data.size())
gt_label = torch.Tensor([dst[img_index][1]]).long().to(device)
gt_label = gt_label.view(1, )
gt_onehot_label = label_to_onehot(gt_label)
plt.imshow(tt(gt_data[0].cpu()))
from models.vision import LeNet, weights_init
net = LeNet().to(device)
torch.manual_seed(1234)
net.apply(weights_init)
criterion = cross_entropy_for_onehot
# compute original gradient
pred = net(gt_data)
y = criterion(pred, gt_onehot_label)
dy_dx = torch.autograd.grad(y, net.parameters())
original_dy_dx = list((_.detach().clone() for _ in dy_dx))
# generate dummy data and label
dummy_data = torch.randn(gt_data.size()).to(device).requires_grad_(True)
dummy_label = torch.randn(gt_onehot_label.size()).to(device).requires_grad_(True)
plt.imshow(tt(dummy_data[0].cpu()))
optimizer = torch.optim.LBFGS([dummy_data, dummy_label])
history = []
for iters in range(300):
def closure():
optimizer.zero_grad()
dummy_pred = net(dummy_data)
dummy_onehot_label = F.softmax(dummy_label, dim=-1)
dummy_loss = criterion(dummy_pred, dummy_onehot_label)
dummy_dy_dx = torch.autograd.grad(dummy_loss, net.parameters(), create_graph=True)
grad_diff = 0
for gx, gy in zip(dummy_dy_dx, original_dy_dx):
grad_diff += ((gx - gy) ** 2).sum()
grad_diff.backward()
return grad_diff
optimizer.step(closure)
if iters % 10 == 0:
current_loss = closure()
print(iters, "%.4f" % current_loss.item())
history.append(tt(dummy_data[0].cpu()))
plt.figure(figsize=(12, 8))
for i in range(30):
plt.subplot(3, 10, i + 1)
plt.imshow(history[i])
plt.title("iter=%d" % (i * 10))
plt.axis('off')
plt.show()