官方的Pytorch实施“多模式LLMS的链接封闭式学习” [CVPR 2024]。
该存储库包含以下论文的官方实施和数据集:
多模式LLMS的链接封闭式学习
https://arxiv.org/abs/2308.07891摘要:通过新颖概念从上下文中学习并提供适当反应的能力在人类对话中至关重要。尽管当前的多模式大型语言模型(MLLM)和大型语言模型(LLMS)在大型数据集中训练,但以无训练的方式识别看不见的图像或理解新颖概念仍然是一个挑战。内部文化学习(ICL)探索了无培训的几次学习,鼓励模型从有限任务中“学习”并推广到看不见的任务。在这项工作中,我们提出了Link-Context学习(LCL),该学习强调了“从因果关系和效果的推理”来增强MLLM的学习能力。通过明确加强支持集与查询集之间的因果关系,LCL超越了传统ICL。通过提供因果关系的演示,LCL引导该模型不仅可以辨别类比,而且还可以识别数据点之间的基本因果关系,这使MLLM赋予了MLLM的能力,可以更有效地识别看不见的图像并更有效地理解新颖的概念。为了促进对这种新颖方法的评估,我们介绍了ISEKAI数据集,该数据集由专为链接封闭式学习而设计的未见生成的图像标签对组成。广泛的实验表明,我们的LCL-MLLM对香草MLLM的新颖概念具有强大的链接性学习能力。
conda create -n lcl python=3.10
conda activate lcl
pip install -r requirements.txt
accelerate config
我们在Rebuild Imagenet-900套件上训练LCL设置,并在Imagenet-100套件上评估模型。您可以在此处获取数据集JSON。
我们评估Isekai-10和Isekai Pair上的模型,您可以在Isekai-10和Isekai Pair下载Isekai数据集。
在Huggingface中下载我们的LCL-2Way重量和LCL-MIX检查点。
要启动Gradio Web演示,请使用以下命令。请注意,该模型在Torch.float16格式中进行评估,该格式需要具有至少16GB内存的GPU。
python ./mllm/demo/demo.py --model_path /path/to/lcl/ckpt
尽管以牺牲某些绩效为代价,但也可以将其用于8位量化。
python ./mllm/demo/demo.py --model_path /path/to/lcl/ckpt --load_in_8bit
准备数据后,您可以使用命令训练模型:
accelerate launch --num_processes 4
--main_process_port 23786
mllm/pipeline/finetune.py
config/lcl_train_2way_weight.py
--cfg-options data_args.use_icl=True
--cfg-options model_args.model_name_or_path=/path/to/init/checkpoint
accelerate launch --num_processes 4
--main_process_port 23786
mllm/pipeline/finetune.py
config/lcl_train_mix1.py
--cfg-options data_args.use_icl=True
--cfg-options model_args.model_name_or_path=/path/to/init/checkpoint
准备数据后,您可以使用命令推导模型:
accelerate launch --num_processes 4
--main_process_port 23786
mllm/pipeline/finetune.py
config/lcl_eval_ISEKAI_10.py
--cfg-options data_args.use_icl=True
--cfg-options model_args.model_name_or_path=/path/to/checkpoint
Mmengine风格的Args和HuggingFace:培训师ARGS得到了支持。例如,您可以像这样更改评估批处理:
# ISEKAI10
accelerate launch --num_processes 4
--main_process_port 23786
mllm/pipeline/finetune.py
config/shikra_eval_multi_pope.py
--cfg-options data_args.use_icl=True
--cfg-options model_args.model_name_or_path=/path/to/checkpoint
--per_device_eval_batch_size 1
# ISEKAI-PAIR
accelerate launch --num_processes 4
--main_process_port 23786
mllm/pipeline/finetune.py
config/shikra_eval_multi_pope.py
--cfg-options data_args.use_icl=True
--cfg-options model_args.model_name_or_path=/path/to/checkpoint
--per_device_eval_batch_size 1
其中--cfg-options a=balabala b=balabala
是mmengine样式参数。他们将覆盖配置文件中预定义的参数。 --per_device_eval_batch_size
是huggingface:trainer参数。
预测结果将保存在output_dir/multitest_xxxx_extra_prediction.jsonl
中,该订单与输入数据集相同。
@inproceedings { tai2023link ,
title = { Link-Context Learning for Multimodal LLMs } ,
author = { Tai, Yan and Fan, Weichen and Zhang, Zhao and Liu, Ziwei } ,
booktitle = { Proceedings of the IEEE/CVF conference on computer vision and pattern recognition (CVPR) } ,
year = { 2024 }
}