-
Notifications
You must be signed in to change notification settings - Fork 3
/
config.py
24 lines (22 loc) · 888 Bytes
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch
class PruneConfig(object):
def __init__(self):
self.n_points_per_layer = 1
self.prunable_layer_types = [torch.nn.modules.conv.Conv2d, torch.nn.modules.linear.Linear]
self.calib_batch = 50
self.device = 'cuda'
self.policy = None
self.fmap_save = True
self.fmap_save_path = './'
class HSICLassoPruneConfig(PruneConfig):
def __init__(self, name, model, ckpt, train_dataloader, pruner="lasso", val_dataloader=None, criterion=None, policy=None, fmap_path=None):
super(HSICLassoPruneConfig, self).__init__()
self.name = name
self.model = model
self.ckpt = ckpt
self.train_dataloader = train_dataloader
self.val_dataloader = val_dataloader
self.criterion = criterion
self.policy = policy
self.pruner = pruner
self.fmap_path = fmap_path