New: ノートブックpba.ipynb
を使用して、PBA と適用された拡張機能を視覚化します。
Python 3 がサポートされるようになりました。
Population Based Augmentation (PBA) は、ニューラル ネットワーク トレーニング用のデータ拡張関数を迅速かつ効率的に学習するアルゴリズムです。 PBA は、1,000 分の 1 のコンピューティングで CIFAR の最先端の結果に匹敵し、研究者や実践者が単一のワークステーション GPU を使用して新しい拡張ポリシーを効果的に学習できるようにします。
このリポジトリには、TensorFlow と Python による作品「Population Based Augmentation: Efficient Learning of Augmentation Schedules」(http://arxiv.org/abs/1905.05393) のコードが含まれています。これには、報告された拡張スケジュールを使用したモデルのトレーニングと、新しい拡張ポリシー スケジュールの検出が含まれます。
拡張戦略の視覚化については、以下を参照してください。
コードは Python 2 と 3 をサポートしています。
pip install -r requirements.txt
bash datasets/cifar10.sh
bash datasets/cifar100.sh
データセット | モデル | テストエラー (%) |
---|---|---|
CIFAR-10 | Wide-ResNet-28-10 | 2.58 |
シェイクシェイク (26 2x32d) | 2.54 | |
シェイクシェイク (26 2x96d) | 2.03 | |
シェイクシェイク (26 2x112d) | 2.03 | |
ピラミッドネット+シェイクドロップ | 1.46 | |
CIFAR-10の削減 | Wide-ResNet-28-10 | 12.82 |
シェイクシェイク (26 2x96d) | 10.64 | |
CIFAR-100 | Wide-ResNet-28-10 | 16.73 |
シェイクシェイク (26 2x96d) | 15.31 | |
ピラミッドネット+シェイクドロップ | 10.94 | |
SVHN | Wide-ResNet-28-10 | 1.18 |
シェイクシェイク (26 2x96d) | 1.13 | |
SVHNの減少 | Wide-ResNet-28-10 | 7.83 |
シェイクシェイク (26 2x96d) | 6.46 |
結果を再現するスクリプトはscripts/table_*.sh
にあります。 1 つの引数、モデル名はすべてのスクリプトに必要です。利用可能なオプションは、この論文の表 2 で各データセットについて報告されているもので、選択肢の中には、 wrn_28_10, ss_32, ss_96, ss_112, pyramid_net
があります。ハイパーパラメータも各スクリプト ファイル内にあります。
たとえば、CIFAR-10 の結果を Wide-ResNet-28-10 で再現するには、次のようにします。
bash scripts/table_1_cifar10.sh wrn_28_10
Shake-Shake (26 2x96d) で縮小 SVHN 結果を再現するには:
bash scripts/table_4_svhn.sh rsvhn_ss_96
まずは Wide-ResNet-28-10 の Reduced SVHN から始めるとよいでしょう。これは、91% 以上のテスト精度に達する Titan XP GPU で 10 分以内に完了できます。
1800 エポックで大規模なモデルを実行するには、数日間のトレーニングが必要になる場合があります。たとえば、CIFAR-10 PyramidNet+ShakeDrop には、Tesla V100 GPU で約 9 日かかります。
ファイルscripts/search.sh
を使用して、 Wide-ResNet-40-2 で PBA 検索を実行します。 1 つの引数、データセット名が必要です。選択肢は、 rsvhn
またはrcifar10
です。
部分的な GPU サイズは、同じ GPU で複数のトライアルを起動するために指定されます。縮小 SVHN には Titan XP GPU で約 1 時間かかり、縮小 CIFAR-10 には約 5 時間かかります。
CUDA_VISIBLE_DEVICES=0 bash scripts/search.sh rsvhn
検索で使用された結果のスケジュールは Ray 結果ディレクトリから取得でき、ログ ファイルはpba/utils.py
のparse_log()
関数を使用してポリシー スケジュールに変換できます。たとえば、200 エポックにわたる Reduced CIFAR-10 で学習されたポリシー スケジュールは、確率と大きさのハイパーパラメータ値に分割され (各拡張操作の 2 つの値がマージされます)、以下に視覚化されます。
時間の経過に伴う確率ハイパーパラメータ | 時間の経過に伴う大きさのハイパーパラメータ |
---|---|
研究で PBA を使用する場合は、以下を引用してください。
@inproceedings{ho2019pba,
title = {Population Based Augmentation: Efficient Learning of Augmentation Policy Schedules},
author = {Daniel Ho and
Eric Liang and
Ion Stoica and
Pieter Abbeel and
Xi Chen
},
booktitle = {ICML},
year = {2019}
}