Официальная реализация Pytorch "Link-Context Learning для мультимодальных LLMS" [CVPR 2024].
Этот репозиторий содержит официальную реализацию и набор данных следующей статьи:
Link-Context Learning для мультимодальной LLMS
https://arxiv.org/abs/2308.07891Аннотация: Способность учиться на контексте с новыми концепциями и предоставлять соответствующие ответы необходима в человеческих разговорах. Несмотря на текущие мультимодальные крупные языковые модели (MLLMS) и модели крупных языков (LLMS), обучаемые мегамасштабным наборам данных, распознавание невидимых изображений или понимание новых концепций без тренировки остается проблемой. В In-Context Learning (ICL) исследует обучение без обучения без обучения, где модели рекомендуются «научиться учиться» из ограниченных задач и обобщать до невидимых задач. В этой работе мы предлагаем Link-Context Learning (LCL), которое подчеркивает «рассуждение от причины и следствия», чтобы увеличить возможности обучения MLLMS. LCL выходит за рамки традиционного ICL, явно укрепляя причинно -следственную связь между набором поддержки и набором запросов. Предоставляя демонстрации с причинно -следственными связями, LCL направляет модель, чтобы различать не только аналогию, но и основную причинно -следственную связь между точками данных, что дает возможность MLLM для распознавания невидимых изображений и более эффективно понимать новые концепции. Чтобы облегчить оценку этого нового подхода, мы вводим набор данных Isekai, состоящий исключительно из невидимых сгенерированных пар изображений, предназначенных для обучения контекста. Обширные эксперименты показывают, что наша LCL-MLLM обладает сильными возможностями обучения в связи с контекстом для новых концепций по сравнению с ванильными MLLM.
conda create -n lcl python=3.10
conda activate lcl
pip install -r requirements.txt
accelerate config
Мы тренируем набор LCL на нашем наборе Rebuild ImageNet-900 и оцениваем модель на наборе ImageNet-100. Вы можете получить набор данных JSON здесь.
Мы оцениваем модель на Isekai-10 и Isekai-Pair, вы можете скачать набор данных Isekai в Isekai-10 и Isekai-Pair.
Загрузите наши контрольные точки LCL-2WAY и LCL-MIX в HuggingFace.
Чтобы запустить демонстрацию Gradio Web, используйте следующую команду. Обратите внимание, что модель оценивается в формате Torch.float16, который требует графического процессора с не менее 16 ГБ памяти.
python ./mllm/demo/demo.py --model_path /path/to/lcl/ckpt
Также возможно использовать его в 8-битной квантовании, хотя и за счет жертвы некоторых результатов.
python ./mllm/demo/demo.py --model_path /path/to/lcl/ckpt --load_in_8bit
После подготовки данных вы можете обучить модель, используя команду:
accelerate launch --num_processes 4
--main_process_port 23786
mllm/pipeline/finetune.py
config/lcl_train_2way_weight.py
--cfg-options data_args.use_icl=True
--cfg-options model_args.model_name_or_path=/path/to/init/checkpoint
accelerate launch --num_processes 4
--main_process_port 23786
mllm/pipeline/finetune.py
config/lcl_train_mix1.py
--cfg-options data_args.use_icl=True
--cfg-options model_args.model_name_or_path=/path/to/init/checkpoint
После подготовки данных вы можете сделать вывод модели, используя команду:
accelerate launch --num_processes 4
--main_process_port 23786
mllm/pipeline/finetune.py
config/lcl_eval_ISEKAI_10.py
--cfg-options data_args.use_icl=True
--cfg-options model_args.model_name_or_path=/path/to/checkpoint
Args в стиле Mmengine и Huggingface: поддерживают тренер ARGS. Например, вы можете изменить eval Patchize, как это:
# ISEKAI10
accelerate launch --num_processes 4
--main_process_port 23786
mllm/pipeline/finetune.py
config/shikra_eval_multi_pope.py
--cfg-options data_args.use_icl=True
--cfg-options model_args.model_name_or_path=/path/to/checkpoint
--per_device_eval_batch_size 1
# ISEKAI-PAIR
accelerate launch --num_processes 4
--main_process_port 23786
mllm/pipeline/finetune.py
config/shikra_eval_multi_pope.py
--cfg-options data_args.use_icl=True
--cfg-options model_args.model_name_or_path=/path/to/checkpoint
--per_device_eval_batch_size 1
Где --cfg-options a=balabala b=balabala
аргумент в стиле MMENGINE. Они будут перезаписать аргумент, предопределенный в файле конфигурации. И --per_device_eval_batch_size
-это guggingface: trainer Argiry.
Результат прогнозирования будет сохранен в output_dir/multitest_xxxx_extra_prediction.jsonl
, который содержит тот же порядок, что и набор данных ввода.
@inproceedings { tai2023link ,
title = { Link-Context Learning for Multimodal LLMs } ,
author = { Tai, Yan and Fan, Weichen and Zhang, Zhao and Liu, Ziwei } ,
booktitle = { Proceedings of the IEEE/CVF conference on computer vision and pattern recognition (CVPR) } ,
year = { 2024 }
}