These are the results from forking tysam-code/hlb-CIFAR10. See snimu/hlb-CIFAR10 for the fork.
Below are first the losses and accuracies plotted for the model, then a graph of the model itself (generated using torchview).
I have results for two algorithms:
PermutationCoordinateDescent
MergeMany
Table of Contents:
I'll test PermutationCoordinateDescent
over two parameters:
- Filter-size: According to the paper, larger filter-sizes work better. I'll test filter-sizes of 3x3, 6x6, 9x9, 12x12, 15x15, and 18x18.
- L2-Regularizer: This isn't part of the paper, but it makes sense to me
that an algorithm that works by permuting weights to be as similar as possible
to that of another model would benefit from all weights being in the same value-range.
I'll use the
weight_decay
-parameter oftorch.optim.Adam
to regularize the weights. Values will range from0.0
to0.9
This is at a fixed weight_decay > 0.0
.
A few things immediately jump out to me from the plots above:
- The method works somewhat; interpolation between
model_a
andmodel_b (rebasin)
is much better than betweenmodel_a
andmodel_b (original)
. - Applying the method to
model_b
and then interpolating betweenmodel_b (original)
andmodel_b (rebasin)
yields better results than interpolating betweenmodel_a
andmodel_b (original)
. - The git re-basin method works very well for the accuracy of the model!
At least for this model, interpolation between
model_a
andmodel_b (rebasin)
leads to almost flat accuracies. - Larger filter size is said to work better in the paper, but it is unclear to me if this is actually the case here. Let's look at that in more detail below.
Below, I plot the losses and accuracies
when interpolating between model_a
and model_b (rebasin)
.
I do so for all filter-sizes.
It is not clear to me that larger filter sizes lead to better results.
However, larger filter-sizes do degrade model-performance in general
(the larger the filter size, the worse the performance of the first and last model,
which are the original model_a
and model_b (rebasin)
),
so to give a comparison of how rebasin affects how interpolation between models
behaves, I plot the losses and accuracies again below, but this time
I move all startpoints (i.e. model_a
) to the results of the 3x3-filter.
The behavior seems noisy, though slightly better for larger filter-sizes.
This is at a fixed filter_size == 28*28 == 784
.
These results are very weird.
As expected, increasing the L2-Regularizer
improves the results of PermutationCoordinateDescent
significantly;
in fact, the algorithm only really works for pretty high weight_decay
-values,
where it doesn't fully remove the loss-barrier, but lowers it by a lot,
and almost fully removes the accuracy-barrier (this stands in opposition to the results
of the paper, where PermutationCoordinateDescent
removed the loss-barrier but not
the accuracy-barrier).
However, I cannot explain the behavior of the interpolation between
model_b (original)
and model_b (rebasin)
. At low weight_decay
-values,
it is extremely good, much better than that between model_a
and model_b (rebasin)
.
The first possible explanation for this might be that model_b (original) === model_b (rebasin)
.
In other words, model_b (original)
was already optimally aligned with model_a
,
such that it didn't get permuted.
However, this is clearly not the case, as these facts about the interpolation
as weight_decay == 0.0
show:
- The accuracies between
model_b (original)
andmodel_b (rebasin)
change from step to step, which means that the interpolated models are different from either of the two models, which in turn means that the two models are different. - The interpolation between
model_a
andmodel_b (original)
is different from that betweenmodel_a
andmodel_b (rebasin)
, which means thatmodel_b (original) != model_b (rebasin)
.
An alternative explanation might be that
a high regularization lead to model_b (original)
being in a narrower loss-basin,
such that permuting it and interpolating between the results leads to poor results.
This might explain why the absolute accuracy falls with rising weight_decay
(at least when going from 0.0
to 0.1
). However, that fact may also be explained
by the fact that the model wasn't tuned for any of the given weight_decay
-values,
and that the differences in accuracy are just random.
The explanation that seems the most likely to me is that at low weight_decay
-values,
only one or two weights are permuted, leading to good interpolation between
model_b (original)
and model_b (rebasin)
, but barely any difference
between the interpolations betwen model_a
and model_b (original)
& model_b (rebasin)
.
Ultimately, though, I don't know.
BatchNorm statistics have to be re-calculated after permuting the weights. This is a big problem: if every interpolated model must have its BatchNorm statistics re-calculated, then the interpolation takes a lot of compute.
So a natural question to ask is: how well can you predict the loss and accuracy of an interpolated model after re-calculating the BatchNorm statistics from the loss and accuracy of the same model before re-calculating the BatchNorm statistics?
Here are the results of evaluating the loss and accuracy of the interpolated models in four different modes:
BN model A
: BatchNorm statistics as in model ABN model B
: BatchNorm statistics as in model BBN reset
: All BatchNorms havereset_running_stats
called on themBN recalc
: All BachNorms have their statistics re-calculated using the entire CIFAR10 dataset
Here are the results for different weight_decay
-values at training:
It seems like the final loss and accuracy are somewhat predictable from
the loss and accuracy of the models before re-calculating the BatchNorm statistics.
BN reset
seems like the best setting for that, but not for very high weight_decay
-values.
BN model A
and BN model B
seem almost useless.
It is, however, very interesting to see that they almost perfectly align
with each other. The outputs of the models with the two different BatchNorm settings
are not precisely the same, and using the BatchNorm
-statistics from
model_b
degrades the performance very sharply very close to model_a
,
so there almost certainly isn't a programming error here
(though you are welcome to check the code youself :)
I also implemented the MergeMany
-algorithm from the paper.
Here are the results of running tests with it on hlb-CIFAR10
.
I ran three kinds of tests:
- Simple
MergeMany
. Doesn't seem to work well. - Train-Merge-Train. Use
MergeMany
as pre-training / initialization, then train the model further. Doesn't seem to work well. - Train on different datasets. Train models on different datasets, then merge them and retrain on one last dataset. This seems to work well!
I used the MergeMany
-algorithm with the following convolutional kernel-sizes:
- 3x3
- 6x6
- 9x9
I used all of these once when merging 3, 6, 9, and 12 models.
Below are the losses and accuracies.
Clearly, MergeMany
significantly reduces performance instead of improving it
(as claimed in the paper). The more models are merged, the greater the loss in performance.
And to be clear, the more models are merged, the worse performance becomes.
In the paper, an MLP was used for merging, so that might be another interesting thing to try.
After the poor results seen above, I was wondering if training for a few steps after
merging several models might yield better results than just training a single model
for the equivalent number of epochs (i.e. 10 epochs, if each of the merged models is trained
for 5 epochs and the merge model for another 5 epochs). If it were,
then the MergeMany
-algorithm may still be useful for federated learning.
I use the following training procedure:
n
models trained fo 5 epochs each, merged, then trained for another 5 epochs- 1 model trained for 10 epochs (the control model)
I use the following abbreviations:
L
: Loss (of the merged & retrained model &mdash)LC
: Loss control (of the control model)A
: Accuracy (of the merged & retrained model)AC
: Accuracy control (of the control model)
Settings (#models --- filter-size) | L | LC | L/LC | A | AC | A/AC |
---|---|---|---|---|---|---|
3 --- 3x3 | 1.064 | 1.030 | 103.3% | 0.914 | 0.932 | 98.1% |
3 --- 6x6 | 1.053 | 1.018 | 103.4% | 0.906 | 0.924 | 98.1% |
3 --- 9x9 | 1.079 | 1.046 | 103.2% | 0.889 | 0.904 | 98.3% |
6 --- 3x3 | 1.077 | 1.034 | 104.2% | 0.907 | 0.929 | 97.6% |
6 --- 6x6 | 1.069 | 1.018 | 105.0% | 0.897 | 0.923 | 97.2% |
6 --- 9x9 | 1.087 | 1.042 | 104.3% | 0.879 | 0.906 | 97.0% |
9 --- 3x3 | 1.103 | 1.032 | 106.9% | 0.895 | 0.931 | 96.1% |
9 --- 6x6 | 1.078 | 1.022 | 105.5% | 0.893 | 0.921 | 97.0% |
9 --- 9x9 | 1.113 | 1.041 | 106.9% | 0.872 | 0.909 | 95.9% |
12 --- 3x3 | 1.111 | 1.032 | 107.7% | 0.890 | 0.929 | 95.8% |
12 --- 6x6 | 1.143 | 1.019 | 112.2% | 0.859 | 0.921 | 93.2% |
12 --- 9x9 | 1.135 | 1.045 | 108.6% | 0.858 | 0.905 | 94.8% |
The control model was always trained in exactly the same way, though with the same filter-size as the merged models. This means that the variance seen in the control model for different numbers of losses at the same filter-size is due to the random initialization of the model. However, clear trends are still visible.
It is clear that using MergeMany
as a form of pre-training does not improve performance.
In fact, increasing the number of models that are merged decreases the performance of the merged model
(which is not surprising, given that merging is used as a form of pre-training here,
and more models mean worse performance in the pre-trained model).
Increasing the filter-size, on the other hand, doesn't affect the performance difference
between the merged model and the control model.
In general, MergeMany
doesn't work for pre-training / initialization.
However, if a few models were trained for 5 epochs in one location, merged,
and then for another 10 epochs at another location,
that would likely improve performance compared to simply training for 10 epochs.
In other words, using distributed training with MergeMany
will likely increase
performance relative to simply training for a lower number of epochs in one location.
On the other hand, it seems like using more models for that is worse, not better,
and simply not merging before continuing the training would be even better.
So even for this purpose, MergeMany
doesn't seem to be useful
(for the given architecture, at least).
There is a caveat to the above: all of these models were trained on the same data. Results might be different if the models were trained on different data each.
That means that MergeMany
might still be useful for effectively increasing dataset-size
through federated learning. This is what I've tested next.
To do so, I split CIFAR10
into n+1
parts, where n
is the number of models that I would merge.
I would train n
models for e
epochs, each on a different part of the dataset,
then merge them. Then, I would train the merged model for another e
epochs
on the last part of the dataset.
For comparison, I would train a single model for e
epochs on only the last part of the dataset.
This is to see whether the merged model performs better
than a single model trained for the same number of epochs.
This would already be useful, if local training of several models on private datasets
followed by merging and retraining on a public dataset yields a better model than just
training on the public dataset with equivalent compute time.
I then continued training of that model for another e
models, for a total of n * e
epochs.
This is to see whether the merged model performs better than a single model
given the same compute resources as the merged model including the "pre-training"
(a.k.a. training of models that were then merged).
In the results, Model <i>
refers to the ith model
that was used for merging, Merged Model
refers to the model that was obtained by merging
those models, and Control Model (<j>)
to the control model trained on j * e
epochs.
There are two different parameters of interest that may influence the results:
- The number of epochs that each model is trained for (
e
). - The number of models that are merged (
n
).
Here are the results of plotting loss & accuracy over the number of epochs:
A few things become clear immediately:
- The results are very noisy
- The performance of all models improves with the number of epochs (as expected)
- The performance of all models falls with the number of models (because they see a smaller portion of the dataset; fewer datapoints, and less diverse data)
To compensate for the secular change in performance, I will normalize the data
by dividing each datapoint to the corresponding datapoint of the Merged Model
.
This way, the change in performance of the different models relative to that of the merged model
will become apparent.
Except when using only 3 models, the performance of Merged Model
is always better
than that of the models that were used for merging, as well as Control Model (1)
.
For high n
, it is even somewhat competitive with Control Model (n)
.
Here are the results of plotting loss & accuracy over the number of models:
Here, it becomes very clear that increasing the number of models decreases performance. Let's normalize the data again:
The performance of all models drops relative to that of the Merged Model
with increasing n
.
This makes me think that MergeMany
might work if the limit on training is the dataset size,
and not compute budget. In that case, using data and training privately on-device, then
merging the resulting models and retraining them on a well-tuned central dataset
might be very useful. This is especially the case if different datasources show different bias,
though I currently have not tested this (and may not do it in the near future).
The model is a simple Resnet:
The model with 3x3-filters has the following permutation structure:
PathSequence(
----------------------------------------
LinearPath(
DefaultModule(
module.type: Conv
input.shape: [(5000, 3, 32, 32)]
output.shape: [(5000, 12, 31, 31)]
weight.in_dim.permutation: None
weight.out_dim.permutation.shape: 12
)
DefaultModule(
module.type: Conv
input.shape: [(5000, 12, 31, 31)]
output.shape: [(5000, 32, 31, 31)]
weight.in_dim.permutation.shape: 12
weight.out_dim.permutation.shape: 32
)
DefaultModule(
module.type: Conv
input.shape: [(5000, 32, 31, 31)]
output.shape: [(5000, 64, 31, 31)]
weight.in_dim.permutation.shape: 32
weight.out_dim.permutation.shape: 64
)
OneDimModule(
module.type: BatchNorm
input.shape: [(5000, 64, 15, 15)]
output.shape: [(5000, 64, 15, 15)]
weight.in_dim.permutation.shape: 64
weight.out_dim.permutation.shape: 64
)
)
----------------------------------------
|
|
|
--------------------------------------------------------------
ParallelPaths(
LinearPath( LinearPath()
DefaultModule( |
module.type: Conv |
input.shape: [(5000, 64, 15, 15)] |
output.shape: [(5000, 64, 15, 15)] |
weight.in_dim.permutation.shape: 64 |
weight.out_dim.permutation.shape: 64 |
) |
|
OneDimModule( |
module.type: BatchNorm |
input.shape: [(5000, 64, 15, 15)] |
output.shape: [(5000, 64, 15, 15)] |
weight.in_dim.permutation.shape: 64 |
weight.out_dim.permutation.shape: 64 |
) |
|
) |
)
--------------------------------------------------------------
|
|
|
-----------------------------------------
LinearPath(
DefaultModule(
module.type: Conv
input.shape: [(5000, 64, 15, 15)]
output.shape: [(5000, 256, 15, 15)]
weight.in_dim.permutation.shape: 64
weight.out_dim.permutation.shape: 256
)
OneDimModule(
module.type: BatchNorm
input.shape: [(5000, 256, 7, 7)]
output.shape: [(5000, 256, 7, 7)]
weight.in_dim.permutation.shape: 256
weight.out_dim.permutation.shape: 256
)
)
-----------------------------------------
|
|
|
---------------------------------------------------------------
ParallelPaths(
LinearPath( LinearPath()
DefaultModule( |
module.type: Conv |
input.shape: [(5000, 256, 7, 7)] |
output.shape: [(5000, 256, 7, 7)] |
weight.in_dim.permutation.shape: 256 |
weight.out_dim.permutation.shape: 256 |
) |
|
OneDimModule( |
module.type: BatchNorm |
input.shape: [(5000, 256, 7, 7)] |
output.shape: [(5000, 256, 7, 7)] |
weight.in_dim.permutation.shape: 256 |
weight.out_dim.permutation.shape: 256 |
) |
|
) |
)
---------------------------------------------------------------
|
|
|
----------------------------------------
LinearPath(
DefaultModule(
module.type: Conv
input.shape: [(5000, 256, 7, 7)]
output.shape: [(5000, 512, 7, 7)]
weight.in_dim.permutation.shape: 256
weight.out_dim.permutation: None
)
OneDimModule(
module.type: BatchNorm
input.shape: [(5000, 512, 3, 3)]
output.shape: [(5000, 512, 3, 3)]
weight.in_dim.permutation: None
weight.out_dim.permutation: None
)
)
----------------------------------------
|
|
|
-----------------------------------------------------------
ParallelPaths(
LinearPath( LinearPath()
DefaultModule( |
module.type: Conv |
input.shape: [(5000, 512, 3, 3)] |
output.shape: [(5000, 512, 3, 3)] |
weight.in_dim.permutation: None |
weight.out_dim.permutation: None |
) |
|
OneDimModule( |
module.type: BatchNorm |
input.shape: [(5000, 512, 3, 3)] |
output.shape: [(5000, 512, 3, 3)] |
weight.in_dim.permutation: None |
weight.out_dim.permutation: None |
) |
|
) |
)
-----------------------------------------------------------
|
|
|
------------------------------------
LinearPath(
DefaultModule(
module.type: Linear
input.shape: [(5000, 512)]
output.shape: [(5000, 10)]
weight.in_dim.permutation: None
weight.out_dim.permutation: None
)
)
------------------------------------
)
The models with larger filter-sizes are exactly the same, except for the convolutions' weight-shapes.