CIFAR-10 的快速且独立的培训脚本集合。
脚本 | 平均准确度 | 时间 | 浮点运算次数 |
---|---|---|---|
airbench94_compiled.py | 94.01% | 3.09秒 | 0.36 |
airbench94.py | 94.01% | 3.83秒 | 0.36 |
airbench95.py | 95.01% | 10.4秒 | 1.4 |
airbench96.py | 96.03% | 34.7秒 | 4.9 |
airbench94_muon.py | 94.01% | 2.59秒 | 0.29 |
airbench96_faster.py | 96.00% | 27.3秒 | 3.1 |
作为比较,大多数 CIFAR-10 研究中使用的标准训练要慢得多:
基线 | 平均准确度 | 时间 | 浮点运算次数 |
---|---|---|---|
标准 ResNet-18 训练 | 96% | 7分钟 | 32.3 |
所有计时均在单个 NVIDIA A100 GPU 上进行。
注: airbench96
自论文发表以来已从 46s 改进为 35s。此外, airbench96_faster
是一种改进的(但更复杂)方法,它使用小型代理模型进行数据过滤。 airbench94_muon
是一种使用 Muon 优化器变体的改进方法。
论文中描述了用于获得这些训练速度的一组方法。
要以 94% 的准确率训练神经网络,请运行
git clone https://github.com/KellerJordan/cifar10-airbench.git
cd airbench && python airbench94.py
或者
pip install airbench
python -c "import airbench; airbench.warmup94(); airbench.train94()"
注意: airbench94_compiled.py
和airbench94.py
是等效的(即,产生相同的训练网络分布),不同之处仅在于前者使用torch.compile
来提高 GPU 利用率。前者用于同时训练许多网络的实验,以分摊一次性编译成本。
CIFAR-10 是机器学习领域使用最广泛的数据集之一,每年促进数千个研究项目。该存储库为 CIFAR-10 提供快速稳定的训练基线,以帮助加速这项研究。训练以可轻松运行、无依赖的 PyTorch 脚本的形式提供,并且可以替换训练 ResNet-20 或 ResNet-18 等经典基线。
为了编写自定义 CIFAR-10 实验或训练,您可能会发现独立使用 GPU 加速数据加载器很有用。
import airbench
train_loader = airbench.CifarLoader('/tmp/cifar10', train=True, aug=dict(flip=True, translate=4, cutout=16), batch_size=500)
test_loader = airbench.CifarLoader('/tmp/cifar10', train=False, batch_size=1000)
for epoch in range(200):
for inputs, labels in train_loader:
# outputs = model(inputs)
# loss = F.cross_entropy(outputs, labels)
...
如果你想修改加载器中的数据,可以这样做:
import airbench
train_loader = airbench.CifarLoader('/tmp/cifar10', train=True, aug=dict(flip=True, translate=4, cutout=16), batch_size=500)
mask = (train_loader.labels < 6) # (this is just an example, the mask can be anything)
train_loader.images = train_loader.images[mask]
train_loader.labels = train_loader.labels[mask]
print(len(train_loader)) # The loader now contains 30,000 images and has batch size 500, so this prints 60.
Airbench 可用作数据选择和主动学习实验的平台。下面是一个示例实验,它演示了低置信度示例比随机示例提供更多训练信号的经典结果。在 A100 上运行时间不到 20 秒。
import torch
from airbench import train94, infer, evaluate, CifarLoader
net = train94(label_smoothing=0) # train this network without label smoothing to get a better confidence signal
loader = CifarLoader('cifar10', train=True, batch_size=1000)
logits = infer(net, loader)
conf = logits.log_softmax(1).amax(1) # confidence
train_loader = CifarLoader('cifar10', train=True, batch_size=1024, aug=dict(flip=True, translate=2))
mask = (torch.rand(len(train_loader.labels)) < 0.6)
print('Training on %d images selected randomly' % mask.sum())
train_loader.images = train_loader.images[mask]
train_loader.labels = train_loader.labels[mask]
train94(train_loader, epochs=16) # yields around 93% accuracy
train_loader = CifarLoader('cifar10', train=True, batch_size=1024, aug=dict(flip=True, translate=2))
mask = (conf < conf.float().quantile(0.6))
print('Training on %d images selected based on minimum confidence' % mask.sum())
train_loader.images = train_loader.images[mask]
train_loader.labels = train_loader.labels[mask]
train94(train_loader, epochs=16) # yields around 94% accuracy => low-confidence sampling is better than random.
该项目建立在之前的优异记录 https://github.com/tysam-code/hlb-CIFAR10(6.3 A100 秒达到 94%)的基础上。
它本身建立在令人惊叹的系列 https://myrtle.ai/learn/how-to-train-your-resnet/(26 V100 秒到 94%,即 >=8 A100 秒)