该存储库包含为本文开发的软件,
机器阅读理解的协作自我训练,罗华,李世文,高明,于书,Glass J.,NAACL 2022。
尝试我们的带有中等长度段落的现场演示(长文档版本即将推出)。
我们使用以下软件包运行该软件,
预训练模型可通过此 Google Drive 链接获得。如果您的计算机无法从 Huggingface hub 下载模型,请下载模型并将其移至model_file/
目录下。
model_file/ext_sqv2.pt
:在 SQuAD v2.0 上预训练的 ELECTRA-large 问答模型。model_file/ques_gen_squad.pt
:在 SQuAD v2.0 上预训练的 BART 大型问题生成模型。model_file/electra-tokenize.pt
:Huggingface 提供的 Electra-large 分词器。model_file/bart-tokenizer.pt
:Huggingface 提供的 BART-large 分词器。通过运行以下命令,在我们在data/squad/doc_data_0.json
提供的示例 SQuAD 段落上生成问答对,
python RGX _doc.py
--dataset_name squad
--data_split 0
--output_dir tmp/ RGX
--version_2_with_negative
生成的数据将存储在data_gen/squad
下,包括RGX _0.json
和qa_train_corpus_0.json
。我们为分布式数据生成提供$DATA_SPLIT
选项,例如使用 Slurm。如果仅使用一个进程生成 QA 对,只需使用--data_split 0
。
所有数据都存储在data/
和data_gen/
目录中。
data/{$DATASET_NAME}/doc_data_{$DATA_SPLIT}.json
:目标数据集的未标记文档。data_gen/{$DATASET_NAME}/ RGX _{$DATA_SPLIT}.json
:生成的 QA 数据与相应数据集中的每个文档对齐。data_gen/{$DATASET_NAME}/qa_train_corpus_{$DATA_SPLIT}.json
:给定数据集生成的 QA 训练集。训练示例遵循 SQuAD 数据格式并随机打乱。 doc_data_{$DATA_SPLIT}.json
的格式是字典列表,如下所示 [
{"context": INPUT_DOC_TXT__0},
{"context": INPUT_DOC_TXT__1},
...,
{"context": INPUT_DOC_TXT__N},
]
qa_train_corpus_{$DATA_SPLIT}.json
的格式是字典列表,如下所示 [
{
"context": INPUT_DOC_TXT_0,
"question": GEN_QUESTION_TXT_0,
"answers": {
"text": [ ANSWER_TXT ], # only one answer per question
"answer_start": [ ANSWER_ST_CHAR ]
# index of the starting character of the answer txt
}
},
{
...
},
]
RGX _{$DATA_SPLIT}.json
是文档-QA 映射的列表, [
[
$DOCUMENT_i,
$ANS2ITEM_LIST_i,
$GEN_QA_LIST_i
],
...
]
$DOCUMENT_i
与输入文件具有相同的格式。 $ANS2ITEM_LIST_i
是所有已识别答案和生成问题的元数据。请注意,一个答案可以有多个问题,并且问题可以正确也可以不正确。模型的最终输出是$GEN_QA_LIST_i
,它是基于输入文档生成的 QA 对的字典列表,
[
{
"question": GEN_QUESTION_TXT_0,
"answers": {
"text": [ ANSWER_TXT ],
"answer_start": [ ANSWER_ST_CHAR ]
}
}
]
data/
和data_gen/
目录下创建目录, bash new_dataset.sh $NEW_DATASET_NAME
将包含目标文档的输入文件移动为data/$NEW_DATASET_NAME/doc_data_0.json
。该格式已在上一节中描述。
运行以下命令
python RGX _doc.py
--dataset_name $NEW_DATASET_NAME
--data_split 0
--output_dir tmp/ RGX
--version_2_with_negative
生成的文件将存储在data_gen/{$NEW_DATASET_NAME}/
。
我们建议使用两种方法对生成的 QA 对进行 QA 微调。
mix_mode.py
脚本对两个模型的所有权重进行平均 python mix_model.py $MIX_RATE $SQUAD_MODEL_PATH $ RGX _MODEL_PATH
例如,
python mix_model.py 0.5 model_ft_file/ext_sq.pt model_ft_file/ext_ RGX .pt
输出模型将存储为model_ft_file/ext_mixed.pt
。
如有任何问题,请联系第一作者罗宏银(hyluo at mit dot edu)。如果我们的系统应用在您的工作中,请引用我们的论文
@article{luo2021cooperative,
title={Cooperative self-training machine reading comprehension},
author={Luo, Hongyin and Li, Shang-Wen and Mingye Gao, and Yu, Seunghak and Glass, James},
journal={arXiv preprint arXiv:2103.07449},
year={2021}
}