CIFAR-10 的快速且獨立的訓練腳本集合。
腳本 | 平均準確度 | 時間 | 浮點運算數 |
---|---|---|---|
airbench94_compiled.py | 94.01% | 3.09秒 | 0.36 |
airbench94.py | 94.01% | 3.83秒 | 0.36 |
airbench95.py | 95.01% | 10.4秒 | 1.4 |
airbench96.py | 96.03% | 34.7秒 | 4.9 |
airbench94_muon.py | 94.01% | 2.59秒 | 0.29 |
airbench96_faster.py | 96.00% | 27.3秒 | 3.1 |
作為比較,大多數 CIFAR-10 研究中使用的標準訓練要慢得多:
基線 | 平均準確度 | 時間 | 浮點運算數 |
---|---|---|---|
標準 ResNet-18 訓練 | 96% | 7分鐘 | 32.3 |
所有計時均在單一 NVIDIA A100 GPU 上進行。
註: airbench96
自論文發表以來已從 46s 改進為 35s。此外, airbench96_faster
是一種改進的(但更複雜)方法,它使用小型代理模型進行資料過濾。 airbench94_muon
是一種使用 Muon 優化器變體的改進方法。
論文中描述了用於獲得這些訓練速度的一組方法。
要以 94% 的準確率訓練神經網絡,請運行
git clone https://github.com/KellerJordan/cifar10-airbench.git
cd airbench && python airbench94.py
或者
pip install airbench
python -c "import airbench; airbench.warmup94(); airbench.train94()"
注意: airbench94_compiled.py
和airbench94.py
是等效的(即,產生相同的訓練網路分佈),不同之處僅在於前者使用torch.compile
來提高 GPU 利用率。前者用於同時訓練許多網路的實驗,以分攤一次性編譯成本。
CIFAR-10 是機器學習領域使用最廣泛的資料集之一,每年促進數千個研究計畫。此儲存庫為 CIFAR-10 提供快速穩定的訓練基線,以幫助加速這項研究。訓練以可輕鬆運行、無依賴的 PyTorch 腳本的形式提供,並且可以替代訓練 ResNet-20 或 ResNet-18 等經典基線。
為了編寫自訂 CIFAR-10 實驗或訓練,您可能會發現獨立使用 GPU 加速資料載入器很有用。
import airbench
train_loader = airbench.CifarLoader('/tmp/cifar10', train=True, aug=dict(flip=True, translate=4, cutout=16), batch_size=500)
test_loader = airbench.CifarLoader('/tmp/cifar10', train=False, batch_size=1000)
for epoch in range(200):
for inputs, labels in train_loader:
# outputs = model(inputs)
# loss = F.cross_entropy(outputs, labels)
...
如果你想修改載入器中的數據,可以這樣做:
import airbench
train_loader = airbench.CifarLoader('/tmp/cifar10', train=True, aug=dict(flip=True, translate=4, cutout=16), batch_size=500)
mask = (train_loader.labels < 6) # (this is just an example, the mask can be anything)
train_loader.images = train_loader.images[mask]
train_loader.labels = train_loader.labels[mask]
print(len(train_loader)) # The loader now contains 30,000 images and has batch size 500, so this prints 60.
Airbench 可用作數據選擇和主動學習實驗的平台。下面是一個範例實驗,它示範了低置信度範例比隨機範例提供更多訓練訊號的經典結果。在 A100 上運行時間不到 20 秒。
import torch
from airbench import train94, infer, evaluate, CifarLoader
net = train94(label_smoothing=0) # train this network without label smoothing to get a better confidence signal
loader = CifarLoader('cifar10', train=True, batch_size=1000)
logits = infer(net, loader)
conf = logits.log_softmax(1).amax(1) # confidence
train_loader = CifarLoader('cifar10', train=True, batch_size=1024, aug=dict(flip=True, translate=2))
mask = (torch.rand(len(train_loader.labels)) < 0.6)
print('Training on %d images selected randomly' % mask.sum())
train_loader.images = train_loader.images[mask]
train_loader.labels = train_loader.labels[mask]
train94(train_loader, epochs=16) # yields around 93% accuracy
train_loader = CifarLoader('cifar10', train=True, batch_size=1024, aug=dict(flip=True, translate=2))
mask = (conf < conf.float().quantile(0.6))
print('Training on %d images selected based on minimum confidence' % mask.sum())
train_loader.images = train_loader.images[mask]
train_loader.labels = train_loader.labels[mask]
train94(train_loader, epochs=16) # yields around 94% accuracy => low-confidence sampling is better than random.
此專案建立在先前的優異記錄 https://github.com/tysam-code/hlb-CIFAR10(6.3 A100 秒達到 94%)的基礎上。
它本身建立在令人驚嘆的系列 https://myrtle.ai/learn/how-to-train-your-resnet/ 之上(26 V100 秒到 94%,即 >=8 A100 秒)