Novo: Visualize o PBA e os aumentos aplicados com o notebook pba.ipynb
!
Agora com suporte para Python 3.
Population Based Augmentation (PBA) é um algoritmo que aprende de forma rápida e eficiente funções de aumento de dados para treinamento de redes neurais. O PBA combina resultados de última geração no CIFAR com mil vezes menos computação, permitindo que pesquisadores e profissionais aprendam efetivamente novas políticas de aumento usando uma única GPU de estação de trabalho.
Este repositório contém código para o trabalho "Aumento Baseado em População: Aprendizagem Eficiente de Cronogramas de Aumento" (http://arxiv.org/abs/1905.05393) em TensorFlow e Python. Inclui o treinamento de modelos com os cronogramas de aumento relatados e a descoberta de novos cronogramas de políticas de aumento.
Veja abaixo uma visualização de nossa estratégia de aumento.
O código suporta Python 2 e 3.
pip install -r requirements.txt
bash datasets/cifar10.sh
bash datasets/cifar100.sh
Conjunto de dados | Modelo | Erro de teste (%) |
---|---|---|
CIFAR-10 | Wide-ResNet-28-10 | 2,58 |
Agite-Shake (26 2x32d) | 2,54 | |
Agite-Shake (26 2x96d) | 2.03 | |
Agite-Shake (26 2x112d) | 2.03 | |
PyramidNet+ShakeDrop | 1,46 | |
CIFAR-10 reduzido | Wide-ResNet-28-10 | 12,82 |
Agite-Shake (26 2x96d) | 10,64 | |
CIFAR-100 | Wide-ResNet-28-10 | 16,73 |
Agite-Shake (26 2x96d) | 15h31 | |
PyramidNet+ShakeDrop | 10,94 | |
SVHN | Wide-ResNet-28-10 | 1.18 |
Agite-Shake (26 2x96d) | 1.13 | |
SVHN reduzido | Wide-ResNet-28-10 | 7,83 |
Agite-Shake (26 2x96d) | 6,46 |
Os scripts para reproduzir resultados estão localizados em scripts/table_*.sh
. Um argumento, o nome do modelo, é necessário para todos os scripts. As opções disponíveis são aquelas relatadas para cada conjunto de dados na Tabela 2 do artigo, entre as opções: wrn_28_10, ss_32, ss_96, ss_112, pyramid_net
. Os hiperparâmetros também estão localizados dentro de cada arquivo de script.
Por exemplo, para reproduzir resultados CIFAR-10 em Wide-ResNet-28-10:
bash scripts/table_1_cifar10.sh wrn_28_10
Para reproduzir resultados de SVHN reduzido em Shake-Shake (26 2x96d):
bash scripts/table_4_svhn.sh rsvhn_ss_96
Um bom lugar para começar é o SVHN reduzido em Wide-ResNet-28-10, que pode ser concluído em menos de 10 minutos em uma GPU Titan XP, atingindo 91%+ de precisão de teste.
A execução de modelos maiores em épocas de 1.800 pode exigir vários dias de treinamento. Por exemplo, CIFAR-10 PyramidNet+ShakeDrop leva cerca de 9 dias em uma GPU Tesla V100.
Execute a pesquisa PBA em Wide-ResNet-40-2 com o arquivo scripts/search.sh
. Um argumento, o nome do conjunto de dados, é obrigatório. As opções são rsvhn
ou rcifar10
.
Um tamanho parcial da GPU é especificado para iniciar vários testes na mesma GPU. O SVHN reduzido leva cerca de uma hora em uma GPU Titan XP, e o CIFAR-10 reduzido leva cerca de 5 horas.
CUDA_VISIBLE_DEVICES=0 bash scripts/search.sh rsvhn
Os agendamentos resultantes usados na pesquisa podem ser recuperados do diretório de resultados do Ray e os arquivos de log podem ser convertidos em agendamentos de política com a função parse_log()
em pba/utils.py
. Por exemplo, o cronograma de política aprendido no CIFAR-10 reduzido ao longo de 200 épocas é dividido em valores de hiperparâmetros de probabilidade e magnitude (os dois valores para cada operação de aumento são mesclados) e visualizado abaixo:
Hiperparâmetros de probabilidade ao longo do tempo | Hiperparâmetros de magnitude ao longo do tempo |
---|---|
Se você usa PBA em sua pesquisa, cite:
@inproceedings{ho2019pba,
title = {Population Based Augmentation: Efficient Learning of Augmentation Policy Schedules},
author = {Daniel Ho and
Eric Liang and
Ion Stoica and
Pieter Abbeel and
Xi Chen
},
booktitle = {ICML},
year = {2019}
}