Une implémentation de Glom, la nouvelle idée de Geoffrey Hinton qui intègre les concepts des champs neuronaux, du traitement descendant et ascendant et de l'attention (consensus entre les colonnes) pour apprendre les hiérarchies partie-tout émergentes à partir des données.
La vidéo de Yannic Kilcher m'a beaucoup aidé à comprendre cet article
$ pip install glom-pytorch
import torch
from glom_pytorch import Glom
model = Glom (
dim = 512 , # dimension
levels = 6 , # number of levels
image_size = 224 , # image size
patch_size = 14 # patch size
)
img = torch . randn ( 1 , 3 , 224 , 224 )
levels = model ( img , iters = 12 ) # (1, 256, 6, 512) - (batch - patches - levels - dimension)
Passez l'argument de mot-clé return_all = True
en avant, et vous recevrez tous les états de colonne et de niveau par itération (y compris l'état initial, le nombre d'itérations + 1). Vous pouvez ensuite l'utiliser pour associer les pertes à n'importe quel niveau de sortie à tout moment.
Il vous donne également accès à toutes les données de niveau à travers les itérations pour le clustering, à partir desquelles on peut inspecter les îlots théorisés dans l'article.
import torch
from glom_pytorch import Glom
model = Glom (
dim = 512 , # dimension
levels = 6 , # number of levels
image_size = 224 , # image size
patch_size = 14 # patch size
)
img = torch . randn ( 1 , 3 , 224 , 224 )
all_levels = model ( img , iters = 12 , return_all = True ) # (13, 1, 256, 6, 512) - (time, batch, patches, levels, dimension)
# get the top level outputs after iteration 6
top_level_output = all_levels [ 7 , :, :, - 1 ] # (1, 256, 512) - (batch, patches, dimension)
Débruitage de l'apprentissage auto-supervisé pour encourager l'émergence, comme décrit par Hinton
import torch
import torch . nn . functional as F
from torch import nn
from einops . layers . torch import Rearrange
from glom_pytorch import Glom
model = Glom (
dim = 512 , # dimension
levels = 6 , # number of levels
image_size = 224 , # image size
patch_size = 14 # patch size
)
img = torch . randn ( 1 , 3 , 224 , 224 )
noised_img = img + torch . randn_like ( img )
all_levels = model ( noised_img , return_all = True )
patches_to_images = nn . Sequential (
nn . Linear ( 512 , 14 * 14 * 3 ),
Rearrange ( 'b (h w) (p1 p2 c) -> b c (h p1) (w p2)' , p1 = 14 , p2 = 14 , h = ( 224 // 14 ))
)
top_level = all_levels [ 7 , :, :, - 1 ] # get the top level embeddings after iteration 6
recon_img = patches_to_images ( top_level )
# do self-supervised learning by denoising
loss = F . mse_loss ( img , recon_img )
loss . backward ()
Vous pouvez transmettre l'état de la colonne et des niveaux dans le modèle pour continuer là où vous vous êtes arrêté (peut-être si vous traitez des images consécutives d'une vidéo lente, comme mentionné dans l'article)
import torch
from glom_pytorch import Glom
model = Glom (
dim = 512 ,
levels = 6 ,
image_size = 224 ,
patch_size = 14
)
img1 = torch . randn ( 1 , 3 , 224 , 224 )
img2 = torch . randn ( 1 , 3 , 224 , 224 )
img3 = torch . randn ( 1 , 3 , 224 , 224 )
levels1 = model ( img1 , iters = 12 ) # image 1 for 12 iterations
levels2 = model ( img2 , levels = levels1 , iters = 10 ) # image 2 for 10 iteratoins
levels3 = model ( img3 , levels = levels2 , iters = 6 ) # image 3 for 6 iterations
Merci à Cfoster0 pour avoir révisé le code
@misc { hinton2021represent ,
title = { How to represent part-whole hierarchies in a neural network } ,
author = { Geoffrey Hinton } ,
year = { 2021 } ,
eprint = { 2102.12627 } ,
archivePrefix = { arXiv } ,
primaryClass = { cs.CV }
}