Une implémentation d'arbres de régression bayésiens additifs (BART) dans JAX.
Si vous ne savez pas ce qu'est BART, mais connaissez XGBoost, considérez BART comme une sorte de XGBoost bayésien. bartz fait fonctionner BART aussi vite que XGBoost.
BART est une technique de régression bayésienne non paramétrique. Compte tenu des prédicteurs de formation
Ce module Python fournit une implémentation de BART qui s'exécute sur GPU, pour traiter plus rapidement de grands ensembles de données. C'est également bon sur le CPU. La plupart des autres implémentations de BART sont destinées à R et fonctionnent uniquement sur CPU.
Sur CPU, bartz fonctionne à la vitesse de dbarts (l'implémentation la plus rapide que je connaisse) si n > 20 000, mais en utilisant 1/20 de la mémoire. Sur GPU, la prime de vitesse dépend de la taille de l'échantillon ; c'est pratique sur CPU uniquement pour n > 10 000. L'accélération maximale est actuellement de 200x, sur un Nvidia A100 et avec au moins 2 000 000 d'observations.
Ce notebook Colab exécute bartz avec n = 100 000 observations, p = 1 000 prédicteurs, 10 000 arbres, pour 1 000 itérations MCMC, en 5 minutes.
Article : Petrillo (2024), « Arbres de régression additive bayésienne très rapides sur GPU », arXiv :2410.23244.
Pour citer directement le logiciel, y compris la version spécifique, utilisez zenodo.