新功能:使用笔记本pba.ipynb
可视化 PBA 和应用的增强!
现在支持 Python 3。
基于群体的增强(PBA)是一种快速有效地学习用于神经网络训练的数据增强函数的算法。 PBA 与 CIFAR 上最先进的结果相匹配,而计算量却减少了一千倍,使研究人员和从业人员能够使用单个工作站 GPU 有效地学习新的增强策略。
该存储库包含 TensorFlow 和 Python 中的“基于群体的增强:增强计划的高效学习”(http://arxiv.org/abs/1905.05393) 工作的代码。它包括使用报告的增强计划来训练模型以及发现新的增强策略计划。
请参阅下文,了解我们的增强策略的可视化。
代码支持Python 2和3。
pip install -r requirements.txt
bash datasets/cifar10.sh
bash datasets/cifar100.sh
数据集 | 模型 | 测试误差(%) |
---|---|---|
CIFAR-10 | Wide-ResNet-28-10 | 2.58 |
摇一摇 (26 2x32d) | 2.54 | |
摇一摇 (26 2x96d) | 2.03 | |
摇一摇 (26 2x112d) | 2.03 | |
PyramidNet+ShakeDrop | 1.46 | |
简化的 CIFAR-10 | Wide-ResNet-28-10 | 12.82 |
摇一摇 (26 2x96d) | 10.64 | |
CIFAR-100 | Wide-ResNet-28-10 | 16.73 |
摇一摇 (26 2x96d) | 15.31 | |
PyramidNet+ShakeDrop | 10.94 | |
SVHN | Wide-ResNet-28-10 | 1.18 |
摇一摇 (26 2x96d) | 1.13 | |
减少SVHN | Wide-ResNet-28-10 | 7.83 |
摇一摇 (26 2x96d) | 6.46 |
重现结果的脚本位于scripts/table_*.sh
中。所有脚本都需要一个参数,即模型名称。可用选项是本文表 2 中为每个数据集报告的选项,其中包括: wrn_28_10, ss_32, ss_96, ss_112, pyramid_net
。超参数也位于每个脚本文件内。
例如,要在 Wide-ResNet-28-10 上重现 CIFAR-10 结果:
bash scripts/table_1_cifar10.sh wrn_28_10
要在 Shake-Shake (26 2x96d) 上重现简化的 SVHN 结果:
bash scripts/table_4_svhn.sh rsvhn_ss_96
一个好的起点是 Wide-ResNet-28-10 上的简化 SVHN,它可以在 Titan XP GPU 上在 10 分钟内完成,达到 91% 以上的测试精度。
在 1800 个 epoch 上运行较大的模型可能需要多天的训练。例如,CIFAR-10 PyramidNet+ShakeDrop 在 Tesla V100 GPU 上大约需要 9 天。
使用文件scripts/search.sh
在 Wide-ResNet-40-2 上运行 PBA 搜索。需要一个参数,即数据集名称。选项是rsvhn
或rcifar10
。
指定部分 GPU 大小以在同一 GPU 上启动多个试验。简化的 SVHN 在 Titan XP GPU 上大约需要一个小时,简化的 CIFAR-10 大约需要 5 小时。
CUDA_VISIBLE_DEVICES=0 bash scripts/search.sh rsvhn
搜索中使用的结果调度可以从 Ray 结果目录中检索,并且可以使用pba/utils.py
中的parse_log()
函数将日志文件转换为策略调度。例如,在简化的 CIFAR-10 上学习超过 200 个时期的策略计划被分为概率和幅度超参数值(每个增强操作的两个值被合并)并可视化如下:
随时间变化的概率超参数 | 超参数随时间变化的幅度 |
---|---|
如果您在研究中使用 PBA,请引用:
@inproceedings{ho2019pba,
title = {Population Based Augmentation: Efficient Learning of Augmentation Policy Schedules},
author = {Daniel Ho and
Eric Liang and
Ion Stoica and
Pieter Abbeel and
Xi Chen
},
booktitle = {ICML},
year = {2019}
}