slot attention
1.4.0
Implementación de Slot Attention del artículo 'Aprendizaje centrado en objetos con Slot Attention' en Pytorch. Aquí hay un vídeo que describe lo que esta red puede hacer.
Actualización: el repositorio oficial se ha publicado aquí.
$ pip install slot_attention
import torch
from slot_attention import SlotAttention
slot_attn = SlotAttention (
num_slots = 5 ,
dim = 512 ,
iters = 3 # iterations of attention, defaults to 3
)
inputs = torch . randn ( 2 , 1024 , 512 )
slot_attn ( inputs ) # (2, 5, 512)
Después del entrenamiento, se informa que la red puede generalizarse a un número ligeramente diferente de espacios (clústeres). Puede anular el número de espacios utilizados por la palabra clave num_slots
en adelante.
slot_attn ( inputs , num_slots = 8 ) # (2, 8, 512)
Para utilizar el método de ranura adaptativa para generar una máscara activa diferenciable y determinar si se debe utilizar una ranura, simplemente haga lo siguiente
import torch
from slot_attention import MultiHeadSlotAttention , AdaptiveSlotWrapper
# define slot attention
slot_attn = MultiHeadSlotAttention (
dim = 512 ,
num_slots = 5 ,
iters = 3 ,
)
# wrap the slot attention
adaptive_slots = AdaptiveSlotWrapper (
slot_attn ,
temperature = 0.5 # gumbel softmax temperature
)
inputs = torch . randn ( 2 , 1024 , 512 )
slots , keep_slots = adaptive_slots ( inputs ) # (2, 5, 512), (2, 5)
# the auxiliary loss in the paper for minimizing number of slots used for a scene would simply be
keep_aux_loss = keep_slots . sum () # add this to your main loss with some weight
@misc { locatello2020objectcentric ,
title = { Object-Centric Learning with Slot Attention } ,
author = { Francesco Locatello and Dirk Weissenborn and Thomas Unterthiner and Aravindh Mahendran and Georg Heigold and Jakob Uszkoreit and Alexey Dosovitskiy and Thomas Kipf } ,
year = { 2020 } ,
eprint = { 2006.15055 } ,
archivePrefix = { arXiv } ,
primaryClass = { cs.LG }
}
@article { Fan2024AdaptiveSA ,
title = { Adaptive Slot Attention: Object Discovery with Dynamic Slot Number } ,
author = { Ke Fan and Zechen Bai and Tianjun Xiao and Tong He and Max Horn and Yanwei Fu and Francesco Locatello and Zheng Zhang } ,
journal = { ArXiv } ,
year = { 2024 } ,
volume = { abs/2406.09196 } ,
url = { https://api.semanticscholar.org/CorpusID:270440447 }
}