PyTorch linear over-parameterization layers with automatic graph reduction.
Official codebase used in:
The Low-Rank Simplicity Bias in Deep Networks
Minyoung Huh Hossein Mobahi Richard Zhang Brian Cheung Pulkit Agrawal Phillip Isola
MIT CSAIL Google Research Adobe Research MIT BCS
TMLR 2023 (arXiv 2021).
[project page] | [paper] | [arXiv]
Developed on
- Python 3.7 🐍
- PyTorch 1.7 🔥
> git clone https://github.com/minyoungg/overparam
> cd overparam
> pip install .
The layers work exactly the same as any torch.nn
layers.
from overparam import OverparamLinear
layer = OverparamLinear(16, 32, width=1, depth=2)
x = torch.randn(1, 16)
from overparam import OverparamConv2d
import numpy as np
We can construct 3 Conv2d layers with kernel dimensions of 5x5
, 3x3
, 1x1
# Same padding
padding = max((np.sum(kernel_sizes) - len(kernel_sizes) + 1) // 2, 0)
layer = OverparamConv2d(2, 4, kernel_sizes=[5, 3, 1], padding, depth=len(kernel_sizes))
# Get the effective kernel size
print(layer.kernel_size)
When kernel_sizes
is an integer, all proceeding layers are assumed to have kernel size of 1x1
.
# Forward pass (expanded form)
layer.train()
y = layer(x)
When calling eval()
the model will automatically reduce the computation graph to its effective single-layer counterpart.
Forward pass in eval
mode will use the effective weights instead.
# Forward pass (collapsed form) [automatic]
layer.eval()
y = layer(x)
You can access the effective weights as follows:
print(layer.weight)
print(layer.bias)
import torchvision.models as models
from overparam.utils import overparameterize
model = models.alexnet() # Replace this with YOUR_PYTORCH_MODEL()
model = overparameterize(model, depth=2)
We also provide support for batch-norm and linear residual connections.
- batch-normalization (pseudo-linera layer: linear during
eval
mode)
layer = OverparamConv2d(32, 32, kernel_sizes=3, padding=1, depth=2,
batch_norm=True)
- residual-connection
# every 2 layers, a residual connection is added
layer = OverparamConv2d(32, 32, kernel_sizes=3, padding=1, depth=2,
residual=True, residual_intervals=2)
- multiple residual connection
# every modulo [1, 2, 3] layers, a residual connection is added
layer = OverparamConv2d(32, 32, kernel_sizes=3, padding=1, depth=2,
residual=True, residual_intervals=[1, 2, 3])
- batch-norm and residual connection
# mimics `BasicBlock` in ResNets
layer = OverparamConv2d(32, 32, kernel_sizes=3, padding=1, depth=2,
batch_norm=True, residual=True, residual_intervals=2)
@article{huh2023simplicitybias,
title={The Low-Rank Simplicity Bias in Deep Networks},
author={Minyoung Huh and Hossein Mobahi and Richard Zhang and Brian Cheung and Pulkit Agrawal and Phillip Isola},
journal={Transactions on Machine Learning Research},
issn={2835-8856},
year={2023},
url={https://openreview.net/forum?id=bCiNWDmlY2},
}