Kumpulan skrip pelatihan cepat dan mandiri untuk CIFAR-10.
Naskah | Berarti akurasi | Waktu | PFLOP |
---|---|---|---|
airbench94_compiled.py | 94,01% | 3,09 detik | 0,36 |
airbench94.py | 94,01% | 3,83 detik | 0,36 |
airbench95.py | 95,01% | 10.4 detik | 1.4 |
airbench96.py | 96,03% | 34,7 detik | 4.9 |
airbench94_muon.py | 94,01% | 2,59 detik | 0,29 |
airbench96_faster.py | 96,00% | 27.3 detik | 3.1 |
Sebagai perbandingan, pelatihan standar yang digunakan di sebagian besar studi tentang CIFAR-10 jauh lebih lambat:
Dasar | Berarti akurasi | Waktu | PFLOP |
---|---|---|---|
Pelatihan standar ResNet-18 | 96% | 7 menit | 32.3 |
Semua pengaturan waktu menggunakan satu GPU NVIDIA A100.
Catatan: airbench96
telah ditingkatkan sejak kertas dari 46s menjadi 35s. Selain itu, airbench96_faster
adalah metode yang ditingkatkan (tetapi lebih rumit) yang menggunakan pemfilteran data dengan model proksi kecil. Dan airbench94_muon
adalah metode yang ditingkatkan menggunakan varian pengoptimal Muon.
Serangkaian metode yang digunakan untuk memperoleh kecepatan pelatihan ini dijelaskan dalam makalah.
Untuk melatih jaringan saraf dengan akurasi 94%, jalankan salah satunya
git clone https://github.com/KellerJordan/cifar10-airbench.git
cd airbench && python airbench94.py
atau
pip install airbench
python -c "import airbench; airbench.warmup94(); airbench.train94()"
Catatan: airbench94_compiled.py
dan airbench94.py
setara (yaitu, menghasilkan distribusi jaringan terlatih yang sama), dan hanya berbeda karena yang pertama menggunakan torch.compile
untuk meningkatkan pemanfaatan GPU. Yang pertama dimaksudkan untuk eksperimen di mana banyak jaringan dilatih sekaligus untuk mengamortisasi biaya kompilasi satu kali.
CIFAR-10 adalah salah satu kumpulan data yang paling banyak digunakan dalam pembelajaran mesin, memfasilitasi ribuan proyek penelitian per tahun. Repo ini memberikan data dasar pelatihan yang cepat dan stabil untuk CIFAR-10 guna membantu mempercepat penelitian ini. Pelatihan ini disediakan sebagai skrip PyTorch bebas ketergantungan yang mudah dijalankan, dan dapat menggantikan garis dasar klasik seperti pelatihan ResNet-20 atau ResNet-18.
Untuk menulis eksperimen atau pelatihan CIFAR-10 khusus, Anda mungkin merasa berguna jika menggunakan pemuat data yang dipercepat GPU secara mandiri.
import airbench
train_loader = airbench.CifarLoader('/tmp/cifar10', train=True, aug=dict(flip=True, translate=4, cutout=16), batch_size=500)
test_loader = airbench.CifarLoader('/tmp/cifar10', train=False, batch_size=1000)
for epoch in range(200):
for inputs, labels in train_loader:
# outputs = model(inputs)
# loss = F.cross_entropy(outputs, labels)
...
Jika Anda ingin mengubah data di loader, dapat dilakukan seperti ini:
import airbench
train_loader = airbench.CifarLoader('/tmp/cifar10', train=True, aug=dict(flip=True, translate=4, cutout=16), batch_size=500)
mask = (train_loader.labels < 6) # (this is just an example, the mask can be anything)
train_loader.images = train_loader.images[mask]
train_loader.labels = train_loader.labels[mask]
print(len(train_loader)) # The loader now contains 30,000 images and has batch size 500, so this prints 60.
Airbench dapat digunakan sebagai platform eksperimen dalam pemilihan data dan pembelajaran aktif. Berikut ini adalah contoh eksperimen yang menunjukkan hasil klasik bahwa contoh dengan tingkat keyakinan rendah memberikan lebih banyak sinyal pelatihan dibandingkan contoh acak. Ini berjalan dalam <20 detik pada A100.
import torch
from airbench import train94, infer, evaluate, CifarLoader
net = train94(label_smoothing=0) # train this network without label smoothing to get a better confidence signal
loader = CifarLoader('cifar10', train=True, batch_size=1000)
logits = infer(net, loader)
conf = logits.log_softmax(1).amax(1) # confidence
train_loader = CifarLoader('cifar10', train=True, batch_size=1024, aug=dict(flip=True, translate=2))
mask = (torch.rand(len(train_loader.labels)) < 0.6)
print('Training on %d images selected randomly' % mask.sum())
train_loader.images = train_loader.images[mask]
train_loader.labels = train_loader.labels[mask]
train94(train_loader, epochs=16) # yields around 93% accuracy
train_loader = CifarLoader('cifar10', train=True, batch_size=1024, aug=dict(flip=True, translate=2))
mask = (conf < conf.float().quantile(0.6))
print('Training on %d images selected based on minimum confidence' % mask.sum())
train_loader.images = train_loader.images[mask]
train_loader.labels = train_loader.labels[mask]
train94(train_loader, epochs=16) # yields around 94% accuracy => low-confidence sampling is better than random.
Proyek ini dibangun berdasarkan rekor luar biasa sebelumnya https://github.com/tysam-code/hlb-CIFAR10 (6,3 A100-detik hingga 94%).
Yang dibangun berdasarkan seri luar biasa https://myrtle.ai/learn/how-to-train-your-resnet/ (26 V100-detik hingga 94%, yaitu >=8 A100-detik)