slot attention
1.4.0
Pytorch의 'Object-Centric Learning with Slot Attention' 논문에서 Slot Attention 구현. 다음은 이 네트워크가 수행할 수 있는 작업을 설명하는 비디오입니다.
업데이트: 공식 저장소가 여기에 공개되었습니다.
$ pip install slot_attention
import torch
from slot_attention import SlotAttention
slot_attn = SlotAttention (
num_slots = 5 ,
dim = 512 ,
iters = 3 # iterations of attention, defaults to 3
)
inputs = torch . randn ( 2 , 1024 , 512 )
slot_attn ( inputs ) # (2, 5, 512)
훈련 후 네트워크는 약간 다른 수의 슬롯(클러스터)으로 일반화할 수 있는 것으로 보고되었습니다. 앞으로 num_slots
키워드가 사용하는 슬롯 수를 재정의할 수 있습니다.
slot_attn ( inputs , num_slots = 8 ) # (2, 8, 512)
슬롯 사용 여부에 대해 미분 가능한 핫 마스크를 생성하기 위해 적응형 슬롯 방법을 사용하려면 다음을 수행하십시오.
import torch
from slot_attention import MultiHeadSlotAttention , AdaptiveSlotWrapper
# define slot attention
slot_attn = MultiHeadSlotAttention (
dim = 512 ,
num_slots = 5 ,
iters = 3 ,
)
# wrap the slot attention
adaptive_slots = AdaptiveSlotWrapper (
slot_attn ,
temperature = 0.5 # gumbel softmax temperature
)
inputs = torch . randn ( 2 , 1024 , 512 )
slots , keep_slots = adaptive_slots ( inputs ) # (2, 5, 512), (2, 5)
# the auxiliary loss in the paper for minimizing number of slots used for a scene would simply be
keep_aux_loss = keep_slots . sum () # add this to your main loss with some weight
@misc { locatello2020objectcentric ,
title = { Object-Centric Learning with Slot Attention } ,
author = { Francesco Locatello and Dirk Weissenborn and Thomas Unterthiner and Aravindh Mahendran and Georg Heigold and Jakob Uszkoreit and Alexey Dosovitskiy and Thomas Kipf } ,
year = { 2020 } ,
eprint = { 2006.15055 } ,
archivePrefix = { arXiv } ,
primaryClass = { cs.LG }
}
@article { Fan2024AdaptiveSA ,
title = { Adaptive Slot Attention: Object Discovery with Dynamic Slot Number } ,
author = { Ke Fan and Zechen Bai and Tianjun Xiao and Tong He and Max Horn and Yanwei Fu and Francesco Locatello and Zheng Zhang } ,
journal = { ArXiv } ,
year = { 2024 } ,
volume = { abs/2406.09196 } ,
url = { https://api.semanticscholar.org/CorpusID:270440447 }
}