Implementasi Pytorch resmi "Pembelajaran Konteks Tautan untuk LLM Multimodal" [CVPR 2024].
Repositori ini berisi implementasi dan dataset resmi dari makalah berikut:
Pembelajaran Link-Context untuk LLM Multimodal
https://arxiv.org/abs/2308.07891Abstrak: Kemampuan untuk belajar dari konteks dengan konsep -konsep baru, dan memberikan respons yang tepat sangat penting dalam percakapan manusia. Meskipun model bahasa besar multimodal saat ini (MLLM) dan model bahasa besar (LLM) dilatih pada dataset skala besar, mengenali gambar yang tidak terlihat atau memahami konsep-konsep baru secara bebas pelatihan tetap menjadi tantangan. Pembelajaran In-Context (ICL) mengeksplorasi pembelajaran beberapa tembakan bebas pelatihan, di mana model didorong untuk "belajar untuk belajar" dari tugas terbatas dan menggeneralisasi ke tugas yang tidak terlihat. Dalam karya ini, kami mengusulkan pembelajaran tautan-konteks (LCL), yang menekankan "penalaran dari sebab dan akibat" untuk menambah kemampuan pembelajaran MLLM. LCL melampaui ICL tradisional dengan secara eksplisit memperkuat hubungan kausal antara set dukungan dan set kueri. Dengan memberikan demonstrasi dengan tautan kausal, LCL memandu model untuk membedakan tidak hanya analogi tetapi juga hubungan kausal yang mendasari antara titik data, yang memberdayakan MLLM untuk mengenali gambar yang tidak terlihat dan memahami konsep -konsep baru secara lebih efektif. Untuk memfasilitasi evaluasi pendekatan baru ini, kami memperkenalkan dataset Isekai, yang terdiri dari pasangan label gambar yang tidak terlihat yang tidak terlihat yang dirancang untuk pembelajaran konteks tautan. Eksperimen ekstensif menunjukkan bahwa LCL-MLLM kami menunjukkan kemampuan pembelajaran konteks tautan-konteks yang kuat untuk konsep-konsep baru di atas vanilla MLLM.
conda create -n lcl python=3.10
conda activate lcl
pip install -r requirements.txt
accelerate config
Kami melatih pengaturan LCL pada set Rebuild ImageNet-900 kami, dan mengevaluasi model pada set ImageNet-100. Anda bisa mendapatkan dataset json di sini.
Kami mengevaluasi model pada isekai-10 dan isekai-pair, Anda dapat mengunduh dataset isekai di isekai-10 dan isekai-pair.
Unduh pos pemeriksaan LCL-2WAY-WEIGHT dan LCL-MIX kami di HuggingFace.
Untuk meluncurkan demo Web Gradio, gunakan perintah berikut. Harap dicatat bahwa model mengevaluasi dalam format Torch.float16, yang membutuhkan GPU dengan setidaknya 16GB memori.
python ./mllm/demo/demo.py --model_path /path/to/lcl/ckpt
Dimungkinkan juga untuk menggunakannya dalam kuantisasi 8-bit, meskipun dengan mengorbankan pengorbanan beberapa kinerja.
python ./mllm/demo/demo.py --model_path /path/to/lcl/ckpt --load_in_8bit
Setelah menyiapkan data, Anda dapat melatih model menggunakan perintah:
accelerate launch --num_processes 4
--main_process_port 23786
mllm/pipeline/finetune.py
config/lcl_train_2way_weight.py
--cfg-options data_args.use_icl=True
--cfg-options model_args.model_name_or_path=/path/to/init/checkpoint
accelerate launch --num_processes 4
--main_process_port 23786
mllm/pipeline/finetune.py
config/lcl_train_mix1.py
--cfg-options data_args.use_icl=True
--cfg-options model_args.model_name_or_path=/path/to/init/checkpoint
Setelah menyiapkan data, Anda dapat menyimpulkan model menggunakan perintah:
accelerate launch --num_processes 4
--main_process_port 23786
mllm/pipeline/finetune.py
config/lcl_eval_ISEKAI_10.py
--cfg-options data_args.use_icl=True
--cfg-options model_args.model_name_or_path=/path/to/checkpoint
MMengine Style Args dan Huggingface: Pelatih ARGS didukung. Misalnya, Anda dapat mengubah eval Batchsize seperti ini:
# ISEKAI10
accelerate launch --num_processes 4
--main_process_port 23786
mllm/pipeline/finetune.py
config/shikra_eval_multi_pope.py
--cfg-options data_args.use_icl=True
--cfg-options model_args.model_name_or_path=/path/to/checkpoint
--per_device_eval_batch_size 1
# ISEKAI-PAIR
accelerate launch --num_processes 4
--main_process_port 23786
mllm/pipeline/finetune.py
config/shikra_eval_multi_pope.py
--cfg-options data_args.use_icl=True
--cfg-options model_args.model_name_or_path=/path/to/checkpoint
--per_device_eval_batch_size 1
di mana --cfg-options a=balabala b=balabala
adalah argumen gaya mmengine. Mereka akan menimpa argumen yang telah ditentukan sebelumnya dalam file config. Dan --per_device_eval_batch_size
adalah Huggingface: Argumen Pelatih.
Hasil prediksi akan disimpan di output_dir/multitest_xxxx_extra_prediction.jsonl
, yang memiliki urutan yang sama dengan dataset input.
@inproceedings { tai2023link ,
title = { Link-Context Learning for Multimodal LLMs } ,
author = { Tai, Yan and Fan, Weichen and Zhang, Zhao and Liu, Ziwei } ,
booktitle = { Proceedings of the IEEE/CVF conference on computer vision and pattern recognition (CVPR) } ,
year = { 2024 }
}