-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathyellowfin.py
571 lines (490 loc) · 22.2 KB
/
yellowfin.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
import math
import numpy as np
import torch
import copy
import logging
import os
import pickle as cp
# eps for numerical stability
eps = 1e-6
class YFOptimizer(object):
def __init__(self, var_list, lr=0.0001, mu=0.0, clip_thresh=None, weight_decay=0.0,
beta=0.999, curv_win_width=20, zero_debias=True, sparsity_debias=False, delta_mu=0.0,
auto_clip_fac=None, force_non_inc_step=False, h_max_log_smooth=True, h_min_log_smooth=True,
checkpoint_interval=1000, verbose=False, adapt_clip=True, stat_protect_fac=100.0, catastrophic_move_thresh=100.0,
use_disk_checkpoint=False, checkpoint_dir='./YF_workspace'):
'''
clip thresh is the threshold value on ||lr * gradient||
delta_mu can be place holder/variable/python scalar. They are used for additional
momentum in situations such as asynchronous-parallel training. The default is 0.0
for basic usage of the optimizer.
Args:
lr: python scalar. The initial value of learning rate, we use 1.0 in our paper.
mu: python scalar. The initial value of momentum, we use 0.0 in our paper.
clip_thresh: python scalar. The manaully-set clipping threshold for tf.clip_by_global_norm.
if None, the automatic clipping can be carried out. The automatic clipping
feature is parameterized by argument auto_clip_fac. The auto clip feature
can be switched off with auto_clip_fac = None
beta: python scalar. The smoothing parameter for estimations.
sparsity_debias: gradient norm and curvature are biased to larger values when
calculated with sparse gradient. This is useful when the model is very sparse,
e.g. LSTM with word embedding. For non-sparse CNN, turning it off could slightly
accelerate the speed.
delta_mu: for extensions. Not necessary in the basic use.
force_non_inc_step: in some very rare cases, it is necessary to force ||lr * gradient||
to be not increasing dramatically for stableness after some iterations.
In practice, if turned on, we enforce lr * sqrt(smoothed ||grad||^2)
to be less than 2x of the minimal value of historical value on smoothed || lr * grad ||.
This feature is turned off by default.
checkpoint_interval: interval to do checkpointing. For potential recovery from crashing.
stat_protect_fac: a loose hard adaptive threshold over ||grad||^2. It is to protect stat
from being destropied by exploding gradient.
Other features:
If you want to manually control the learning rates, self.lr_factor is
an interface to the outside, it is an multiplier for the internal learning rate
in YellowFin. It is helpful when you want to do additional hand tuning
or some decaying scheme to the tuned learning rate in YellowFin.
Example on using lr_factor can be found here:
https://github.com/JianGoForIt/YellowFin_Pytorch/blob/master/pytorch-cifar/main.py#L109
'''
self._lr = lr
self._mu = mu
self._lr_t = lr
self._mu_t = mu
# we convert var_list from generator to list so that
# it can be used for multiple times
self._var_list = list(var_list)
self._clip_thresh = clip_thresh
self._auto_clip_fac = auto_clip_fac
self._beta = beta
self._curv_win_width = curv_win_width
self._zero_debias = zero_debias
self._sparsity_debias = sparsity_debias
self._force_non_inc_step = force_non_inc_step
self._optimizer = torch.optim.SGD(self._var_list, lr=self._lr,
momentum=self._mu, weight_decay=weight_decay)
self._iter = 0
# global states are the statistics
self._global_state = {}
# for decaying learning rate and etc.
self._lr_factor = 1.0
# smoothing options
self._h_max_log_smooth = h_max_log_smooth
self._h_min_log_smooth = h_min_log_smooth
# checkpoint interval
self._checkpoint_interval = checkpoint_interval
self._verbose = verbose
if self._verbose:
logging.debug('Verbose mode with debugging info logged.')
# clip exploding gradient
self._adapt_clip = adapt_clip
self._exploding_grad_clip_thresh=1e3
self._exploding_grad_clip_target_value = 1e3
self._stat_protect_fac = stat_protect_fac
self._catastrophic_move_thresh = catastrophic_move_thresh
self._exploding_grad_detected = False
# workspace creation
self._use_disk_checkpoint = use_disk_checkpoint
self._checkpoint_dir = checkpoint_dir
if use_disk_checkpoint:
if not os.path.exists(self._checkpoint_dir):
os.makedirs(self._checkpoint_dir)
self._checkpoint_file = "checkpoint_pid_" + str(os.getpid())
def state_dict(self):
# for checkpoint saving
sgd_state_dict = self._optimizer.state_dict()
# for recover model internally in case of numerical issue
model_state_list = [p.data \
for group in self._optimizer.param_groups for p in group['params'] ]
global_state = self._global_state
lr_factor = self._lr_factor
iter = self._iter
lr = self._lr
mu = self._mu
clip_thresh = self._clip_thresh
beta = self._beta
curv_win_width = self._curv_win_width
zero_debias = self._zero_debias
h_min = self._h_min
h_max = self._h_max
return {
"sgd_state_dict": sgd_state_dict,
"model_state_list": model_state_list,
"global_state": global_state,
"lr_factor": lr_factor,
"iter": iter,
"lr": lr,
"mu": mu,
"clip_thresh": clip_thresh,
"beta": beta,
"curv_win_width": curv_win_width,
"zero_debias": zero_debias,
"h_min": h_min,
"h_max": h_max
}
def load_state_dict(self, state_dict):
# for checkpoint saving
self._optimizer.load_state_dict(state_dict['sgd_state_dict'])
# for recover model internally if any numerical issue happens
param_id = 0
for group in self._optimizer.param_groups:
for p in group["params"]:
p.data.copy_(state_dict["model_state_list"][param_id] )
param_id += 1
self._global_state = state_dict['global_state']
self._lr_factor = state_dict['lr_factor']
self._iter = state_dict['iter']
self._lr = state_dict['lr']
self._mu = state_dict['mu']
self._clip_thresh = state_dict['clip_thresh']
self._beta = state_dict['beta']
self._curv_win_width = state_dict['curv_win_width']
self._zero_debias = state_dict['zero_debias']
self._h_min = state_dict["h_min"]
self._h_max = state_dict["h_max"]
return
def load_state_dict_perturb(self, state_dict):
# for checkpoint saving
self._optimizer.load_state_dict(state_dict['sgd_state_dict'])
# for recover model internally if any numerical issue happens
param_id = 0
for group in self._optimizer.param_groups:
for p in group["params"]:
p.data.copy_(state_dict["model_state_list"][param_id] )
p.data += 1e-8
param_id += 1
self._global_state = state_dict['global_state']
self._lr_factor = state_dict['lr_factor']
self._iter = state_dict['iter']
self._lr = state_dict['lr']
self._mu = state_dict['mu']
self._clip_thresh = state_dict['clip_thresh']
self._beta = state_dict['beta']
self._curv_win_width = state_dict['curv_win_width']
self._zero_debias = state_dict['zero_debias']
self._h_min = state_dict["h_min"]
self._h_max = state_dict["h_max"]
return
def set_lr_factor(self, factor):
self._lr_factor = factor
return
def get_lr_factor(self):
return self._lr_factor
def zero_grad(self):
self._optimizer.zero_grad()
return
def zero_debias_factor(self):
return 1.0 - self._beta ** (self._iter + 1)
def zero_debias_factor_delay(self, delay):
# for exponentially averaged stat which starts at non-zero iter
return 1.0 - self._beta ** (self._iter - delay + 1)
def curvature_range(self):
global_state = self._global_state
if self._iter == 0:
global_state["curv_win"] = torch.FloatTensor(self._curv_win_width, 1).zero_()
curv_win = global_state["curv_win"]
grad_norm_squared = self._global_state["grad_norm_squared"]
# curv_win[self._iter % self._curv_win_width] = np.log(grad_norm_squared + eps)
curv_win[self._iter % self._curv_win_width] = grad_norm_squared
valid_end = min(self._curv_win_width, self._iter + 1)
# we use running average over log scale, accelerating
# h_max / min in the begining to follow the varying trend of curvature.
beta = self._beta
if self._iter == 0:
global_state["h_min_avg"] = 0.0
global_state["h_max_avg"] = 0.0
self._h_min = 0.0
self._h_max = 0.0
if self._h_min_log_smooth:
global_state["h_min_avg"] = \
global_state["h_min_avg"] * beta + (1 - beta) * torch.min(np.log(curv_win[:valid_end] + eps) )
else:
global_state["h_min_avg"] = \
global_state["h_min_avg"] * beta + (1 - beta) * torch.min(curv_win[:valid_end] )
if self._h_max_log_smooth:
global_state["h_max_avg"] = \
global_state["h_max_avg"] * beta + (1 - beta) * torch.max(np.log(curv_win[:valid_end] + eps) )
else:
global_state["h_max_avg"] = \
global_state["h_max_avg"] * beta + (1 - beta) * torch.max(curv_win[:valid_end] )
if self._zero_debias:
debias_factor = self.zero_debias_factor()
if self._h_min_log_smooth:
self._h_min = np.exp(global_state["h_min_avg"] / debias_factor)
else:
self._h_min = global_state["h_min_avg"] / debias_factor
if self._h_max_log_smooth:
self._h_max = np.exp(global_state["h_max_avg"] / debias_factor)
else:
self._h_max = global_state["h_max_avg"] / debias_factor
else:
if self._h_min_log_smooth:
self._h_min = np.exp(global_state["h_min_avg"] )
else:
self._h_min = global_state["h_min_avg"]
if self._h_max_log_smooth:
self._h_max = np.exp(global_state["h_max_avg"] )
else:
self._h_max = global_state["h_max_avg"]
if self._sparsity_debias:
self._h_min *= self._sparsity_avg
self._h_max *= self._sparsity_avg
return
def grad_variance(self):
global_state = self._global_state
beta = self._beta
self._grad_var = np.array(0.0, dtype=np.float32)
for group_id, group in enumerate(self._optimizer.param_groups):
for p_id, p in enumerate(group['params'] ):
if p.grad is None:
continue
grad = p.grad.data
state = self._optimizer.state[p]
if self._iter == 0:
state["grad_avg"] = grad.new().resize_as_(grad).zero_()
state["grad_avg_squared"] = 0.0
state["grad_avg"].mul_(beta).add_(1 - beta, grad)
self._grad_var += torch.sum(state["grad_avg"] * state["grad_avg"] )
if self._zero_debias:
debias_factor = self.zero_debias_factor()
else:
debias_factor = 1.0
self._grad_var /= -(debias_factor**2)
self._grad_var += global_state['grad_norm_squared_avg'] / debias_factor
# in case of negative variance: the two term are using different debias factors
self._grad_var = max(self._grad_var, eps)
if self._sparsity_debias:
self._grad_var *= self._sparsity_avg
return
def dist_to_opt(self):
global_state = self._global_state
beta = self._beta
if self._iter == 0:
global_state["grad_norm_avg"] = 0.0
global_state["dist_to_opt_avg"] = 0.0
global_state["grad_norm_avg"] = \
global_state["grad_norm_avg"] * beta + (1 - beta) * math.sqrt(global_state["grad_norm_squared"] )
global_state["dist_to_opt_avg"] = \
global_state["dist_to_opt_avg"] * beta \
+ (1 - beta) * global_state["grad_norm_avg"] / (global_state['grad_norm_squared_avg'] + eps)
if self._zero_debias:
debias_factor = self.zero_debias_factor()
self._dist_to_opt = global_state["dist_to_opt_avg"] / debias_factor
else:
self._dist_to_opt = global_state["dist_to_opt_avg"]
if self._sparsity_debias:
self._dist_to_opt /= (np.sqrt(self._sparsity_avg) + eps)
return
def grad_sparsity(self):
global_state = self._global_state
if self._iter == 0:
global_state["sparsity_avg"] = 0.0
non_zero_cnt = 0.0
all_entry_cnt = 0.0
for group in self._optimizer.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
grad_non_zero = grad.nonzero()
if grad_non_zero.dim() > 0:
non_zero_cnt += grad_non_zero.size()[0]
all_entry_cnt += torch.numel(grad)
beta = self._beta
global_state["sparsity_avg"] = beta * global_state["sparsity_avg"] \
+ (1 - beta) * non_zero_cnt / float(all_entry_cnt)
self._sparsity_avg = \
global_state["sparsity_avg"] / self.zero_debias_factor()
if self._verbose:
logging.debug("sparsity %f, sparsity avg %f", non_zero_cnt / float(all_entry_cnt), self._sparsity_avg)
return
def lr_grad_norm_avg(self):
# this is for enforcing lr * grad_norm not
# increasing dramatically in case of instability.
# Not necessary for basic use.
global_state = self._global_state
beta = self._beta
if "lr_grad_norm_avg" not in global_state:
global_state['grad_norm_squared_avg_log'] = 0.0
global_state['grad_norm_squared_avg_log'] = \
global_state['grad_norm_squared_avg_log'] * beta \
+ (1 - beta) * np.log(global_state['grad_norm_squared'] + eps)
if "lr_grad_norm_avg" not in global_state:
global_state["lr_grad_norm_avg"] = \
0.0 * beta + (1 - beta) * np.log(self._lr * np.sqrt(global_state['grad_norm_squared'] ) + eps)
# we monitor the minimal smoothed ||lr * grad||
global_state["lr_grad_norm_avg_min"] = \
np.exp(global_state["lr_grad_norm_avg"] / self.zero_debias_factor() )
else:
global_state["lr_grad_norm_avg"] = global_state["lr_grad_norm_avg"] * beta \
+ (1 - beta) * np.log(self._lr * np.sqrt(global_state['grad_norm_squared'] ) + eps)
global_state["lr_grad_norm_avg_min"] = \
min(global_state["lr_grad_norm_avg_min"],
np.exp(global_state["lr_grad_norm_avg"] / self.zero_debias_factor() ) )
def before_apply(self):
# compute running average of gradient and norm of gradient
beta = self._beta
global_state = self._global_state
if self._iter == 0:
global_state["grad_norm_squared_avg"] = 0.0
global_state["grad_norm_squared"] = 0.0
for group_id, group in enumerate(self._optimizer.param_groups):
for p_id, p in enumerate(group['params'] ):
if p.grad is None:
continue
grad = p.grad.data
param_grad_norm_squared = torch.sum(grad * grad)
global_state['grad_norm_squared'] += param_grad_norm_squared
if self._verbose:
logging.debug("Iteration %f", self._iter)
logging.debug("param grad squared gid %d, pid %d, %f, log scale: %f", group_id, p_id, param_grad_norm_squared,
np.log(param_grad_norm_squared + 1e-10) / np.log(10) )
if self._iter >= 1:
self._exploding_grad_clip_thresh = self._h_max
self._exploding_grad_clip_target_value = np.sqrt(self._h_max)
if global_state['grad_norm_squared'] >= self._exploding_grad_clip_thresh:
self._exploding_grad_detected = True
else:
self._exploding_grad_detected = False
global_state['grad_norm_squared_avg'] = \
global_state['grad_norm_squared_avg'] * beta + (1 - beta) * global_state['grad_norm_squared']
if self._verbose:
logging.debug("overall grad norm squared %f, log scale: %f",
global_state['grad_norm_squared'], np.log(global_state['grad_norm_squared'] + 1e-10) / np.log(10))
if self._sparsity_debias:
self.grad_sparsity()
self.curvature_range()
self.grad_variance()
self.dist_to_opt()
if self._verbose:
logging.debug("h_max %f ", self._h_max)
logging.debug("h_min %f ", self._h_min)
logging.debug("dist %f ", self._dist_to_opt)
logging.debug("var %f ", self._grad_var)
if self._iter > 0:
self.get_mu()
self.get_lr()
self._lr = beta * self._lr + (1 - beta) * self._lr_t
self._mu = beta * self._mu + (1 - beta) * self._mu_t
if self._verbose:
logging.debug("lr_t %f", self._lr_t)
logging.debug("mu_t %f", self._mu_t)
logging.debug("lr %f", self._lr)
logging.debug("mu %f", self._mu)
return
def get_lr(self):
self._lr_t = (1.0 - math.sqrt(self._mu_t) )**2 / (self._h_min + eps)
# slow start of lr to prevent huge lr when there is only a few iteration finished
self._lr_t = min(self._lr_t, self._lr_t * (self._iter + 1) / float(10.0 * self._curv_win_width) )
return
def get_cubic_root(self):
# We have the equation x^2 D^2 + (1-x)^4 * C / h_min^2
# where x = sqrt(mu).
# We substitute x, which is sqrt(mu), with x = y + 1.
# It gives y^3 + py = q
# where p = (D^2 h_min^2)/(2*C) and q = -p.
# We use the Vieta's substution to compute the root.
# There is only one real solution y (which is in [0, 1] ).
# http://mathworld.wolfram.com/VietasSubstitution.html
# eps in the numerator is to prevent momentum = 1 in case of zero gradient
if np.isnan(self._dist_to_opt) or np.isnan(self._h_min) or np.isnan(self._grad_var) \
or np.isinf(self._dist_to_opt) or np.isinf(self._h_min) or np.isinf(self._grad_var):
logging.warning("Input to cubic solver has invalid nan/inf value!")
raise Exception("Input to cubic solver has invalid nan/inf value!")
p = (self._dist_to_opt + eps)**2 * (self._h_min + eps)**2 / 2 / (self._grad_var + eps)
w3 = (-math.sqrt(p**2 + 4.0 / 27.0 * p**3) - p) / 2.0
w = math.copysign(1.0, w3) * math.pow(math.fabs(w3), 1.0/3.0)
y = w - p / 3.0 / (w + eps)
x = y + 1
if self._verbose:
logging.debug("p %f, denominator %f", p, self._grad_var + eps)
logging.debug("w3 %f ", w3)
logging.debug("y %f, denominator %f", y, w + eps)
if np.isnan(x) or np.isinf(x):
logging.warning("Output from cubic is invalid nan/inf value!")
raise Exception("Output from cubic is invalid nan/inf value!")
return x
def get_mu(self):
root = self.get_cubic_root()
dr = max( (self._h_max + eps) / (self._h_min + eps), 1.0 + eps)
self._mu_t = max(root**2, ( (np.sqrt(dr) - 1) / (np.sqrt(dr) + 1) )**2 )
return
def update_hyper_param(self):
for group in self._optimizer.param_groups:
group['momentum'] = self._mu_t
#group['momentum'] = max(self._mu, self._mu_t)
if self._force_non_inc_step == False:
group['lr'] = self._lr_t * self._lr_factor
# a loose clamping to prevent catastrophically large move. If the move
# is too large, we set lr to 0 and only use the momentum to move
if self._adapt_clip and (group['lr'] * np.sqrt(self._global_state['grad_norm_squared']) >= self._catastrophic_move_thresh):
group['lr'] = self._catastrophic_move_thresh / np.sqrt(self._global_state['grad_norm_squared'] + eps)
if self._verbose:
logging.warning("clip catastropic move!")
elif self._iter > self._curv_win_width:
# force to guarantee lr * grad_norm not increasing dramatically.
# Not necessary for basic use. Please refer to the comments
# in YFOptimizer.__init__ for more details
self.lr_grad_norm_avg()
debias_factor = self.zero_debias_factor()
group['lr'] = min(self._lr * self._lr_factor,
2.0 * self._global_state["lr_grad_norm_avg_min"] \
/ (np.sqrt(np.exp(self._global_state['grad_norm_squared_avg_log'] / debias_factor) ) + eps) )
return
def auto_clip_thresh(self):
# Heuristic to automatically prevent sudden exploding gradient
# Not necessary for basic use.
return math.sqrt(self._h_max) * self._auto_clip_fac
def step(self):
# add weight decay
for group in self._optimizer.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
if group['weight_decay'] != 0:
grad = grad.add(group['weight_decay'], p.data)
if self._clip_thresh != None:
torch.nn.utils.clip_grad_norm(self._var_list, self._clip_thresh)
elif (self._iter != 0 and self._auto_clip_fac != None):
# do not clip the first iteration
torch.nn.utils.clip_grad_norm(self._var_list, self.auto_clip_thresh() )
# loose threshold for preventing exploding gradients from destroying statistics
if self._adapt_clip and (self._iter > 1):
torch.nn.utils.clip_grad_norm(self._var_list, np.sqrt(self._stat_protect_fac * self._h_max) + eps)
try:
# before appply
self.before_apply()
# update learning rate and momentum
self.update_hyper_param()
# periodically save model and states
if self._iter % self._checkpoint_interval == 0:
if self._use_disk_checkpoint and os.path.exists(self._checkpoint_dir):
checkpoint_path = self._checkpoint_dir + "/" + self._checkpoint_file
with open(checkpoint_path, "wb") as f:
cp.dump(self.state_dict(), f, protocol=2)
else:
self._state_checkpoint = copy.deepcopy(self.state_dict() )
# protection from exploding gradient
if self._exploding_grad_detected and self._verbose:
logging.warning("exploding gradient detected: grad norm detection thresh %f , grad norm %f, grad norm after clip%f",
np.sqrt(self._exploding_grad_clip_thresh),
np.sqrt(self._global_state['grad_norm_squared'] ),
self._exploding_grad_clip_target_value)
if self._adapt_clip and self._exploding_grad_detected:
# print("exploding gradient detected: grad norm detection thresh ", np.sqrt(self._exploding_grad_clip_thresh),
# "grad norm", np.sqrt(self._global_state['grad_norm_squared'] ),
# "grad norm after clip ", self._exploding_grad_clip_target_value)
torch.nn.utils.clip_grad_norm(self._var_list, self._exploding_grad_clip_target_value + eps)
self._optimizer.step()
self._iter += 1
except:
# load the last checkpoint
logging.warning("Numerical issue triggered restore with backup. Resuming from last checkpoint.")
if self._use_disk_checkpoint and os.path.exists(self._checkpoint_dir):
checkpoint_path = self._checkpoint_dir + "/" + self._checkpoint_file
with open(checkpoint_path, "rb") as f:
self.load_state_dict_perturb(cp.load(f))
else:
self.load_state_dict_perturb(copy.deepcopy(self._state_checkpoint) )
return