Новое: визуализируйте PBA и применяемые дополнения с помощью записной книжки pba.ipynb
!
Теперь с поддержкой Python 3.
Популяционное увеличение (PBA) — это алгоритм, который быстро и эффективно изучает функции увеличения данных для обучения нейронных сетей. PBA соответствует самым современным результатам CIFAR, используя в тысячу раз меньше вычислительных ресурсов, что позволяет исследователям и практикам эффективно изучать новые политики расширения, используя один графический процессор рабочей станции.
Этот репозиторий содержит код для работы «Популяционное увеличение: эффективное изучение графиков расширения» (http://arxiv.org/abs/1905.05393) в TensorFlow и Python. Он включает в себя обучение моделей с использованием сообщаемых графиков расширения и открытие новых графиков политики расширения.
Ниже представлена визуализация нашей стратегии расширения.
Код поддерживает Python 2 и 3.
pip install -r requirements.txt
bash datasets/cifar10.sh
bash datasets/cifar100.sh
Набор данных | Модель | Ошибка теста (%) |
---|---|---|
СИФАР-10 | Широкий-ResNet-28-10 | 2,58 |
Встряхните-встряхните (26 2x32d) | 2,54 | |
Встряхните-встряхните (26 2x96d) | 2.03 | |
Встряхните-встряхните (26 2x112d) | 2.03 | |
PyramidNet+ShakeDrop | 1,46 | |
Пониженный CIFAR-10 | Широкий-ResNet-28-10 | 12.82 |
Встряхните-встряхните (26 2x96d) | 10.64 | |
СИФАР-100 | Широкий-ResNet-28-10 | 16,73 |
Встряхните-встряхните (26 2x96d) | 15.31 | |
PyramidNet+ShakeDrop | 10.94 | |
СВХН | Широкий-ResNet-28-10 | 1.18 |
Встряхните-встряхните (26 2x96d) | 1.13 | |
Уменьшенный СВХН | Широкий-ResNet-28-10 | 7,83 |
Встряхните-встряхните (26 2x96d) | 6.46 |
Скрипты для воспроизведения результатов находятся в scripts/table_*.sh
. Для всех сценариев требуется один аргумент — имя модели. Доступные варианты — это те, которые указаны для каждого набора данных в таблице 2 документа, среди вариантов: wrn_28_10, ss_32, ss_96, ss_112, pyramid_net
. Гиперпараметры также находятся внутри каждого файла сценария.
Например, чтобы воспроизвести результаты CIFAR-10 на Wide-ResNet-28-10:
bash scripts/table_1_cifar10.sh wrn_28_10
Чтобы воспроизвести результаты уменьшенного SVHN на Shake-Shake (26 2x96d):
bash scripts/table_4_svhn.sh rsvhn_ss_96
Хорошим местом для начала является Уменьшенный SVHN на Wide-ResNet-28-10, который можно выполнить менее чем за 10 минут на графическом процессоре Titan XP, достигая точности теста 91%+.
Для запуска более крупных моделей в 1800 эпохах может потребоваться несколько дней обучения. Например, CIFAR-10 PyramidNet+ShakeDrop занимает около 9 дней на графическом процессоре Tesla V100.
Запустите поиск PBA на Wide-ResNet-40-2 с помощью файла scripts/search.sh
. Требуется один аргумент — имя набора данных. Варианты: rsvhn
или rcifar10
.
Частичный размер графического процессора указан для запуска нескольких пробных версий на одном графическом процессоре. Уменьшенный SVHN занимает около часа на графическом процессоре Titan XP, а уменьшенный CIFAR-10 — около 5 часов.
CUDA_VISIBLE_DEVICES=0 bash scripts/search.sh rsvhn
Результирующие расписания, используемые при поиске, можно получить из каталога результатов Ray, а файлы журналов можно преобразовать в расписания политик с помощью функции parse_log()
в pba/utils.py
. Например, график политики, полученный по сокращенному CIFAR-10 за 200 эпох, разделен на значения гиперпараметров вероятности и величины (два значения для каждой операции увеличения объединяются) и визуализируются ниже:
Вероятностные гиперпараметры во времени | Гиперпараметры величины во времени |
---|---|
Если вы используете PBA в своих исследованиях, укажите:
@inproceedings{ho2019pba,
title = {Population Based Augmentation: Efficient Learning of Augmentation Policy Schedules},
author = {Daniel Ho and
Eric Liang and
Ion Stoica and
Pieter Abbeel and
Xi Chen
},
booktitle = {ICML},
year = {2019}
}