slot attention
1.4.0
การใช้ Slot Attention จากรายงาน 'การเรียนรู้แบบเน้นวัตถุด้วย Slot Attention' ใน Pytorch นี่คือวิดีโอที่อธิบายว่าเครือข่ายนี้ทำอะไรได้บ้าง
อัปเดต: พื้นที่เก็บข้อมูลอย่างเป็นทางการได้รับการเผยแพร่แล้วที่นี่
$ pip install slot_attention
import torch
from slot_attention import SlotAttention
slot_attn = SlotAttention (
num_slots = 5 ,
dim = 512 ,
iters = 3 # iterations of attention, defaults to 3
)
inputs = torch . randn ( 2 , 1024 , 512 )
slot_attn ( inputs ) # (2, 5, 512)
หลังการฝึกอบรม เครือข่ายได้รับการรายงานว่าสามารถสรุปจำนวนช่อง (คลัสเตอร์) ที่แตกต่างกันเล็กน้อย คุณสามารถแทนที่จำนวนช่องที่ใช้โดยคีย์เวิร์ด num_slots
ในการส่งต่อได้
slot_attn ( inputs , num_slots = 8 ) # (2, 8, 512)
หากต้องการใช้วิธีการช่องแบบอะแดปทีฟเพื่อสร้างฮอตมาส์กที่แตกต่างกันสำหรับว่าจะใช้ช่องหรือไม่ ให้ทำดังต่อไปนี้
import torch
from slot_attention import MultiHeadSlotAttention , AdaptiveSlotWrapper
# define slot attention
slot_attn = MultiHeadSlotAttention (
dim = 512 ,
num_slots = 5 ,
iters = 3 ,
)
# wrap the slot attention
adaptive_slots = AdaptiveSlotWrapper (
slot_attn ,
temperature = 0.5 # gumbel softmax temperature
)
inputs = torch . randn ( 2 , 1024 , 512 )
slots , keep_slots = adaptive_slots ( inputs ) # (2, 5, 512), (2, 5)
# the auxiliary loss in the paper for minimizing number of slots used for a scene would simply be
keep_aux_loss = keep_slots . sum () # add this to your main loss with some weight
@misc { locatello2020objectcentric ,
title = { Object-Centric Learning with Slot Attention } ,
author = { Francesco Locatello and Dirk Weissenborn and Thomas Unterthiner and Aravindh Mahendran and Georg Heigold and Jakob Uszkoreit and Alexey Dosovitskiy and Thomas Kipf } ,
year = { 2020 } ,
eprint = { 2006.15055 } ,
archivePrefix = { arXiv } ,
primaryClass = { cs.LG }
}
@article { Fan2024AdaptiveSA ,
title = { Adaptive Slot Attention: Object Discovery with Dynamic Slot Number } ,
author = { Ke Fan and Zechen Bai and Tianjun Xiao and Tong He and Max Horn and Yanwei Fu and Francesco Locatello and Zheng Zhang } ,
journal = { ArXiv } ,
year = { 2024 } ,
volume = { abs/2406.09196 } ,
url = { https://api.semanticscholar.org/CorpusID:270440447 }
}