glom pytorch
Glom 的实现是 Geoffrey Hinton 的新想法,它集成了神经领域的概念、自上而下的处理和注意力(列之间的共识),用于从数据中学习新兴的部分整体层次结构。
Yannic Kilcher 的视频对帮助我理解本文很有帮助
$ pip install glom-pytorch
import torch
from glom_pytorch import Glom
model = Glom (
dim = 512 , # dimension
levels = 6 , # number of levels
image_size = 224 , # image size
patch_size = 14 # patch size
img = torch . randn ( 1 , 3 , 224 , 224 )
levels = model ( img , iters = 12 ) # (1, 256, 6, 512) - (batch - patches - levels - dimension)
在向前传递return_all = True
关键字参数,您将返回每次迭代的所有列和级别状态(包括初始状态、迭代次数 + 1)。然后,您可以使用它在任何时间步将任何损失附加到任何级别输出。
import torch
from glom_pytorch import Glom
model = Glom (
dim = 512 , # dimension
levels = 6 , # number of levels
image_size = 224 , # image size
patch_size = 14 # patch size
img = torch . randn ( 1 , 3 , 224 , 224 )
all_levels = model ( img , iters = 12 , return_all = True ) # (13, 1, 256, 6, 512) - (time, batch, patches, levels, dimension)
# get the top level outputs after iteration 6
top_level_output = all_levels [ 7 , :, :, - 1 ] # (1, 256, 512) - (batch, patches, dimension)
正如 Hinton 所描述的那样,对自我监督学习进行去噪以鼓励涌现
import torch
import torch . nn . functional as F
from torch import nn
from einops . layers . torch import Rearrange
from glom_pytorch import Glom
model = Glom (
dim = 512 , # dimension
levels = 6 , # number of levels
image_size = 224 , # image size
patch_size = 14 # patch size
img = torch . randn ( 1 , 3 , 224 , 224 )
noised_img = img + torch . randn_like ( img )
all_levels = model ( noised_img , return_all = True )
patches_to_images = nn . Sequential (
nn . Linear ( 512 , 14 * 14 * 3 ),
Rearrange ( 'b (h w) (p1 p2 c) -> b c (h p1) (w p2)' , p1 = 14 , p2 = 14 , h = ( 224 // 14 ))
top_level = all_levels [ 7 , :, :, - 1 ] # get the top level embeddings after iteration 6
recon_img = patches_to_images ( top_level )
# do self-supervised learning by denoising
loss = F . mse_loss ( img , recon_img )
loss . backward ()
import torch
from glom_pytorch import Glom
model = Glom (
dim = 512 ,
levels = 6 ,
image_size = 224 ,
patch_size = 14
img1 = torch . randn ( 1 , 3 , 224 , 224 )
img2 = torch . randn ( 1 , 3 , 224 , 224 )
img3 = torch . randn ( 1 , 3 , 224 , 224 )
levels1 = model ( img1 , iters = 12 ) # image 1 for 12 iterations
levels2 = model ( img2 , levels = levels1 , iters = 10 ) # image 2 for 10 iteratoins
levels3 = model ( img3 , levels = levels2 , iters = 6 ) # image 3 for 6 iterations
感谢 Cfoster0 审阅代码
@misc { hinton2021represent ,
title = { How to represent part-whole hierarchies in a neural network } ,
author = { Geoffrey Hinton } ,
year = { 2021 } ,
eprint = { 2102.12627 } ,
archivePrefix = { arXiv } ,
primaryClass = { cs.CV }