TabFormer
1.0.0
このリポジトリは、pytorch ソース コードと表形式トランスフォーマー (TabFormer) のデータを提供します。詳細は、ICASSP 2021 で発表される論文「Tabular Transformers for Modeling Multivariate Time Series」で説明されています。
(X) は、コードがテストされるバージョンを表します。
これらは、yaml を使用して次を実行してインストールできます。
conda env create -f setup.yml
合成クレジット カード トランザクション データセットは ./data/credit_card で提供されます。 12 フィールドを持つ 2,400 万のレコードがあります。データにアクセスするには git-lfs が必要です。 LFS 帯域幅に関連する問題に直面している場合は、この直接リンクを使用してデータにアクセスできます。 git clone ..
コマンドの先頭にGIT_LFS_SKIP_SMUDGE=1
付けることで、git-lfs ファイルを無視できます。
PRSA データセットの場合、Kaggle から PRSA データセットをダウンロードし、./data/card ディレクトリに配置する必要があります。
クレジット カード トランザクションまたは PRSA データセットで表形式の BERT モデルをトレーニングするには、次を実行します。
$ python main.py --do_train --mlm --field_ce --lm_type bert
--field_hs 64 --data_type [prsa/card]
--output_dir [output_dir]
特定のuser-idのクレジット カード トランザクションで表形式の GPT2 モデルをトレーニングするには、次のようにします。
$ python main.py --do_train --lm_type gpt2 --field_ce --flatten --data_type card
--data_root [path_to_data] --user_ids [user-id]
--output_dir [output_dir]
いくつかのオプションの説明 (詳細はargs.py
にあります):
--data_type
選択肢は、北京 PM2.5 データセットとクレジット カード トランザクション データセットの場合、それぞれprsa
とcard
です。--mlm
マスクされた言語モデルの場合。 BERT用変圧器トレーナーのオプション--field_hs
フィールドレベルトランスフォーマーの非表示サイズ--lm_type
bert
とgpt2
から選択--user_ids
オプションを使用すると、特定のユーザー ID からのトランザクションのみを選択できます。 @inproceedings{padhi2021tabular,
title={Tabular transformers for modeling multivariate time series},
author={Padhi, Inkit and Schiff, Yair and Melnyk, Igor and Rigotti, Mattia and Mroueh, Youssef and Dognin, Pierre and Ross, Jerret and Nair, Ravi and Altman, Erik},
booktitle={ICASSP 2021-2021 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
pages={3565--3569},
year={2021},
organization={IEEE},
url={https://ieeexplore.ieee.org/document/9414142}
}