bartz
v chain samples we made along the way
JAX 中贝叶斯加法回归树 (BART) 的实现。
如果您不知道 BART 是什么,但知道 XGBoost,请将 BART 视为一种贝叶斯 XGBoost。 bartz 使 BART 的运行速度与 XGBoost 一样快。
BART 是一种非参数贝叶斯回归技术。给定训练预测变量
该 Python 模块提供了在 GPU 上运行的 BART 实现,可以更快地处理大型数据集。 CPU 上也有不错的表现。 BART 的大多数其他实现都是针对 R 的,并且仅在 CPU 上运行。
在 CPU 上,如果 n > 20,000,bartz 将以 dbarts(我所知道的最快实现)的速度运行,但使用 1/20 的内存。在 GPU 上,速度溢价取决于样本大小;仅当 n > 10,000 时才比 CPU 方便。目前,在 Nvidia A100 上且至少有 2,000,000 个观测值时,最大加速率为 200 倍。
这个 Colab 笔记本在 5 分钟内运行 bartz,其中 n = 100,000 个观测值、p = 1000 个预测变量、10,000 棵树,进行 1000 次 MCMC 迭代。
文章:Petrillo (2024),“GPU 上的极快贝叶斯加性回归树”,arXiv:2410.23244。
要直接引用软件,包括具体版本,请使用zenodo。