本仓库为“while True: 卷”队完成飞桨论文复现挑战赛(第四期)中论文#39所使用的仓库
#39 Cream of the Crop: Distilling Prioritized Paths For One-Shot Neural Architecture Search ⭐⭐
One-shot weight sharing methods have recently drawn great attention in neural architecture search due to high efficiency and competitive performance. However, weight sharing across models has an inherent deficiency, i.e., insufficient training of subnetworks in hypernetworks. To alleviate this problem, we present a simple yet effective architecture distillation method. The central idea is that subnetworks can learn collaboratively and teach each other throughout the training process, aiming to boost the convergence of individual models. We introduce the concept of prioritized path, which refers to the architecture candidates exhibiting superior performance during training. Distilling knowledge from the prioritized paths is able to boost the training of subnetworks. Since the prioritized paths are changed on the fly depending on their performance and complexity, the final obtained paths are the cream of the crop. We directly select the most promising one from the prioritized paths as the final architecture, without using other complex search methods, such as reinforcement learning or evolution algorithms. The experiments on ImageNet verify such path distillation method can improve the convergence ratio and performance of the hypernetwork, as well as boosting the training of subnetworks. The discovered architectures achieve superior performance compared to the recent MobileNetV3 and EfficientNet families under aligned settings. Moreover, the experiments on object detection and more challenging search space show the generality and robustness of the proposed method. Code and models are available at https://github.com/microsoft/cream.git2 .
https://aistudio.baidu.com/aistudio/clusterprojectdetail/2381048
详见模型分析
-
- 完成了tImM到pImM库的转写(以及其他依赖库转写)
-
- 完成了所有代码的转写及到训练部分前的测试
-
- 完成了Cream-14的retrain与精度测试
我们已经完成了源论文验证集的运行,以下简要介绍运行情况: (更新于21/8/19)
在论文作者发布的源码中,我们发现了一些瑕疵。为了使代码能够在无cuda的cpu环境下更好地运行,我们对源代码进行了如下更改:
- 根据源码README指引修改了Cream/configs/test/test.yaml,使得程序将根据Cream-14模型(权重已载入,详见权重与数据)进行预测。
- 为了方便进行cuda与cpu相关的修改,将Cream/tools/test.py中
import
的timm.data.loader
文件迁出至Cream/tools/文件夹下,并进行了一系列改动 - 由于torch版本差异,对于Cream/lib/core/test.py文件中现第55行的误差计算函数
accuracy()
,我们不在采用timm库中的版本,而是将库中的函数粘贴到此文件中并将return
行的torch.tensor.view()
函数用torch.tensor.reshape()
替代,以规避用前者将二维张量转换为一维时的报错 - 注释了Cream/lib/core/test.py、Cream/tools/loader.py、Cream/tools/test.py中涉及cuda的部分, 以避免cuda缺失造成的报错
- 在Cream/tools/test.py中引入了
torch.device("cpu")
,并注释了GPU相关的部分,以适应CPU运行环境 - 改动了Cream\tools\main.py、Cream/tools/test.py中部分系统命令相关的语句,将ubuntu命令转换为了Windows下的命令行命令
- 为了方便观察运行进度,修改了Cream/lib/core/test.py文件中现第73行的输出log的条件,使得每一个batch完成后均会在命令行窗口输出日志
由于原文作者所指定的部分版本的依赖已经无法获得,我们更新了一套新的依赖,详见Cream\requirements_new
- 我们已经将Cream-14模型的权重下载至Cream\experiments\workspace\ckps\14.pth.tar。原文作者还在百度网盘(提取码:wqw6)保存了其他可用权重。虽然这些权重同样可以使修改后的源码正常运行,但我们并不是非常建议您使用这些模型权重,尤其是较大的模型权重(如最大的Cream-604)处理检验集,因为如此规模的网络并不适合在CPU环境下运行(我们曾尝试运行Cream-604网络。虽然它在检验集的评估下有明显更优的表现,但其运行时间是Cream-14网络的近20倍)。
- 原文使用了ILSVRC-2012-Task_1(即原README所谓的ImageNet-2012)的检验集进行模型的评估。因为该数据集过大,您需要自行下载并将其中的所有图片移动至Cream\data\imagenet\val
- 原文作者提供的标记文件已随此仓库同步至Github。下载检验集后,请将其一并移动到Cream\data\imagenet\val,并在该文件夹下启动GitBash运行如下命令以运行该文件:
sh valprep.sh
- 该文件会自动将检验集图片按标签分类为文件夹。此过程可能长达数小时,请耐心等待。
若要使用检验集进行评估,请在主文件夹下使用命令行窗口运行如下命令(参考原README)
cd Cream
conda create -n Cream python=3.8
conda activate Cream
pip install -r requirements_new #使用新的依赖
python ./tools/main.py test .\experiments\configs\test\test.yaml #斜杠方向的变化是为了适应Cream\tools\main.py中命令拼接的bug
我们于21/8/19使用Cream-14模型重新运行了检验集,并保留了本次运行的日志。结果与原文结果对比如下:
指标 | 原文的 | 我们的 |
---|---|---|
运行时间 | - | 8分34秒 |
一级准确率(Top-1 Acc.) | 53.8% | 53.9% |
五级准确率(Top-5 Acc.) | 77.2% | 77.4% |
运行结果于原文基本吻合
由于原论文所附的代码强烈依赖于tImM(Pytorch Image Models)库,因此我们不得不另行转写了一个与之对应的PImM(Paddle Image Models)库以更加快捷地转写主程序。
下表给出了原论文用到的所有重要timm库函数/方法在我们的pimm库中的对应调用方式。
timm对象 | pimm对象 |
---|---|
timm.data.Dataset | pimm.data.Dataset |
timm.data.create_loader | pimm.data.create_loader |
timm.loss.LabelSmoothingCrossEntropy | pimm.loss.LabelSmoothingCrossEntropy |
timm.models.efficientnet_blocks.ConvBnAct | pimm.models.efficientnet_blocks.ConvBnAct |
timm.models.efficientnet_blocks.DepthwiseSeparableConv | pimm.models.efficientnet_blocks.DepthwiseSeparableConv |
timm.models.efficientnet_blocks.drop_path | pimm.models.efficientnet_blocks.drop_path |
timm.models.efficientnet_blocks.InvertedResidual | pimm.models.efficientnet_blocks.InvertedResidual |
timm.models.efficientnet_blocks.SqueezeExcite | pimm.models.efficientnet_blocks.SqueezeExcite |
timm.models.layers.activations.hard_sigmoid | pimm.models. |
timm.models.layers.activations.Swish | pimm.models. |
timm.models.layers.create_conv2d | pimm.models. |
timm.models.layers.SelectAdaptivePool2d | pimm.models. |
timm.models.resume_checkpoint | pimm.models.resume_checkpoint |
timm.optim.create_optimizer | pimm.optim.create_optimizer |
timm.scheduler.create_scheduler | pimm.scheduler.create_scheduler |
timm.utils.CheckpointSaver | pimm.utils.CheckpointSaver |
timm.utils.ModelEma | pimm.utils.ModelEma |
timm.utils.reduce_tensor | pimm.utils.reduce_tensor |
当然,我们还转写了其他原文代码用到的依赖库,详见下表
原依赖库名 | 现依赖库名 | 用途 |
---|---|---|
ptflops | pdflops | 用于计算模型中的浮点运算次数,在本模型代码中用作搜索模型依据(本库似乎未被使用) |
thop | phop | 同上 |
在转写原文源码的过程中,我们遵循如下原则:
- 严格遵循原文件结构与代码结构
- 对于Pytorch中的类与函数
- 尽可能使用paddle中的同功能类与函数
- 若存在多种方式实现同一功能,在无差异的情形下,尽可能使用同名类与函数
- 若paddle中无直接实现,尽可能使用paddle官网给出的推荐组合实现
- 不对源码功能进行任何优化与改动
由于转写了部分依赖库,复现代码的运行环境有所简化:
future
PIL
paddle == 2.1.2 #numpy等基础依赖库版本以其依赖为准
yacs
目前测试可稳定运行的硬件环境与软件环境匹配包括:
- Win10-CPU-python==3.8.1
- Win10-单核GPU-python==3.9.6
我们随代码提供了用于测试大部分底层依赖的测试代码,详见复现代码测试.ipynb
复现代码训练与重训练方法与原代码基本相同,但请将命令路径内的分隔符换为反斜杠以规避原代码中的os.path.join
bug。具体代码如下(当然,我们也对设置文件进行了同样的修改):
cd Paddle_Cream
python ./tools/main.py test .\experiments\configs\test\train.yaml
python ./tools/main.py test .\experiments\configs\test\14.yaml
我们提供了重训练完成的14M级模型,一下命令可对其精度进行测试:
cd Paddle_Cream
python ./tools/main.py test .\experiments\configs\test\test.yaml
我们得到的测试结果如下:
指标 | 原文的 | 我们的 |
---|---|---|
一级准确率(Top-1 Acc.) | 53.8% | 53.76% |
五级准确率(Top-5 Acc.) | 77.2% | 77.10% |
测试日志见test.log