slot attention
1.4.0
Implémentation de Slot Attention à partir de l'article « Apprentissage centré sur l'objet avec Slot Attention » dans Pytorch. Voici une vidéo qui décrit ce que ce réseau peut faire.
Mise à jour : le référentiel officiel a été publié ici
$ pip install slot_attention
import torch
from slot_attention import SlotAttention
slot_attn = SlotAttention (
num_slots = 5 ,
dim = 512 ,
iters = 3 # iterations of attention, defaults to 3
)
inputs = torch . randn ( 2 , 1024 , 512 )
slot_attn ( inputs ) # (2, 5, 512)
Après formation, le réseau serait capable de se généraliser à un nombre légèrement différent d'emplacements (clusters). Vous pouvez remplacer le nombre d'emplacements utilisés par le mot-clé num_slots
en avant.
slot_attn ( inputs , num_slots = 8 ) # (2, 8, 512)
Pour utiliser la méthode de slot adaptatif pour générer un masque chaud différentiable pour savoir s'il faut utiliser un slot, procédez simplement comme suit
import torch
from slot_attention import MultiHeadSlotAttention , AdaptiveSlotWrapper
# define slot attention
slot_attn = MultiHeadSlotAttention (
dim = 512 ,
num_slots = 5 ,
iters = 3 ,
)
# wrap the slot attention
adaptive_slots = AdaptiveSlotWrapper (
slot_attn ,
temperature = 0.5 # gumbel softmax temperature
)
inputs = torch . randn ( 2 , 1024 , 512 )
slots , keep_slots = adaptive_slots ( inputs ) # (2, 5, 512), (2, 5)
# the auxiliary loss in the paper for minimizing number of slots used for a scene would simply be
keep_aux_loss = keep_slots . sum () # add this to your main loss with some weight
@misc { locatello2020objectcentric ,
title = { Object-Centric Learning with Slot Attention } ,
author = { Francesco Locatello and Dirk Weissenborn and Thomas Unterthiner and Aravindh Mahendran and Georg Heigold and Jakob Uszkoreit and Alexey Dosovitskiy and Thomas Kipf } ,
year = { 2020 } ,
eprint = { 2006.15055 } ,
archivePrefix = { arXiv } ,
primaryClass = { cs.LG }
}
@article { Fan2024AdaptiveSA ,
title = { Adaptive Slot Attention: Object Discovery with Dynamic Slot Number } ,
author = { Ke Fan and Zechen Bai and Tianjun Xiao and Tong He and Max Horn and Yanwei Fu and Francesco Locatello and Zheng Zhang } ,
journal = { ArXiv } ,
year = { 2024 } ,
volume = { abs/2406.09196 } ,
url = { https://api.semanticscholar.org/CorpusID:270440447 }
}