Nouveau : Visualisez PBA et les augmentations appliquées avec le notebook pba.ipynb
!
Maintenant avec le support de Python 3.
L'augmentation basée sur la population (PBA) est un algorithme qui apprend rapidement et efficacement les fonctions d'augmentation des données pour la formation des réseaux neuronaux. PBA associe les résultats de pointe du CIFAR avec mille fois moins de calcul, permettant aux chercheurs et aux praticiens d'apprendre efficacement de nouvelles politiques d'augmentation à l'aide d'un seul GPU de poste de travail.
Ce référentiel contient le code du travail « Population Based Augmentation : Efficient Learning of Augmentation Schedules » (http://arxiv.org/abs/1905.05393) dans TensorFlow et Python. Cela comprend la formation de modèles avec les calendriers d'augmentation signalés et la découverte de nouveaux calendriers de politiques d'augmentation.
Voir ci-dessous pour une visualisation de notre stratégie d'augmentation.
Le code prend en charge Python 2 et 3.
pip install -r requirements.txt
bash datasets/cifar10.sh
bash datasets/cifar100.sh
Ensemble de données | Modèle | Erreur de test (%) |
---|---|---|
CIFAR-10 | Large-ResNet-28-10 | 2,58 |
Secouer-Secouer (26 2x32d) | 2,54 | |
Secouer-Secouer (26 2x96d) | 2.03 | |
Secouer-Secouer (26 2x112d) | 2.03 | |
PyramidNet+ShakeDrop | 1,46 | |
CIFAR-10 réduit | Large-ResNet-28-10 | 12.82 |
Secouer-Secouer (26 2x96d) | 10.64 | |
CIFAR-100 | Large-ResNet-28-10 | 16.73 |
Secouer-Secouer (26 2x96d) | 15h31 | |
PyramidNet+ShakeDrop | 10.94 | |
SVHN | Large-ResNet-28-10 | 1.18 |
Secouer-Secouer (26 2x96d) | 1.13 | |
SVHN réduite | Large-ResNet-28-10 | 7,83 |
Secouer-Secouer (26 2x96d) | 6.46 |
Les scripts pour reproduire les résultats se trouvent dans scripts/table_*.sh
. Un argument, le nom du modèle, est requis pour tous les scripts. Les options disponibles sont celles rapportées pour chaque ensemble de données dans le tableau 2 de l'article, parmi les choix : wrn_28_10, ss_32, ss_96, ss_112, pyramid_net
. Les hyperparamètres sont également situés dans chaque fichier de script.
Par exemple, pour reproduire les résultats CIFAR-10 sur Wide-ResNet-28-10 :
bash scripts/table_1_cifar10.sh wrn_28_10
Pour reproduire les résultats SVHN réduits sur Shake-Shake (26 2x96d) :
bash scripts/table_4_svhn.sh rsvhn_ss_96
Un bon point de départ est le SVHN réduit sur Wide-ResNet-28-10 qui peut se terminer en moins de 10 minutes sur un GPU Titan XP atteignant une précision de test de plus de 91 %.
L’exécution des modèles les plus grands sur 1 800 époques peut nécessiter plusieurs jours de formation. Par exemple, CIFAR-10 PyramidNet+ShakeDrop prend environ 9 jours sur un GPU Tesla V100.
Exécutez la recherche PBA sur Wide-ResNet-40-2 avec le fichier scripts/search.sh
. Un argument, le nom de l'ensemble de données, est requis. Les choix sont rsvhn
ou rcifar10
.
Une taille de GPU partielle est spécifiée pour lancer plusieurs essais sur le même GPU. Le SVHN réduit prend environ une heure sur un GPU Titan XP, et le CIFAR-10 réduit prend environ 5 heures.
CUDA_VISIBLE_DEVICES=0 bash scripts/search.sh rsvhn
Les planifications résultantes utilisées dans la recherche peuvent être récupérées du répertoire de résultats Ray et les fichiers journaux peuvent être convertis en planifications de politique avec la fonction parse_log()
dans pba/utils.py
. Par exemple, le calendrier politique appris sur le CIFAR-10 réduit sur 200 époques est divisé en valeurs d'hyperparamètres de probabilité et d'ampleur (les deux valeurs pour chaque opération d'augmentation sont fusionnées) et visualisé ci-dessous :
Hyperparamètres de probabilité dans le temps | Hyperparamètres d'ampleur au fil du temps |
---|---|
Si vous utilisez PBA dans votre recherche, veuillez citer :
@inproceedings{ho2019pba,
title = {Population Based Augmentation: Efficient Learning of Augmentation Policy Schedules},
author = {Daniel Ho and
Eric Liang and
Ion Stoica and
Pieter Abbeel and
Xi Chen
},
booktitle = {ICML},
year = {2019}
}