Repo ini adalah implementasi resmi dari INTR: A Simple Interpretable Transformer for Fine-grained Image Classification and Analysis. Saat ini mencakup kode dan model untuk interpretasi data terperinci. Kami akan memberikan tautan ke proses ICLR 2024 mendatang untuk makalah ini ketika sudah tersedia secara online.
INTR adalah penggunaan baru Transformers untuk membuat klasifikasi gambar dapat diinterpretasikan. Di INTR, kami menyelidiki pendekatan proaktif terhadap klasifikasi, meminta setiap kelas untuk mencari dirinya sendiri dalam sebuah gambar. Kami mempelajari kueri khusus kelas (satu untuk setiap kelas) sebagai masukan ke decoder, memungkinkan mereka mencari keberadaannya dalam gambar melalui perhatian silang. Kami menunjukkan bahwa INTR secara intrinsik mendorong setiap kelas untuk hadir secara jelas; sehingga bobot perhatian silang memberikan interpretasi yang berarti terhadap prediksi model. Menariknya, melalui perhatian silang multi-head, INTR dapat belajar melokalisasi berbagai atribut kelas, sehingga sangat cocok untuk klasifikasi dan analisis terperinci.
Dalam model INTR, setiap query di decoder bertanggung jawab atas prediksi suatu kelas. Jadi, kueri melihat dirinya sendiri untuk menemukan fitur khusus kelas dari peta fitur. Pertama, kita memvisualisasikan peta fitur yaitu matriks nilai arsitektur transformator untuk melihat bagian-bagian penting dari objek pada gambar. Untuk menemukan fitur spesifik yang menjadi perhatian model dalam matriks nilai, kami menampilkan peta panas perhatian model. Untuk menghindari campur tangan eksternal dalam klasifikasi, kami menggunakan vektor bobot bersama untuk klasifikasi sehingga bobot perhatian menjelaskan prediksi model.
INTR pada tulang punggung DETR-R50, kinerja klasifikasi, dan model yang disempurnakan pada kumpulan data yang berbeda.
Kumpulan data | akun@1 | acc@5 | Model |
---|---|---|---|
ANAK | 71.8 | 89.3 | unduhan pos pemeriksaan |
Burung | 97.4 | 99.2 | unduhan pos pemeriksaan |
kupu-kupu | 95.0 | 98.3 | unduhan pos pemeriksaan |
Buat lingkungan python (opsional)
conda create -n intr python=3.8 -y
conda activate intr
Kloning repositori
git clone https://github.com/dipanjyoti/INTR.git
cd INTR
Instal dependensi python
pip install -r requirements.txt
Ikuti format data di bawah ini.
datasets
├── dataset_name
│ ├── train
│ │ ├── class1
│ │ │ ├── img1.jpeg
│ │ │ ├── img2.jpeg
│ │ │ └── ...
│ │ ├── class2
│ │ │ ├── img3.jpeg
│ │ │ └── ...
│ │ └── ...
│ └── val
│ ├── class1
│ │ ├── img4.jpeg
│ │ ├── img5.jpeg
│ │ └── ...
│ ├── class2
│ │ ├── img6.jpeg
│ │ └── ...
│ └── ...
Untuk mengevaluasi performa INTR pada dataset CUB , pada pengaturan multi-GPU (misalnya 4 GPU), jalankan perintah di bawah ini. Pos pemeriksaan INTR tersedia di Sempurnakan model dan hasil.
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --master_port 12345 --use_env main.py --eval --resume < path/to/intr_checkpoint_cub_detr_r50.pth > --dataset_path < path/to/datasets > --dataset_name < dataset_name >
Untuk menghasilkan representasi visual dari interpretasi INTR, jalankan perintah yang disediakan di bawah ini. Perintah ini akan menyajikan interpretasi untuk kelas tertentu dengan indeks
python -m tools.visualization --eval --resume < path/to/intr_checkpoint_cub_detr_r50.pth > --dataset_path < path/to/datasets > --dataset_name < dataset_name > --class_index < class_number >
Prediksi dan visualisasi gambar tunggal waktu inferensi: Kami juga menyediakan Notebook Jupyter, demo.ipynb, yang dirancang untuk prediksi dan visualisasi gambar tunggal selama proses inferensi. Harap dicatat bahwa demo difokuskan pada dataset CUB.
Untuk mempersiapkan INTR untuk pelatihan, gunakan model DETR-R50 yang telah dilatih sebelumnya. Untuk melatih kumpulan data tertentu, ubah '--num_queries' dengan mengaturnya ke jumlah kelas dalam kumpulan data. Dalam arsitektur INTR, setiap kueri di dekoder diberi tugas untuk menangkap fitur khusus kelas, yang berarti bahwa setiap kueri dapat diadaptasi melalui proses pembelajaran. Akibatnya, jumlah parameter model akan bertambah sebanding dengan jumlah kelas dalam kumpulan data. Untuk melatih INTR pada sistem multi-GPU, (misalnya 4 GPU), jalankan perintah di bawah ini.
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --master_port 12345 --use_env main.py --finetune < path/to/detr-r50-e632da11.pth > --dataset_path < path/to/datasets > --dataset_name < dataset_name > --num_queries < num_of_classes >
Model kami terinspirasi oleh metode DEtection TRansformer (DETR).
Kami berterima kasih kepada penulis DETR karena telah melakukan pekerjaan luar biasa.
Jika Anda merasa pekerjaan kami bermanfaat untuk penelitian Anda, mohon pertimbangkan untuk mengutip entri BibTeX.
@inproceedings{paul2024simple,
title={A Simple Interpretable Transformer for Fine-Grained Image Classification and Analysis},
author={Paul, Dipanjyoti and Chowdhury, Arpita and Xiong, Xinqi and Chang, Feng-Ju and Carlyn, David and Stevens, Samuel and Provost, Kaiya and Karpatne, Anuj and Carstens, Bryan and Rubenstein, Daniel and Stewart, Charles and Berger-Wolf, Tanya and Su, Yu and Chao, Wei-Lun},
booktitle={International Conference on Learning Representations},
year={2024}
}