官方的Pytorch實施“多模式LLMS的鏈接封閉式學習” [CVPR 2024]。
該存儲庫包含以下論文的官方實施和數據集:
多模式LLMS的鏈接封閉式學習
https://arxiv.org/abs/2308.07891摘要:通過新穎概念從上下文中學習並提供適當反應的能力在人類對話中至關重要。儘管當前的多模式大型語言模型(MLLM)和大型語言模型(LLMS)在大型數據集中訓練,但以無訓練的方式識別看不見的圖像或理解新穎概念仍然是一個挑戰。內部文化學習(ICL)探索了無培訓的幾次學習,鼓勵模型從有限任務中“學習”並推廣到看不見的任務。在這項工作中,我們提出了Link-Context學習(LCL),該學習強調了“從因果關係和效果的推理”來增強MLLM的學習能力。通過明確加強支持集與查詢集之間的因果關係,LCL超越了傳統ICL。通過提供因果關係的演示,LCL引導該模型不僅可以辨別類比,而且還可以識別數據點之間的基本因果關係,這使MLLM賦予了MLLM的能力,可以更有效地識別看不見的圖像並更有效地理解新穎的概念。為了促進對這種新穎方法的評估,我們介紹了ISEKAI數據集,該數據集由專為鏈接封閉式學習而設計的未見生成的圖像標籤對組成。廣泛的實驗表明,我們的LCL-MLLM對香草MLLM的新穎概念具有強大的鏈接性學習能力。
conda create -n lcl python=3.10
conda activate lcl
pip install -r requirements.txt
accelerate config
我們在Rebuild Imagenet-900套件上訓練LCL設置,並在Imagenet-100套件上評估模型。您可以在此處獲取數據集JSON。
我們評估Isekai-10和Isekai Pair上的模型,您可以在Isekai-10和Isekai Pair下載Isekai數據集。
在Huggingface中下載我們的LCL-2Way重量和LCL-MIX檢查點。
要啟動Gradio Web演示,請使用以下命令。請注意,該模型在Torch.float16格式中進行評估,該格式需要具有至少16GB內存的GPU。
python ./mllm/demo/demo.py --model_path /path/to/lcl/ckpt
儘管以犧牲某些績效為代價,但也可以將其用於8位量化。
python ./mllm/demo/demo.py --model_path /path/to/lcl/ckpt --load_in_8bit
準備數據後,您可以使用命令訓練模型:
accelerate launch --num_processes 4
--main_process_port 23786
mllm/pipeline/finetune.py
config/lcl_train_2way_weight.py
--cfg-options data_args.use_icl=True
--cfg-options model_args.model_name_or_path=/path/to/init/checkpoint
accelerate launch --num_processes 4
--main_process_port 23786
mllm/pipeline/finetune.py
config/lcl_train_mix1.py
--cfg-options data_args.use_icl=True
--cfg-options model_args.model_name_or_path=/path/to/init/checkpoint
準備數據後,您可以使用命令推導模型:
accelerate launch --num_processes 4
--main_process_port 23786
mllm/pipeline/finetune.py
config/lcl_eval_ISEKAI_10.py
--cfg-options data_args.use_icl=True
--cfg-options model_args.model_name_or_path=/path/to/checkpoint
Mmengine風格的Args和HuggingFace:培訓師ARGS得到了支持。例如,您可以像這樣更改評估批處理:
# ISEKAI10
accelerate launch --num_processes 4
--main_process_port 23786
mllm/pipeline/finetune.py
config/shikra_eval_multi_pope.py
--cfg-options data_args.use_icl=True
--cfg-options model_args.model_name_or_path=/path/to/checkpoint
--per_device_eval_batch_size 1
# ISEKAI-PAIR
accelerate launch --num_processes 4
--main_process_port 23786
mllm/pipeline/finetune.py
config/shikra_eval_multi_pope.py
--cfg-options data_args.use_icl=True
--cfg-options model_args.model_name_or_path=/path/to/checkpoint
--per_device_eval_batch_size 1
其中--cfg-options a=balabala b=balabala
是mmengine樣式參數。他們將覆蓋配置文件中預定義的參數。 --per_device_eval_batch_size
是huggingface:trainer參數。
預測結果將保存在output_dir/multitest_xxxx_extra_prediction.jsonl
中,該訂單與輸入數據集相同。
@inproceedings { tai2023link ,
title = { Link-Context Learning for Multimodal LLMs } ,
author = { Tai, Yan and Fan, Weichen and Zhang, Zhao and Liu, Ziwei } ,
booktitle = { Proceedings of the IEEE/CVF conference on computer vision and pattern recognition (CVPR) } ,
year = { 2024 }
}