-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathutils_plus.py
70 lines (56 loc) · 2.59 KB
/
utils_plus.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import torch
def batch_clip(model, max_norm):
"""
Perform clipping given the gradient per each sample and clipping threshold max_norm.
max_norm: the threshold for clipping.
"""
grads = [param.grad_sample for param in model.parameters()]
batch_size = grads[0].size(0)
grad_norms = []
for grad_p in grads:
grad_p_flat = grad_p.view(batch_size, -1)
grad_norms.append(torch.norm(grad_p_flat, dim=1))
grad_norms = torch.stack(grad_norms, dim=1)
ones = torch.ones(size=grad_norms.size(), device=grad_norms.device)
scale_factors = torch.maximum(grad_norms / max_norm, ones)
scale_factors = torch.reciprocal(scale_factors)
for k, param in zip(range(len(grads)), model.parameters()):
param.grad_sample = torch.einsum("i...,i", grads[k], scale_factors[:,k])
def batch_noising_scale(model, clip, noise_multiplier, batch_size):
"""
Add to the gradient of each parameter a multivariate gaussian
whose covariance matrix is
``clip * noise_numtiplier * Identity.''
"""
for param in model.parameters():
param.grad = (param.grad_sample + torch.normal(0, clip * noise_multiplier, param.grad_sample.shape, device=param.grad_sample.device)) / batch_size
del param.grad_sample
def topk_compress(model, percentile):
"""
Compress the gradients of model parameters, keeping only the components whose magnitude is in the top percentile.
"""
for parameter in model.parameters():
grad_p_flat = parameter.grad.flatten()
k = int(len(grad_p_flat) * percentile)
topk_vals, topk_inds = torch.topk(input=torch.abs(grad_p_flat), k=k)
mask = torch.zeros(size=grad_p_flat.shape).to(topk_inds.get_device())
mask.scatter_(0, topk_inds, 1, reduce='add')
parameter.grad = torch.multiply(mask, grad_p_flat).reshape(shape=parameter.grad.shape)
def topk_mask_single(grad_p, percentile):
"""
return topk mask of grad_p based on percentile.
"""
grad_p_flat = grad_p.flatten()
k = int(len(grad_p_flat) * percentile)
topk_vals, topk_inds = torch.topk(input=torch.abs(grad_p_flat), k=k)
mask = torch.zeros(size=grad_p_flat.shape).to(topk_inds.get_device())
mask.scatter_(0, topk_inds, 1, reduce='add')
return mask.reshape(shape=grad_p.shape)
def topk_mask_all(grads, percentile):
masks = []
for grad_p in grads:
masks.append(topk_mask_single(grad_p, percentile))
return masks
def apply_external_mask(model, ext_masks):
for parameter, mask in zip(model.parameters(), ext_masks):
parameter.grad = torch.multiply(mask, parameter.grad)