Baru: Visualisasikan PBA dan terapkan augmentasi dengan notebook pba.ipynb
!
Sekarang dengan dukungan Python 3.
Augmentasi Berbasis Populasi (PBA) adalah algoritma yang dengan cepat dan efisien mempelajari fungsi augmentasi data untuk pelatihan jaringan saraf. PBA mencocokkan hasil tercanggih di CIFAR dengan komputasi seribu kali lebih sedikit, memungkinkan peneliti dan praktisi mempelajari kebijakan augmentasi baru secara efektif menggunakan satu GPU stasiun kerja.
Repositori ini berisi kode untuk pekerjaan "Augmentasi Berbasis Populasi: Pembelajaran Jadwal Augmentasi yang Efisien" (http://arxiv.org/abs/1905.05393) di TensorFlow dan Python. Hal ini mencakup pelatihan model dengan jadwal augmentasi yang dilaporkan dan penemuan jadwal kebijakan augmentasi baru.
Lihat di bawah untuk visualisasi strategi augmentasi kami.
Kode mendukung Python 2 dan 3.
pip install -r requirements.txt
bash datasets/cifar10.sh
bash datasets/cifar100.sh
Kumpulan data | Model | Kesalahan Tes (%) |
---|---|---|
CIFAR-10 | ResNet Lebar-28-10 | 2.58 |
Goyang-Goyang (26 2x32d) | 2.54 | |
Goyang-Goyang (26 2x96d) | 2.03 | |
Goyang-Goyang (26 2x112d) | 2.03 | |
PyramidNet+ShakeDrop | 1.46 | |
Mengurangi CIFAR-10 | ResNet Lebar-28-10 | 12.82 |
Goyang-Goyang (26 2x96d) | 10.64 | |
CIFAR-100 | ResNet Lebar-28-10 | 16.73 |
Goyang-Goyang (26 2x96d) | 15.31 | |
PyramidNet+ShakeDrop | 10.94 | |
SVHN | ResNet Lebar-28-10 | 1.18 |
Goyang-Goyang (26 2x96d) | 1.13 | |
Mengurangi SVHN | ResNet Lebar-28-10 | 7.83 |
Goyang-Goyang (26 2x96d) | 6.46 |
Skrip untuk mereproduksi hasil terletak di scripts/table_*.sh
. Satu argumen, nama model, diperlukan untuk semua skrip. Opsi yang tersedia adalah yang dilaporkan untuk setiap kumpulan data pada Tabel 2 makalah, di antara pilihan: wrn_28_10, ss_32, ss_96, ss_112, pyramid_net
. Hyperparameter juga terletak di dalam setiap file skrip.
Misalnya, untuk mereproduksi hasil CIFAR-10 di Wide-ResNet-28-10:
bash scripts/table_1_cifar10.sh wrn_28_10
Untuk mereproduksi hasil Reduced SVHN pada Shake-Shake (26 2x96d):
bash scripts/table_4_svhn.sh rsvhn_ss_96
Tempat yang baik untuk memulai adalah Mengurangi SVHN di Wide-ResNet-28-10 yang dapat diselesaikan dalam waktu kurang dari 10 menit pada GPU Titan XP yang mencapai akurasi pengujian 91%+.
Menjalankan model yang lebih besar pada periode 1800 mungkin memerlukan pelatihan beberapa hari. Misalnya, CIFAR-10 PyramidNet+ShakeDrop membutuhkan waktu sekitar 9 hari pada GPU Tesla V100.
Jalankan pencarian PBA di Wide-ResNet-40-2 dengan file scripts/search.sh
. Satu argumen, nama kumpulan data, wajib diisi. Pilihannya adalah rsvhn
atau rcifar10
.
Ukuran GPU parsial ditentukan untuk meluncurkan beberapa uji coba pada GPU yang sama. Pengurangan SVHN membutuhkan waktu sekitar satu jam pada GPU Titan XP, dan Pengurangan CIFAR-10 membutuhkan waktu sekitar 5 jam.
CUDA_VISIBLE_DEVICES=0 bash scripts/search.sh rsvhn
Jadwal hasil yang digunakan dalam pencarian dapat diambil dari direktori hasil Ray, dan file log dapat diubah menjadi jadwal kebijakan dengan fungsi parse_log()
di pba/utils.py
. Misalnya, jadwal kebijakan yang dipelajari pada Reduced CIFAR-10 selama 200 periode dibagi menjadi nilai hyperparameter probabilitas dan magnitudo (dua nilai untuk setiap operasi augmentasi digabungkan) dan divisualisasikan di bawah ini:
Probabilitas Hyperparameter dari Waktu ke Waktu | Besaran Hyperparameter dari Waktu ke Waktu |
---|---|
Jika Anda menggunakan PBA dalam penelitian Anda, harap kutip:
@inproceedings{ho2019pba,
title = {Population Based Augmentation: Efficient Learning of Augmentation Policy Schedules},
author = {Daniel Ho and
Eric Liang and
Ion Stoica and
Pieter Abbeel and
Xi Chen
},
booktitle = {ICML},
year = {2019}
}