Une collection de scripts de formation rapides et autonomes pour CIFAR-10.
Scénario | Précision moyenne | Temps | PFLOP |
---|---|---|---|
airbench94_compiled.py | 94,01% | 3.09s | 0,36 |
airbench94.py | 94,01% | 3,83 s | 0,36 |
airbench95.py | 95,01% | 10,4 s | 1.4 |
airbench96.py | 96,03% | 34,7 s | 4.9 |
airbench94_muon.py | 94,01% | 2,59s | 0,29 |
airbench96_faster.py | 96,00% | 27,3 s | 3.1 |
À titre de comparaison, la formation standard utilisée dans la plupart des études sur CIFAR-10 est beaucoup plus lente :
Référence | Précision moyenne | Temps | PFLOP |
---|---|---|---|
Formation standard ResNet-18 | 96% | 7min | 32.3 |
Tous les timings sont sur un seul GPU NVIDIA A100.
Remarque : airbench96
a été amélioré depuis l'article de 46s à 35s. De plus, airbench96_faster
est une méthode améliorée (mais plus compliquée) qui utilise le filtrage des données par un petit modèle proxy. Et airbench94_muon
est une méthode améliorée utilisant une variante de l'optimiseur Muon.
L'ensemble des méthodes utilisées pour obtenir ces vitesses d'entraînement sont décrites dans l'article.
Pour entraîner un réseau de neurones avec une précision de 94 %, exécutez soit
git clone https://github.com/KellerJordan/cifar10-airbench.git
cd airbench && python airbench94.py
ou
pip install airbench
python -c "import airbench; airbench.warmup94(); airbench.train94()"
Remarque : airbench94_compiled.py
et airbench94.py
sont équivalents (c'est-à-dire qu'ils produisent la même répartition des réseaux formés) et diffèrent uniquement par le fait que le premier utilise torch.compile
pour améliorer l'utilisation du GPU. Le premier est destiné aux expériences dans lesquelles de nombreux réseaux sont formés en même temps afin d'amortir le coût unique de compilation.
CIFAR-10 est l’un des ensembles de données les plus utilisés en apprentissage automatique, facilitant des milliers de projets de recherche chaque année. Ce référentiel fournit des bases de formation rapides et stables pour CIFAR-10 afin d'aider à accélérer cette recherche. Les formations sont fournies sous forme de scripts PyTorch sans dépendance et faciles à exécuter et peuvent remplacer les lignes de base classiques comme la formation ResNet-20 ou ResNet-18.
Pour écrire des expériences ou des formations CIFAR-10 personnalisées, vous trouverez peut-être utile d'utiliser indépendamment le chargeur de données accéléré par GPU.
import airbench
train_loader = airbench.CifarLoader('/tmp/cifar10', train=True, aug=dict(flip=True, translate=4, cutout=16), batch_size=500)
test_loader = airbench.CifarLoader('/tmp/cifar10', train=False, batch_size=1000)
for epoch in range(200):
for inputs, labels in train_loader:
# outputs = model(inputs)
# loss = F.cross_entropy(outputs, labels)
...
Si vous souhaitez modifier les données dans le chargeur, cela peut être fait comme ceci :
import airbench
train_loader = airbench.CifarLoader('/tmp/cifar10', train=True, aug=dict(flip=True, translate=4, cutout=16), batch_size=500)
mask = (train_loader.labels < 6) # (this is just an example, the mask can be anything)
train_loader.images = train_loader.images[mask]
train_loader.labels = train_loader.labels[mask]
print(len(train_loader)) # The loader now contains 30,000 images and has batch size 500, so this prints 60.
Airbench peut être utilisé comme plate-forme pour des expériences de sélection de données et d'apprentissage actif. Ce qui suit est un exemple d'expérience qui démontre le résultat classique selon lequel les exemples à faible confiance fournissent plus de signaux d'entraînement que les exemples aléatoires. Il fonctionne en moins de 20 secondes sur un A100.
import torch
from airbench import train94, infer, evaluate, CifarLoader
net = train94(label_smoothing=0) # train this network without label smoothing to get a better confidence signal
loader = CifarLoader('cifar10', train=True, batch_size=1000)
logits = infer(net, loader)
conf = logits.log_softmax(1).amax(1) # confidence
train_loader = CifarLoader('cifar10', train=True, batch_size=1024, aug=dict(flip=True, translate=2))
mask = (torch.rand(len(train_loader.labels)) < 0.6)
print('Training on %d images selected randomly' % mask.sum())
train_loader.images = train_loader.images[mask]
train_loader.labels = train_loader.labels[mask]
train94(train_loader, epochs=16) # yields around 93% accuracy
train_loader = CifarLoader('cifar10', train=True, batch_size=1024, aug=dict(flip=True, translate=2))
mask = (conf < conf.float().quantile(0.6))
print('Training on %d images selected based on minimum confidence' % mask.sum())
train_loader.images = train_loader.images[mask]
train_loader.labels = train_loader.labels[mask]
train94(train_loader, epochs=16) # yields around 94% accuracy => low-confidence sampling is better than random.
Ce projet s'appuie sur l'excellent record précédent https://github.com/tysam-code/hlb-CIFAR10 (6,3 A100 secondes à 94 %).
Qui s'appuie sur l'étonnante série https://myrtle.ai/learn/how-to-train-your-resnet/ (26 V100 secondes à 94 %, soit >=8 A100 secondes)