TabFormer
1.0.0
该存储库提供了 pytorch 源代码以及表格转换器 (TabFormer) 的数据。详细信息请参阅将在 ICASSP 2021 上发表的论文《用于多元时间序列建模的表格变换器》。
(X) 代表测试代码的版本。
这些可以通过运行以下命令使用 yaml 安装:
conda env create -f setup.yml
合成信用卡交易数据集在 ./data/credit_card 中提供。有24M条记录,12个字段。您需要 git-lfs 来访问数据。如果您遇到与 LFS 带宽相关的问题,您可以使用此直接链接来访问数据。然后,您可以通过在git clone ..
命令中添加前缀GIT_LFS_SKIP_SMUDGE=1
来忽略 git-lfs 文件。
对于 PRSA 数据集,必须从 Kaggle 下载 PRSA 数据集并将其放置在 ./data/card 目录中。
要在信用卡交易或 PRSA 数据集上训练表格 BERT 模型,请运行:
$ python main.py --do_train --mlm --field_ce --lm_type bert
--field_hs 64 --data_type [prsa/card]
--output_dir [output_dir]
要针对特定用户 ID 的信用卡交易训练表格 GPT2 模型:
$ python main.py --do_train --lm_type gpt2 --field_ce --flatten --data_type card
--data_root [path_to_data] --user_ids [user-id]
--output_dir [output_dir]
一些选项的描述(更多可以在args.py
中找到):
--data_type
对于北京PM2.5数据集和信用卡交易数据集分别选择prsa
和card
。--mlm
用于掩码语言模型; BERT 变压器训练器选项--field_hs
场级变压器的隐藏大小bert
和gpt2
的--lm_type
选择--user_ids
选项仅从特定用户 ID 中选择交易。 @inproceedings{padhi2021tabular,
title={Tabular transformers for modeling multivariate time series},
author={Padhi, Inkit and Schiff, Yair and Melnyk, Igor and Rigotti, Mattia and Mroueh, Youssef and Dognin, Pierre and Ross, Jerret and Nair, Ravi and Altman, Erik},
booktitle={ICASSP 2021-2021 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
pages={3565--3569},
year={2021},
organization={IEEE},
url={https://ieeexplore.ieee.org/document/9414142}
}