一個俳句函式庫,使用 JAX 中的xmap
/ pjit
運算子來實現變壓器的模型平行性。
並行方案類似於原始的 Megatron-LM,由於高速 2d 網格網絡,它在 TPU 上非常有效率。還有一個實現 ZeRo 風格分片的實驗模型版本。
此函式庫設計用於在 TPUv3 上擴展至約 40B 參數,超出此範圍應使用不同的平行策略。請參閱其他實現,例如 GPT-NeoX 或 DeepSpeed。
未來的研究方向之一是將這個程式碼庫與 swarm-jax 集成,以透過管道並行性實現進一步的可擴展性。
12-07-21 :新增了微調指南
在 The Pile 上訓練的 60 億個參數的自回歸文字生成模型。
下載 slim 權重(僅限 BF16 權重,用於推理,9GB)
下載完整權重(包括優化器參數,61GB)
部分訓練的檢查站
Colab 演示
網路示範
阿蘭的部落格文章
如果沒有 TPU 研究雲端在 EleutherAI 的協助下慷慨提供的運算能力,這個專案就不可能實現。
感謝 Google Cloud TPU 團隊提供對 Cloud TPU VM alpha 的早期存取(現已公開發布!)
感謝所有以這種或那種方式提供幫助的人(按字母順序列出):
GPT-J-6B 的權重根據 Apache 許可證 2.0 版本獲得許可。
超參數 | 價值 |
---|---|
n_參數 | 6,053,381,344 |
n_層數 | 28* |
d_模型 | 4,096 |
d_ff | 16,384 |
n_頭 | 16 |
d_頭 | 256 |
n_ctx | 2,048 |
n_詞彙 | 50,257(與 GPT-2/3 相同的標記器) |
位置編碼 | 旋轉位置編碼 (RoPE) |
繩索尺寸 | 64 |
*
每層由一個前饋塊和一個自註意力塊組成
此模型由 28 層組成,模型維度為 4096,前饋維度為 16384。該模型使用與 GPT-2/GPT-3 相同的 BPE 集,使用 50257 的標記化詞彙進行訓練。
模型大致按效能排序,如果不可用,則按失敗次數排序。
模型 | 重量 | 訓練失敗次數 | 蘭巴達 PPL ↓ | 蘭巴達加速器 ↑ | 溫諾格蘭德 ↑ | 海拉斯瓦格 ↑ | PIQA↑ | 資料集大小 (GB) |
---|---|---|---|---|---|---|---|---|
機會 | ✔ | 0 | 〜很多 | ~0% | 50% | 25% | 25% | 0 |
GPT-3-Ada‡ | ✘ | ----- | 9.95 | 51.6% | 52.9% | 43.4% | 70.5% | ----- |
GPT-2-1.5B | ✔ | ----- | 10.63 | 51.21% | 59.4% | 50.9% | 70.8% | 40 |
GPTNeo-1.3B‡ | ✔ | 3.0e21 | 7.50 | 57.2% | 55.0% | 48.9% | 71.1% | 第825章 |
威震天-2.5B* | ✘ | 2.4e21 | ----- | 61.7% | ----- | ----- | ----- | 174 |
GPTNeo-2.7B‡ | ✔ | 6.8e21 | 5.63 | 62.2% | 56.5% | 55.8% | 73.0% | 第825章 |
GPT-3-1.3B*‡ | ✘ | 2.4e21 | 5.44 | 63.6% | 58.7% | 54.7% | 75.1% | ~800 |
GPT-3-巴貝奇‡ | ✘ | ----- | 5.58 | 62.4% | 59.0% | 54.5% | 75.5% | ----- |
威震天-8.3B* | ✘ | 7.8e21 | ----- | 66.5% | ----- | ----- | ----- | 174 |
GPT-3-2.7B*‡ | ✘ | 4.8e21 | 4.60 | 67.1% | 62.3% | 62.8% | 75.6% | ~800 |
威震天 11B† | ✔ | 1.0e22 | ----- | ----- | ----- | ----- | ----- | 161 |
GPT-J-6B ‡ | ✔ | 1.5e22 | 3.99 | 69.7% | 65.3% | 66.1% | 76.5% | 第825章 |
GPT-3-6.7B*‡ | ✘ | 1.2e22 | 4.00 | 70.3% | 64.5% | 67.4% | 78.0% | ~800 |
GPT-3-居禮‡ | ✘ | ----- | 4.00 | 69.3% | 65.6% | 68.5% | 77.9% | ----- |
GPT-3-13B*‡ | ✘ | 2.3e22 | 3.56 | 72.5% | 67.9% | 70.9% | 78.5% | ~800 |
GPT-3-175B*‡ | ✘ | 3.1e23 | 3.00 | 76.2% | 70.2% | 78.9% | 81.0% | ~800 |
GPT-3-達文西‡ | ✘ | ----- | 3.0 | 75% | 72% | 78% | 80% | ----- |
地鼠230B* | ✘ | 6.31E+23 | ----- | 74.50% | 70.10% | 79.20% | 81.80% | 第1344章 |
MT-NLG 530B*‡ | ✘ | ----- | ----- | 76.6% | 73.0% | 80.2% | 82.0% | ----- |
*
代表各自作者報告的評估數字,所有其他數字均透過使用發布的權重或 API 存取運行 lm-evaluation-harness 來提供。由於細微的實現差異以及不同的零樣本任務框架,這些可能無法直接比較。請參閱此部落格文章以了解更多詳細資訊。
†
Megatron-11B 模型沒有提供可比較的指標,並且使用發布的權重的幾種實現無法重現生成品質和評估。 (參見 1 2 3)因此,未嘗試進行評估。
‡
這些模型已經使用包含可能的測試集污染的資料進行了訓練。 OpenAI GPT-3 模型未能對某些測試集的訓練資料進行重複資料刪除,而 GPT-Neo 模型以及該模型是在 The Pile 上進行訓練的,而 The Pile 尚未針對任何測試集進行重複資料刪除。
該儲存庫中的大多數腳本都設計為在 TPU 上運行,TPU-VM 架構下的 TPU 是可以運行任意程式碼的虛擬機器。大多數腳本旨在啟動 TPU、透過 SSH 連接到其中以設定依賴項並從本地目錄複製程式碼,然後啟動可以接受 RPC 呼叫的 Ray 工作執行緒。
TPUVM 處理運行模型訓練步驟和評估、檢查點保存和加載,而驅動程式 python 程式處理資料加載和一般編排(例如何時保存檢查點等)。
這意味著大多數腳本( train.py
、 eval_harness.py
等)希望在與 TPU 相同區域的 GCE 虛擬機器上運行,以最大限度地減少 RPC 延遲和資料傳輸成本。其他腳本(通常是不含--tpu
參數的腳本,例如device_sample.py
、 device_serve.py
或device_train.py
)期望直接在 TPUVM 上執行。 device_* 腳本僅適用於 v3-8 ,不適用於較大的 pod。
此外,還有一個範例 ( resharding_example.py
),說明如何將提供的檢查點(在 GPT-J-6B 的情況下有 8 個分片)轉換為較小的數量,例如在 GPU 上執行時。
若要微調模型,請在 TPU VM 上執行device_train.py
。使用 TPU v3-8,您可以以約 5000 個令牌/秒的速率進行微調,這對於中小型資料集來說應該足夠了。
請閱讀逐步指南以獲取完整的微調說明。
請注意,該庫對 JAX 版本有一些特定要求。具體來說,要使用 v1 模型(包括 GPT-J 6B),需要jax==0.2.12
。這又取決於jaxlib==0.1.68
。如果不這樣做,您將收到神秘的 xmap 錯誤
但是,要使用 v2 模型程式碼(沒有公開發布的權重),可以使用最新的 JAX 版本。
引用這個儲存庫:
@misc{mesh-transformer-jax,
author = {Wang, Ben},
title = {{Mesh-Transformer-JAX: Model-Parallel Implementation of Transformer Language Model with JAX}},
howpublished = {url{https://github.com/kingoflolz/mesh-transformer-jax}},
year = 2021,
month = May
}
引用GPT-J-6B的重量:
@misc{gpt-j,
author = {Wang, Ben and Komatsuzaki, Aran},
title = {{GPT-J-6B: A 6 Billion Parameter Autoregressive Language Model}},
howpublished = {url{https://github.com/kingoflolz/mesh-transformer-jax}},
year = 2021,
month = May
}
如果您使用此儲存庫或任何預先訓練的權重來做一些很酷的事情,我們很樂意聽到它。請隨意打開 github 問題或透過電子郵件聯繫(在個人資料中)。