Official repository of the paper: Receding Neuron Importances for Structured Pruning, by Mihai Suteu and Yike Guo 2022 https://arxiv.org/abs/2204.06404
Structured pruning efficiently compresses networks by identifying and removing unimportant neurons. While this can be elegantly achieved by applying sparsity-inducing regularisation on BatchNorm parameters, an L1 penalty would shrink all scaling factors rather than just those of superfluous neurons. To tackle this issue, we introduce a simple BatchNorm variation with bounded scaling parameters, based on which we design a novel regularisation term that suppresses only neurons with low importance. Under our method, the weights of unnecessary neurons effectively recede, producing a polarised bimodal distribution of importances. We show that neural networks trained this way can be pruned to a larger extent and with less deterioration. We one-shot prune VGG and ResNet architectures at different ratios on CIFAR and ImagenNet datasets. In the case of VGG-style networks, our method significantly outperforms existing approaches particularly under a severe pruning regime.
- Datasets: CIFAR10, CIFAR100
- Methods: RNI (ours), UCS (baseline), L1 (Slimming)
- Architectures: VGG-16, ResNet-56
- In the hope to make things more user-friendly all training, pruning and fine-tuning can be run from a Ipython Notebook using only two functions.
- Both Training and pruning-finetuning return the best model and training logs.
- Sample usage over CIFAR scenarios can be found in Results.ipynb
- Trains a new full/baseline model from scratch, which can be pruned later.
model, logs = train(
dataset="cifar10", # "cifar10", "cifar100"
arch="vgg16", # "vgg16", "resnet56"
method="rni", # "rni", "ucs", "l1"
reg_weight=1e-4, # strength of sparsity regularisation, will be turned off for ucs
b=3, # "shift" hyper-parameter for rni, won't do anything for l1 or ucs
epochs=160,
lr=0.1, # scheduled to be divided by 10 at 50% and 75% of epochs.
batch_size=64,
seed=0,
device="cuda:0",
)
- Prunes and finetunes a given model.
- Pruning strategy is dictated by method used. local: ucs, global: l1, rni
model, logs = prune_finetune(
model,
dataset="cifar10",
method="rni",
prune_pc=0.5, # percentage of filters to prune
epochs=160,
lr=0.1,
batch_size=64,
seed=0,
device="cuda:0",
)
- Written in Pytorch
- Very simple implementation using a single underlying trainer function for all methods, architectures and datasets.
- Pruning logic is defined in the model modules, as it depends on the architecture.
- Sigmoid BatchNorm implementation is based on Pytorch's BatchNorm code.