重要的
查看新的表格 DL 模型:TabM
arXiv ? Python 包其他表格 DL 项目
这是论文“Revisiting Deep Learning Models for Tabular Data”的正式实现。
一句话:类似 MLP 的模型仍然是很好的基线,而 FT-Transformer 是针对表格数据问题的 Transformer 架构的一种新的强大改编。
本文重点讨论表格数据问题的架构。结果:
package/
目录中的 Python 包是在实践和未来工作中使用本文的推荐方式。
文件的其余部分:
output/
目录包含本文中使用的各种模型和数据集的大量结果和(调整的)超参数。
例如,让我们探讨一下 MLP 模型的指标。首先,让我们加载报告( stats.json
文件):
import json
from pathlib import Path
import pandas as pd
df = pd . json_normalize ([
json . loads ( x . read_text ())
for x in Path ( 'output' ). glob ( '*/mlp/tuned/*/stats.json' )
])
现在,对于每个数据集,我们计算所有随机种子的平均测试分数:
print ( df . groupby ( 'dataset' )[ 'metrics.test.score' ]. mean (). round ( 3 ))
输出与论文中的表 2 完全匹配:
dataset
adult 0.852
aloi 0.954
california_housing -0.499
covtype 0.962
epsilon 0.898
helena 0.383
higgs_small 0.723
jannis 0.719
microsoft -0.747
yahoo -0.757
year -8.853
Name: metrics.test.score, dtype: float64
上述方法还可用于探索超参数,以直观地了解不同算法的典型超参数值。例如,以下是计算 MLP 模型的中值调整学习率的方法:
笔记
对于某些算法(例如 MLP),最近的项目提供了更多可以以类似方式探索的结果。例如,请参阅 TabR 上的这篇论文。
警告
请谨慎使用此方法。研究超参数值时:
print ( df [ df [ 'config.seed' ] == 0 ][ 'config.training.lr' ]. quantile ( 0.5 ))
# Output: 0.0002161505605899536
笔记
这一段很长。在文本编辑器中使用 GitHub 上的“大纲”功能来获取本节的概述。
代码组织如下:
bin
:ensemble.py
执行集成tune.py
执行超参数调整analysis_gbdt_vs_nn.py
运行实验create_synthetic_data_plots.py
构建绘图lib
包含bin
中程序使用的常用工具output
包含配置文件( bin
中程序的输入)和结果(指标、调整配置等)package
包含本文的Python包安装康达
export PROJECT_DIR= < ABSOLUTE path to the repository root >
# example: export PROJECT_DIR=/home/myusername/repositories/revisiting-models
git clone https://github.com/yandex-research/tabular-dl-revisiting-models $PROJECT_DIR
cd $PROJECT_DIR
conda create -n revisiting-models python=3.8.8
conda activate revisiting-models
conda install pytorch==1.7.1 torchvision==0.8.2 cudatoolkit=10.1.243 numpy=1.19.2 -c pytorch -y
conda install cudnn=7.6.5 -c anaconda -y
pip install -r requirements.txt
conda install nodejs -y
jupyter labextension install @jupyter-widgets/jupyterlab-manager
# if the following commands do not succeed, update conda
conda env config vars set PYTHONPATH= ${PYTHONPATH} : ${PROJECT_DIR}
conda env config vars set PROJECT_DIR= ${PROJECT_DIR}
conda env config vars set LD_LIBRARY_PATH= ${CONDA_PREFIX} /lib: ${LD_LIBRARY_PATH}
conda env config vars set CUDA_HOME= ${CONDA_PREFIX}
conda env config vars set CUDA_ROOT= ${CONDA_PREFIX}
conda deactivate
conda activate revisiting-models
仅在尝试 TabNet 时才需要此环境。对于所有其他情况,请使用 PyTorch 环境。
这些说明与 PyTorch 环境相同(包括 PyTorch 的安装!),但是:
python=3.7.10
cudatoolkit=10.0
pip install -r requirements.txt
之前执行以下操作:pip install tensorflow-gpu==1.14
requirements.txt
中的tensorboard
许可证:通过下载我们的数据集,您接受其所有组件的许可证。除了这些许可证之外,我们不会施加任何新的限制。您可以在我们论文的“参考文献”部分找到来源列表。
wget https://www.dropbox.com/s/o53umyg6mn3zhxy/data.tar.gz?dl=1 -O revisiting_models_data.tar.gz
mv revisiting_models_data.tar.gz $PROJECT_DIR
cd $PROJECT_DIR
tar -xvf revisiting_models_data.tar.gz
本节仅提供具体命令,注释很少。完成本教程后,我们建议您查看下一部分,以更好地了解如何使用存储库。它还将有助于更好地理解本教程。
在本教程中,我们将在加州住房数据集上重现 MLP 的结果。我们将涵盖:
请注意,获得完全相同结果的机会相当低,但是,它们应该与我们的结果相差不大。在运行任何内容之前,请转到存储库的根目录并显式设置CUDA_VISIBLE_DEVICES
(如果您打算使用 GPU):
cd $PROJECT_DIR
export CUDA_VISIBLE_DEVICES=0
在开始之前,我们先检查一下环境是否配置成功。以下命令应在加州住房数据集上训练一个 MLP:
mkdir draft
cp output/california_housing/mlp/tuned/0.toml draft/check_environment.toml
python bin/mlp.py draft/check_environment.toml
结果应该位于目录draft/check_environment
中。目前,结果的内容并不重要。
我们在加州住房数据集上调整 MLP 的配置位于output/california_housing/mlp/tuning/0.toml
。为了重现调整,请复制我们的配置并运行您的调整:
# you can choose any other name instead of "reproduced.toml"; it is better to keep this
# name while completing the tutorial
cp output/california_housing/mlp/tuning/0.toml output/california_housing/mlp/tuning/reproduced.toml
# let's reduce the number of tuning iterations to make tuning fast (and ineffective)
python -c "
from pathlib import Path
p = Path('output/california_housing/mlp/tuning/reproduced.toml')
p.write_text(p.read_text().replace('n_trials = 100', 'n_trials = 5'))
"
python bin/tune.py output/california_housing/mlp/tuning/reproduced.toml
您的调整结果将位于output/california_housing/mlp/tuning/reproduced
,您可以将其与我们的进行比较: output/california_housing/mlp/tuning/0
。文件best.toml
包含我们将在下一节中评估的最佳配置。
现在我们必须使用 15 个不同的随机种子来评估调整后的配置。
# create a directory for evaluation
mkdir -p output/california_housing/mlp/tuned_reproduced
# clone the best config from the tuning stage with 15 different random seeds
python -c "
for seed in range(15):
open(f'output/california_housing/mlp/tuned_reproduced/{seed}.toml', 'w').write(
open('output/california_housing/mlp/tuning/reproduced/best.toml').read().replace('seed = 0', f'seed = {seed}')
)
"
# train MLP with all 15 configs
for seed in {0..14}
do
python bin/mlp.py output/california_housing/mlp/tuned_reproduced/ ${seed} .toml
done
我们的评估结果目录就位于您的目录旁边,即位于output/california_housing/mlp/tuned
。
# just run this single command
python bin/ensemble.py mlp output/california_housing/mlp/tuned_reproduced
您的结果将位于output/california_housing/mlp/tuned_reproduced_ensemble
,您可以将其与我们的结果进行比较: output/california_housing/mlp/tuned_ensemble
。
使用此处描述的方法总结所进行的实验的结果(相应地修改.glob(...)
中的路径过滤器: tuned
-> tuned_reproduced
)。
可以对所有模型和数据集执行类似的步骤。网格搜索的调整过程略有不同:您必须运行所有所需的配置,并根据验证性能手动选择最佳配置。例如,请参见output/epsilon/ft_transformer
。
您应该从存储库的根目录运行 Python 脚本。大多数程序都期望配置文件作为唯一的参数。输出将是一个与配置同名的目录,但没有扩展名。配置是用 TOML 编写的。未提供程序的可能参数列表,应从脚本中推断出来(通常,配置用脚本中的args
变量表示)。如果要使用 CUDA,则必须显式设置CUDA_VISIBLE_DEVICES
环境变量。例如:
# The result will be at "path/to/my_experiment"
CUDA_VISIBLE_DEVICES=0 python bin/mlp.py path/to/my_experiment.toml
# The following example will run WITHOUT CUDA
python bin/mlp.py path/to/my_experiment.toml
如果你打算一直使用CUDA,可以将环境变量保存在Conda环境中:
conda env config vars set CUDA_VISIBLE_DEVICES= " 0 "
-f
( --force
) 选项将删除现有结果并从头开始运行脚本:
python bin/whatever.py path/to/config.toml -f # rewrites path/to/config
bin/tune.py
支持延续:
python bin/tune.py path/to/config.toml --continue
stats.json
和其他结果对于所有脚本, stats.json
是输出中最重要的部分。内容因节目而异。它可以包含:
通常还会保存训练集、验证集和测试集的预测。
现在,您知道重现所有结果并扩展此存储库以满足您的需求所需的一切。现在教程也应该更加清晰了。请随意提出问题并提出问题。
@inproceedings{gorishniy2021revisiting,
title={Revisiting Deep Learning Models for Tabular Data},
author={Yury Gorishniy and Ivan Rubachev and Valentin Khrulkov and Artem Babenko},
booktitle={{NeurIPS}},
year={2021},
}