实际实现了一种极其简单的自监督学习方法,该方法达到了新的最先进水平(超越 SimCLR),无需对比学习,也无需指定负对。
该存储库提供了一个模块,可以轻松包装任何基于图像的神经网络(残差网络、鉴别器、策略网络),以立即开始从未标记的图像数据中受益。
更新 1:现在有新的证据表明批量归一化是使该技术发挥良好作用的关键
更新 2:一篇新论文成功用组范数 + 权重标准化取代了批量范数,反驳了 BYOL 需要批量统计才能发挥作用
更新 3:最后,我们对它的工作原理进行了一些分析
Yannic Kilcher 的精彩解释
现在就让您的组织免于为标签付费:)
$ pip install byol-pytorch
只需插入您的神经网络,指定 (1) 图像尺寸以及 (2) 隐藏层的名称(或索引),其输出用作自监督训练的潜在表示。
import torch
from byol_pytorch import BYOL
from torchvision import models
resnet = models . resnet50 ( pretrained = True )
learner = BYOL (
resnet ,
image_size = 256 ,
hidden_layer = 'avgpool'
)
opt = torch . optim . Adam ( learner . parameters (), lr = 3e-4 )
def sample_unlabelled_images ():
return torch . randn ( 20 , 3 , 256 , 256 )
for _ in range ( 100 ):
images = sample_unlabelled_images ()
loss = learner ( images )
opt . zero_grad ()
loss . backward ()
opt . step ()
learner . update_moving_average () # update moving average of target encoder
# save your improved network
torch . save ( resnet . state_dict (), './improved-net.pt' )
差不多就这样了。经过大量训练后,残差网络现在应该在其下游监督任务上表现更好。
Kaiming He 的一篇新论文提出,BYOL 甚至不需要目标编码器是在线编码器的指数移动平均值。我决定构建此选项,以便您可以轻松地使用该变体进行训练,只需将use_momentum
标志设置为False
即可。如果您采用以下示例所示的路线,您将不再需要调用update_moving_average
。
import torch
from byol_pytorch import BYOL
from torchvision import models
resnet = models . resnet50 ( pretrained = True )
learner = BYOL (
resnet ,
image_size = 256 ,
hidden_layer = 'avgpool' ,
use_momentum = False # turn off momentum in the target encoder
)
opt = torch . optim . Adam ( learner . parameters (), lr = 3e-4 )
def sample_unlabelled_images ():
return torch . randn ( 20 , 3 , 256 , 256 )
for _ in range ( 100 ):
images = sample_unlabelled_images ()
loss = learner ( images )
opt . zero_grad ()
loss . backward ()
opt . step ()
# save your improved network
torch . save ( resnet . state_dict (), './improved-net.pt' )
虽然超参数已设置为论文发现的最佳参数,但您可以使用基本包装类的额外关键字参数来更改它们。
learner = BYOL (
resnet ,
image_size = 256 ,
hidden_layer = 'avgpool' ,
projection_size = 256 , # the projection size
projection_hidden_size = 4096 , # the hidden dimension of the MLP for both the projection and prediction
moving_average_decay = 0.99 # the moving average decay factor for the target encoder, already set at what paper recommends
)
默认情况下,该库将使用 SimCLR 论文中的增强功能(BYOL 论文中也使用了该内容)。但是,如果您想指定自己的增强管道,则只需使用augment_fn
关键字传入您自己的自定义增强函数即可。
augment_fn = nn . Sequential (
kornia . augmentation . RandomHorizontalFlip ()
)
learner = BYOL (
resnet ,
image_size = 256 ,
hidden_layer = - 2 ,
augment_fn = augment_fn
)
在论文中,他们似乎确保其中一种增强比另一种增强具有更高的高斯模糊概率。您也可以根据自己的喜好进行调整。
augment_fn = nn . Sequential (
kornia . augmentation . RandomHorizontalFlip ()
)
augment_fn2 = nn . Sequential (
kornia . augmentation . RandomHorizontalFlip (),
kornia . filters . GaussianBlur2d (( 3 , 3 ), ( 1.5 , 1.5 ))
)
learner = BYOL (
resnet ,
image_size = 256 ,
hidden_layer = - 2 ,
augment_fn = augment_fn ,
augment_fn2 = augment_fn2 ,
)
要获取嵌入或投影,您只需将return_embeddings = True
标志传递给BYOL
学习者实例
import torch
from byol_pytorch import BYOL
from torchvision import models
resnet = models . resnet50 ( pretrained = True )
learner = BYOL (
resnet ,
image_size = 256 ,
hidden_layer = 'avgpool'
)
imgs = torch . randn ( 2 , 3 , 256 , 256 )
projection , embedding = learner ( imgs , return_embedding = True )
该存储库现在提供分布式培训?抱脸加速。您只需将自己的Dataset
集传递到导入的BYOLTrainer
中
首先通过调用加速 CLI 设置分布式训练的配置
$ accelerate config
然后按照如下所示制作训练脚本,例如./train.py
from torchvision import models
from byol_pytorch import (
BYOL ,
BYOLTrainer ,
MockDataset
)
resnet = models . resnet50 ( pretrained = True )
dataset = MockDataset ( 256 , 10000 )
trainer = BYOLTrainer (
resnet ,
dataset = dataset ,
image_size = 256 ,
hidden_layer = 'avgpool' ,
learning_rate = 3e-4 ,
num_train_steps = 100_000 ,
batch_size = 16 ,
checkpoint_every = 1000 # improved model will be saved periodically to ./checkpoints folder
)
trainer ()
然后再次使用加速 CLI 启动脚本
$ accelerate launch ./train.py
如果您的下游任务涉及分段,请查看以下存储库,它将 BYOL 扩展到“像素”级学习。
https://github.com/lucidrains/pixel-level-contrastive-learning
@misc { grill2020bootstrap ,
title = { Bootstrap Your Own Latent: A New Approach to Self-Supervised Learning } ,
author = { Jean-Bastien Grill and Florian Strub and Florent Altché and Corentin Tallec and Pierre H. Richemond and Elena Buchatskaya and Carl Doersch and Bernardo Avila Pires and Zhaohan Daniel Guo and Mohammad Gheshlaghi Azar and Bilal Piot and Koray Kavukcuoglu and Rémi Munos and Michal Valko } ,
year = { 2020 } ,
eprint = { 2006.07733 } ,
archivePrefix = { arXiv } ,
primaryClass = { cs.LG }
}
@misc { chen2020exploring ,
title = { Exploring Simple Siamese Representation Learning } ,
author = { Xinlei Chen and Kaiming He } ,
year = { 2020 } ,
eprint = { 2011.10566 } ,
archivePrefix = { arXiv } ,
primaryClass = { cs.CV }
}