Коллекция быстрых и автономных сценариев обучения для CIFAR-10.
Скрипт | Средняя точность | Время | Пфлопс |
---|---|---|---|
airbench94_compiled.py | 94,01% | 3,09 с | 0,36 |
airbench94.py | 94,01% | 3,83 с | 0,36 |
airbench95.py | 95,01% | 10,4 с | 1,4 |
airbench96.py | 96,03% | 34,7 с | 4,9 |
airbench94_muon.py | 94,01% | 2,59 с | 0,29 |
airbench96_faster.py | 96,00% | 27,3 с | 3.1 |
Для сравнения: стандартное обучение, используемое в большинстве исследований по CIFAR-10, работает намного медленнее:
Базовый уровень | Средняя точность | Время | Пфлопс |
---|---|---|---|
Стандартное обучение ResNet-18 | 96% | 7мин | 32,3 |
Все тайминги указаны на одном графическом процессоре NVIDIA A100.
Примечание: airbench96
был улучшен по сравнению с бумагой с 46 до 35. Кроме того, airbench96_faster
— это улучшенный (но более сложный) метод, использующий фильтрацию данных с помощью небольшой прокси-модели. А airbench94_muon
— это улучшенный метод, использующий вариант оптимизатора Muon.
В статье описан набор методов, используемых для получения этих скоростей обучения.
Чтобы обучить нейронную сеть с точностью 94%, запустите либо
git clone https://github.com/KellerJordan/cifar10-airbench.git
cd airbench && python airbench94.py
или
pip install airbench
python -c "import airbench; airbench.warmup94(); airbench.train94()"
Примечание. airbench94_compiled.py
и airbench94.py
эквивалентны (т. е. дают одинаковое распределение обученных сетей) и отличаются только тем, что первый использует torch.compile
для улучшения использования графического процессора. Первый предназначен для экспериментов, в которых одновременно обучается множество сетей, чтобы амортизировать единовременные затраты на компиляцию.
CIFAR-10 — один из наиболее широко используемых наборов данных в машинном обучении, позволяющий реализовывать тысячи исследовательских проектов в год. Этот репозиторий предоставляет быстрые и стабильные основы обучения CIFAR-10, чтобы помочь ускорить эти исследования. Обучение предоставляется в виде легко выполняемых сценариев PyTorch без зависимостей и может заменить классические базовые программы, такие как обучение ResNet-20 или ResNet-18.
Для написания пользовательских экспериментов или тренингов CIFAR-10 вам может оказаться полезным самостоятельно использовать загрузчик данных с ускорением на графическом процессоре.
import airbench
train_loader = airbench.CifarLoader('/tmp/cifar10', train=True, aug=dict(flip=True, translate=4, cutout=16), batch_size=500)
test_loader = airbench.CifarLoader('/tmp/cifar10', train=False, batch_size=1000)
for epoch in range(200):
for inputs, labels in train_loader:
# outputs = model(inputs)
# loss = F.cross_entropy(outputs, labels)
...
Если вы хотите изменить данные в загрузчике, это можно сделать так:
import airbench
train_loader = airbench.CifarLoader('/tmp/cifar10', train=True, aug=dict(flip=True, translate=4, cutout=16), batch_size=500)
mask = (train_loader.labels < 6) # (this is just an example, the mask can be anything)
train_loader.images = train_loader.images[mask]
train_loader.labels = train_loader.labels[mask]
print(len(train_loader)) # The loader now contains 30,000 images and has batch size 500, so this prints 60.
Airbench можно использовать как платформу для экспериментов по выбору данных и активному обучению. Ниже приведен пример эксперимента, демонстрирующий классический результат: примеры с низкой достоверностью дают больше обучающего сигнала, чем случайные примеры. На A100 он выполняется менее чем за 20 секунд.
import torch
from airbench import train94, infer, evaluate, CifarLoader
net = train94(label_smoothing=0) # train this network without label smoothing to get a better confidence signal
loader = CifarLoader('cifar10', train=True, batch_size=1000)
logits = infer(net, loader)
conf = logits.log_softmax(1).amax(1) # confidence
train_loader = CifarLoader('cifar10', train=True, batch_size=1024, aug=dict(flip=True, translate=2))
mask = (torch.rand(len(train_loader.labels)) < 0.6)
print('Training on %d images selected randomly' % mask.sum())
train_loader.images = train_loader.images[mask]
train_loader.labels = train_loader.labels[mask]
train94(train_loader, epochs=16) # yields around 93% accuracy
train_loader = CifarLoader('cifar10', train=True, batch_size=1024, aug=dict(flip=True, translate=2))
mask = (conf < conf.float().quantile(0.6))
print('Training on %d images selected based on minimum confidence' % mask.sum())
train_loader.images = train_loader.images[mask]
train_loader.labels = train_loader.labels[mask]
train94(train_loader, epochs=16) # yields around 94% accuracy => low-confidence sampling is better than random.
Этот проект основан на превосходном предыдущем рекорде https://github.com/tysam-code/hlb-CIFAR10 (6,3 А100-секунд до 94%).
Который сам по себе основан на потрясающей серии https://myrtle.ai/learn/how-to-train-your-resnet/ (26 100 секунд до 94 %, что >=8 100 секунд)