This repository (re)implements the code for the paper Overparameterized ReLU Neural Networks Learn the Simplest Model: Neural Isometry and Phase Transitions (Wang et al 2022).
Suppose that
where
$$ f^\mathrm{ReLU-skip}(\mathbf{X};\Theta) =\mathbf{X}\mathbf{w}^{(1)}{1} w^{(2)}{1}+\sum_{i=2}^m (\mathbf{X}\mathbf{w}^{(1)}{i})+ w^{(2)}_{i}, $$
where
$$ f^\mathrm{ReLU-norm}(\mathbf{X};\Theta) =\sum_{i=1}^m \operatorname{NM}{\alpha_i}((\mathbf{X}\mathbf{w}^{(1)}{i})+)w^{(2)}{i}, $$
where
We consider the regularized training problem
When
We include code to solve convex optimization formulations of the minimal norm problem and to train neural networks discussed in the paper, respectively. We also include code to plot the phase transition graphs shown in the paper.
More details about the numerical experiments can be found in the appendix of the paper.
When solving convex programs, CVXPY (version>=1.1.13) is needed. Mosek solver is preferred. To use Mosek, you will need to register for an account and download a license. You can also change the solver according to the documentation of CVXPY.
When training neural networks discussed in the paper, PyTorch (version>=1.10.0) is needed.
Learned model | Equation | Formulation |
---|---|---|
ReLU | exact | 11 (top of p. 8) |
ReLU | approx | 211 (skip connection removed) |
ReLU-skip | exact | 6 (top of p. 5)* |
ReLU-skip | nonconvex | 9 (bottom of p. 7) |
ReLU-skip | relaxed | 15 (bottom of p. 8) |
ReLU-skip | approx | 211 (bottom of p. 57) |
ReLU-norm | exact | 16 (top of page 9) |
ReLU-norm | relaxed | 17 (middle of page 9) |
ReLU-norm | approx | 212 (bottom of page 9) |
*We implement this with the w_0 norm added to the objective function. We suspect this was a typo in the paper.
| Figure | Condition | Command (after python nic.py
) |
| 10 | NNIC-k |
The codebase is structured as follows:
main.py
: the main CLI.nic.py
: Neural Isometry Condition implementations.plot.py
: for plotting data, as well as a CLI that plots results from a NumPy file containing a numpy array of shape (n, d, sample).training/
: training utilities for solving the actual optimization problems.common.py
: helpers for formulating the convex optimization problems.cvx_base.py
: the abstract superclass for implementing the convex problems.cvx_normalized.py
: the implementation for the convex formulation of a ReLU network with batch normalization.cvx_skip.py
: the implementation for the convex formulation of a ReLU network with a skip connection. Can also be used for a plain network with no skip connection.noncvx_networks.py
: simple neural networks implemented in PyTorch.noncvx_network_train.py
: nonconvex neural network training code with PyTorch.
Originally implemented by:
- Yixuan Hua ([email protected])
- Yifei Wang ([email protected])
This fork is a reimplementation by
- Alexander Cai ([email protected])
- Max Nadeau ([email protected])
Pull requests are welcome. For major changes, please open an issue first to discuss what you would like to change.
Please make sure to update tests as appropriate.