Implementação oficial de Pytorch de "Link-Context Learning for Multimodal LLMS" [CVPR 2024].
Este repositório contém a implementação e o conjunto de dados oficiais do artigo a seguir:
Link-Context Learning for Multimodal LLMS
https://arxiv.org/abs/2308.07891Resumo: A capacidade de aprender com o contexto com novos conceitos e fornecer respostas apropriadas é essencial nas conversas humanas. Apesar dos atuais modelos multimodais de grandes idiomas (MLLMs) e de grandes modelos de linguagem (LLMS) sendo treinados em conjuntos de dados em escala, reconhecer imagens invisíveis ou entender novos conceitos de maneira livre de treinamento continua sendo um desafio. A aprendizagem no contexto (ICL) explora o aprendizado sem treinamento de poucas fotos, onde os modelos são incentivados a "aprender a aprender" com tarefas limitadas e generalizar para tarefas invisíveis. Neste trabalho, propomos o Link-Context Learning (LCL), que enfatiza "o raciocínio por causa e efeito" para aumentar as capacidades de aprendizado do MLLMS. A LCL vai além da ICL tradicional, fortalecendo explicitamente a relação causal entre o conjunto de suporte e o conjunto de consultas. Ao fornecer demonstrações com links causais, a LCL orienta o modelo a discernir não apenas a analogia, mas também as associações causais subjacentes entre os pontos de dados, que capacitam as MLLMs a reconhecer imagens invisíveis e entender novos conceitos com mais eficiência. Para facilitar a avaliação desta nova abordagem, introduzimos o conjunto de dados ISEKAI, compreendendo exclusivamente pares de marcas de imagem geradas invisíveis projetadas para o aprendizado de contexto de ligação. Experiências extensas mostram que nosso LCL-MLLM exibe fortes recursos de aprendizado de contexto de ligação a novos conceitos sobre os MLLMs de baunilha.
conda create -n lcl python=3.10
conda activate lcl
pip install -r requirements.txt
accelerate config
Treinamos a configuração da LCL em nosso conjunto de ImageNet-900 de Rebuild e avaliamos o modelo no conjunto ImageNet-100. Você pode obter o conjunto de dados JSON aqui.
Avaliamos o modelo no ISEKAI-10 e ISEKAI-PAIR, você pode baixar o conjunto de dados ISEKAI no ISEKAI-10 e ISEKAI-PAIR.
Faça o download de nossos pontos de verificação LCL-2way-Weight e LCL-Mix no HuggingFace.
Para iniciar uma demonstração da Web Gradio, use o seguinte comando. Observe que o modelo é avaliado no formato Torch.float16, que requer uma GPU com pelo menos 16 GB de memória.
python ./mllm/demo/demo.py --model_path /path/to/lcl/ckpt
Também é possível usá-lo em quantização de 8 bits, embora às custas de sacrificar algum desempenho.
python ./mllm/demo/demo.py --model_path /path/to/lcl/ckpt --load_in_8bit
Depois de preparar dados, você pode treinar o modelo usando o comando:
accelerate launch --num_processes 4
--main_process_port 23786
mllm/pipeline/finetune.py
config/lcl_train_2way_weight.py
--cfg-options data_args.use_icl=True
--cfg-options model_args.model_name_or_path=/path/to/init/checkpoint
accelerate launch --num_processes 4
--main_process_port 23786
mllm/pipeline/finetune.py
config/lcl_train_mix1.py
--cfg-options data_args.use_icl=True
--cfg-options model_args.model_name_or_path=/path/to/init/checkpoint
Depois de preparar dados, você pode ingerir o modelo usando o comando:
accelerate launch --num_processes 4
--main_process_port 23786
mllm/pipeline/finetune.py
config/lcl_eval_ISEKAI_10.py
--cfg-options data_args.use_icl=True
--cfg-options model_args.model_name_or_path=/path/to/checkpoint
MMENGINE STILE ARGS E HUGGINGFACE: O Treinador Args são suportados. Por exemplo, você pode alterar a avaliação do BatchSize assim:
# ISEKAI10
accelerate launch --num_processes 4
--main_process_port 23786
mllm/pipeline/finetune.py
config/shikra_eval_multi_pope.py
--cfg-options data_args.use_icl=True
--cfg-options model_args.model_name_or_path=/path/to/checkpoint
--per_device_eval_batch_size 1
# ISEKAI-PAIR
accelerate launch --num_processes 4
--main_process_port 23786
mllm/pipeline/finetune.py
config/shikra_eval_multi_pope.py
--cfg-options data_args.use_icl=True
--cfg-options model_args.model_name_or_path=/path/to/checkpoint
--per_device_eval_batch_size 1
Onde --cfg-options a=balabala b=balabala
é o argumento do estilo mmengine. Eles substituirão o argumento predefinido no arquivo de configuração. E --per_device_eval_batch_size
é o Huggingface: argumento do treinador.
O resultado da previsão será salvo em output_dir/multitest_xxxx_extra_prediction.jsonl
, que mantém a mesma ordem que o conjunto de dados de entrada.
@inproceedings { tai2023link ,
title = { Link-Context Learning for Multimodal LLMs } ,
author = { Tai, Yan and Fan, Weichen and Zhang, Zhao and Liu, Ziwei } ,
booktitle = { Proceedings of the IEEE/CVF conference on computer vision and pattern recognition (CVPR) } ,
year = { 2024 }
}