TabFormer
1.0.0
此儲存庫提供了 pytorch 原始程式碼以及表格轉換器 (TabFormer) 的資料。詳細資訊請參閱將在 ICASSP 2021 上發表的論文《用於多元時間序列建模的表格變換器》。
(X) 代表測試程式碼的版本。
這些可以透過執行以下命令使用 yaml 安裝:
conda env create -f setup.yml
合成信用卡交易資料集在 ./data/credit_card 中提供。有24M筆記錄,12個欄位。您需要 git-lfs 來存取資料。如果您遇到與 LFS 頻寬相關的問題,您可以使用此直接連結來存取資料。然後,您可以透過在git clone ..
指令中加入前綴GIT_LFS_SKIP_SMUDGE=1
來忽略 git-lfs 檔案。
對於 PRSA 資料集,必須從 Kaggle 下載 PRSA 資料集並將其放置在 ./data/card 目錄中。
若要在信用卡交易或 PRSA 資料集上訓練表格 BERT 模型,請執行:
$ python main.py --do_train --mlm --field_ce --lm_type bert
--field_hs 64 --data_type [prsa/card]
--output_dir [output_dir]
若要針對特定使用者 ID 的信用卡交易訓練表格 GPT2 模型:
$ python main.py --do_train --lm_type gpt2 --field_ce --flatten --data_type card
--data_root [path_to_data] --user_ids [user-id]
--output_dir [output_dir]
一些選項的描述(更多可以在args.py
中找到):
--data_type
對於北京PM2.5資料集和信用卡交易資料集分別選擇prsa
和card
。--mlm
用於掩碼語言模型; BERT 變壓器訓練器選項--field_hs
場級變壓器的隱藏大小bert
和gpt2
的--lm_type
選擇--user_ids
選項僅從特定使用者 ID 中選擇交易。 @inproceedings{padhi2021tabular,
title={Tabular transformers for modeling multivariate time series},
author={Padhi, Inkit and Schiff, Yair and Melnyk, Igor and Rigotti, Mattia and Mroueh, Youssef and Dognin, Pierre and Ross, Jerret and Nair, Ravi and Altman, Erik},
booktitle={ICASSP 2021-2021 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
pages={3565--3569},
year={2021},
organization={IEEE},
url={https://ieeexplore.ieee.org/document/9414142}
}