Ce référentiel fournit le code source de pytorch et les données pour les transformateurs tabulaires (TabFormer). Les détails sont décrits dans l'article Tabular Transformers for Modeling Multivariate Time Series, qui sera présenté à l'ICASSP 2021.
(X) représente les versions sur lesquelles le code est testé.
Ceux-ci peuvent être installés à l'aide de yaml en exécutant :
conda env create -f setup.yml
L'ensemble de données synthétiques sur les transactions par carte de crédit est fourni dans ./data/credit_card. Il existe 24 millions d'enregistrements avec 12 champs. Vous auriez besoin de git-lfs pour accéder aux données. Si vous rencontrez un problème lié à la bande passante LFS, vous pouvez utiliser ce lien direct pour accéder aux données. Vous pouvez ensuite ignorer les fichiers git-lfs en préfixant GIT_LFS_SKIP_SMUDGE=1
à la commande git clone ..
.
Pour l'ensemble de données PRSA, il faut télécharger l'ensemble de données PRSA depuis Kaggle et les placer dans le répertoire ./data/card.
Pour entraîner un modèle BERT tabulaire sur une transaction par carte de crédit ou sur l'exécution d'un ensemble de données PRSA :
$ python main.py --do_train --mlm --field_ce --lm_type bert
--field_hs 64 --data_type [prsa/card]
--output_dir [output_dir]
Pour entraîner un modèle tabulaire GPT2 sur les transactions par carte de crédit pour un identifiant utilisateur particulier :
$ python main.py --do_train --lm_type gpt2 --field_ce --flatten --data_type card
--data_root [path_to_data] --user_ids [user-id]
--output_dir [output_dir]
Description de certaines options (plus d'informations peuvent être trouvées dans args.py
) :
--data_type
sont prsa
et card
pour l'ensemble de données Beijing PM2.5 et l'ensemble de données sur les transactions par carte de crédit respectivement.--mlm
pour le modèle de langage masqué ; option pour transformateur d'entraînement pour BERT--field_hs
taille cachée pour le transformateur de niveau champ--lm_type
choix de bert
et gpt2
--user_ids
pour sélectionner uniquement les transactions à partir d'identifiants d'utilisateur particuliers. @inproceedings{padhi2021tabular,
title={Tabular transformers for modeling multivariate time series},
author={Padhi, Inkit and Schiff, Yair and Melnyk, Igor and Rigotti, Mattia and Mroueh, Youssef and Dognin, Pierre and Ross, Jerret and Nair, Ravi and Altman, Erik},
booktitle={ICASSP 2021-2021 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
pages={3565--3569},
year={2021},
organization={IEEE},
url={https://ieeexplore.ieee.org/document/9414142}
}