實際上實現了一種極其簡單的自監督學習方法,該方法達到了新的最先進水平(超越 SimCLR),無需對比學習,也無需指定負對。
更新 1:現在有新的證據表明批量歸一化是使該技術發揮良好作用的關鍵
更新 2:一篇新論文成功用組範數 + 權重標準化取代了批量範數,反駁了 BYOL 需要批量統計才能發揮作用
更新 3:最後,我們對它的工作原理做了一些分析
Yannic Kilcher 的精彩解釋
$ pip install byol-pytorch
只需插入您的神經網絡,指定 (1) 影像尺寸以及 (2) 隱藏層的名稱(或索引),其輸出用作自監督訓練的潛在表示。
import torch
from byol_pytorch import BYOL
from torchvision import models
resnet = models . resnet50 ( pretrained = True )
learner = BYOL (
resnet ,
image_size = 256 ,
hidden_layer = 'avgpool'
opt = torch . optim . Adam ( learner . parameters (), lr = 3e-4 )
def sample_unlabelled_images ():
return torch . randn ( 20 , 3 , 256 , 256 )
for _ in range ( 100 ):
images = sample_unlabelled_images ()
loss = learner ( images )
opt . zero_grad ()
loss . backward ()
opt . step ()
learner . update_moving_average () # update moving average of target encoder
# save your improved network
torch . save ( resnet . state_dict (), './improved-net.pt' )
Kaiming He 的一篇新論文提出,BYOL 甚至不需要目標編碼器是線上編碼器的指數移動平均值。我決定建立此選項,以便您可以輕鬆地使用該變體進行訓練,只需將use_momentum
import torch
from byol_pytorch import BYOL
from torchvision import models
resnet = models . resnet50 ( pretrained = True )
learner = BYOL (
resnet ,
image_size = 256 ,
hidden_layer = 'avgpool' ,
use_momentum = False # turn off momentum in the target encoder
opt = torch . optim . Adam ( learner . parameters (), lr = 3e-4 )
def sample_unlabelled_images ():
return torch . randn ( 20 , 3 , 256 , 256 )
for _ in range ( 100 ):
images = sample_unlabelled_images ()
loss = learner ( images )
opt . zero_grad ()
loss . backward ()
opt . step ()
# save your improved network
torch . save ( resnet . state_dict (), './improved-net.pt' )
learner = BYOL (
resnet ,
image_size = 256 ,
hidden_layer = 'avgpool' ,
projection_size = 256 , # the projection size
projection_hidden_size = 4096 , # the hidden dimension of the MLP for both the projection and prediction
moving_average_decay = 0.99 # the moving average decay factor for the target encoder, already set at what paper recommends
預設情況下,該程式庫將使用 SimCLR 論文中的增強功能(BYOL 論文中也使用了該內容)。但是,如果您想指定自己的增強管道,則只需使用augment_fn
augment_fn = nn . Sequential (
kornia . augmentation . RandomHorizontalFlip ()
learner = BYOL (
resnet ,
image_size = 256 ,
hidden_layer = - 2 ,
augment_fn = augment_fn
augment_fn = nn . Sequential (
kornia . augmentation . RandomHorizontalFlip ()
augment_fn2 = nn . Sequential (
kornia . augmentation . RandomHorizontalFlip (),
kornia . filters . GaussianBlur2d (( 3 , 3 ), ( 1.5 , 1.5 ))
learner = BYOL (
resnet ,
image_size = 256 ,
hidden_layer = - 2 ,
augment_fn = augment_fn ,
augment_fn2 = augment_fn2 ,
要取得嵌入或投影,您只需將return_embeddings = True
import torch
from byol_pytorch import BYOL
from torchvision import models
resnet = models . resnet50 ( pretrained = True )
learner = BYOL (
resnet ,
image_size = 256 ,
hidden_layer = 'avgpool'
imgs = torch . randn ( 2 , 3 , 256 , 256 )
projection , embedding = learner ( imgs , return_embedding = True )
首先透過呼叫加速 CLI 設定分散式訓練的配置
$ accelerate config
from torchvision import models
from byol_pytorch import (
BYOLTrainer ,
resnet = models . resnet50 ( pretrained = True )
dataset = MockDataset ( 256 , 10000 )
trainer = BYOLTrainer (
resnet ,
dataset = dataset ,
image_size = 256 ,
hidden_layer = 'avgpool' ,
learning_rate = 3e-4 ,
num_train_steps = 100_000 ,
batch_size = 16 ,
checkpoint_every = 1000 # improved model will be saved periodically to ./checkpoints folder
trainer ()
然後再次使用加速 CLI 啟動腳本
$ accelerate launch ./train.py
如果您的下游任務涉及分段,請查看以下儲存庫,它將 BYOL 擴展到「像素」級學習。
