bartz
v chain samples we made along the way
JAX 中貝葉斯加法回歸樹 (BART) 的實現。
如果您不知道 BART 是什麼,但知道 XGBoost,請將 BART 視為一種貝葉斯 XGBoost。 bartz 讓 BART 的運行速度與 XGBoost 一樣快。
BART 是一種非參數貝葉斯迴歸技術。給定訓練預測變量
這個 Python 模組提供了在 GPU 上運行的 BART 實現,可以更快地處理大型資料集。 CPU 上也有不錯的表現。 BART 的大多數其他實作都是針對 R 的,並且僅在 CPU 上運行。
在 CPU 上,如果 n > 20,000,bartz 將以 dbarts(我所知道的最快實現)的速度運行,但使用 1/20 的記憶體。在 GPU 上,速度溢價取決於樣本大小;僅當 n > 10,000 時才比 CPU 方便。目前,在 Nvidia A100 上且至少有 2,000,000 個觀測值時,最大加速率為 200 倍。
這個 Colab 筆記本在 5 分鐘內運行 bartz,其中 n = 100,000 個觀測值、p = 1000 個預測變數、10,000 棵樹,進行 1000 次 MCMC 迭代。
文章:Petrillo (2024),“GPU 上的極快貝葉斯加性回歸樹”,arXiv:2410.23244。
若要直接引用軟體,包括具體版本,請使用zenodo。