JAX에서 BART(Bayesian Additive Regression Trees)를 구현합니다.
BART가 무엇인지 모르지만 XGBoost를 알고 있다면 BART를 일종의 베이지안 XGBoost로 간주하세요. bartz는 BART를 XGBoost만큼 빠르게 실행합니다.
BART는 비모수 베이지안 회귀 기법입니다. 주어진 훈련 예측변수
이 Python 모듈은 대규모 데이터 세트를 더 빠르게 처리하기 위해 GPU에서 실행되는 BART 구현을 제공합니다. CPU에서도 좋습니다. 대부분의 다른 BART 구현은 R용이며 CPU에서만 실행됩니다.
CPU에서 bartz는 n > 20,000인 경우 dbarts(내가 아는 가장 빠른 구현)의 속도로 실행되지만 메모리의 1/20을 사용합니다. GPU에서 속도 프리미엄은 샘플 크기에 따라 다릅니다. n > 10,000인 경우에만 CPU보다 편리합니다. 최대 속도 향상은 현재 Nvidia A100에서 최소 2,000,000개의 관찰에서 200배입니다.
이 Colab 노트북은 n = 100,000개의 관측치, p = 1000개의 예측 변수, 10,000개의 트리를 사용하여 1000회의 MCMC 반복에 대해 5분 안에 bartz를 실행합니다.
기사: Petrillo(2024), "GPU에서 매우 빠른 베이지안 가산 회귀 트리", arXiv:2410.23244.
특정 버전을 포함하여 소프트웨어를 직접 인용하려면 zenodo를 사용하십시오.