トランスフォーマーのモデル並列処理に JAX のxmap
/ pjit
演算子を使用する Haiku ライブラリ。
並列処理スキームはオリジナルの Megatron-LM に似ており、高速 2D メッシュ ネットワークにより TPU 上で効率的です。 ZeRo スタイルのシャーディングを実装した実験モデル バージョンもあります。
このライブラリは、TPUv3 上で最大約 40B パラメータまでのスケーラビリティを考慮して設計されており、それを超えると、別の並列処理戦略を使用する必要があります。これについては、GPT-NeoX や DeepSpeed などの他の実装を参照してください。
今後の研究の方向性の 1 つは、このコードベースを swarm-jax と統合して、パイプライン並列処理によるさらなるスケーラビリティを実現することです。
12-07-21 : 微調整のためのガイドを追加
The Pile でトレーニングされた 60 億のパラメーター、自己回帰テキスト生成モデル。
スリム ウェイトをダウンロード (BF16 ウェイトのみ、推論用、9GB)
フルウェイトをダウンロード (オプティマイザーパラメータを含む、61GB)
部分的にトレーニングされたチェックポイント
コラボデモ
ウェブデモ
アランのブログ投稿
このプロジェクトは、EleutherAI の支援を受けて TPU Research Cloud によって惜しみなく提供されたコンピューティングがなければ不可能でした。
Cloud TPU VM アルファ (現在公開中) への早期アクセスを提供してくださった Google の Cloud TPU チームに感謝します。
何らかの形でご協力いただいた皆様に感謝します (アルファベット順に記載):
GPT-J-6B の重みは、Apache ライセンスのバージョン 2.0 に基づいてライセンスされています。
ハイパーパラメータ | 価値 |
---|---|
n_パラメータ | 6,053,381,344 |
n_layers | 28* |
d_model | 4,096 |
d_ff | 16,384 |
n_heads | 16 |
d_head | 256 |
n_ctx | 2,048 |
n_vocab | 50,257 (GPT-2/3 と同じトークナイザー) |
位置エンコーディング | 回転位置エンコーディング (RoPE) |
RoPE の寸法 | 64 |
*
各レイヤーは 1 つのフィードフォワード ブロックと 1 つのセルフ アテンション ブロックで構成されます
モデルは、モデル次元 4096、フィードフォワード次元 16384 の 28 層で構成されています。モデル次元は、それぞれ 256 次元の 16 個のヘッドに分割されています。回転位置エンコーディング (RoPE) が各ヘッドの 64 次元に適用されました。 。モデルは、GPT-2/GPT-3 と同じ BPE セットを使用して、50257 のトークン化語彙でトレーニングされます。
モデルはパフォーマンス別に、または入手できない場合は FLOP 別に大まかに並べ替えられます。
モデル | 重み | トレーニング FLOP | ランバダPPL ↓ | ランバダ ACC ↑ | ウィノグランデ ↑ | ヘラスワグ ↑ | ピカ ↑ | データセットのサイズ (GB) |
---|---|---|---|---|---|---|---|---|
チャンス | ✔ | 0 | ~たくさん | ~0% | 50% | 25% | 25% | 0 |
GPT-3-エイダ‡ | ✘ | ----- | 9.95 | 51.6% | 52.9% | 43.4% | 70.5% | ----- |
GPT-2-1.5B | ✔ | ----- | 10.63 | 51.21% | 59.4% | 50.9% | 70.8% | 40 |
GPTNeo-1.3B‡ | ✔ | 3.0e21 | 7.50 | 57.2% | 55.0% | 48.9% | 71.1% | 825 |
メガトロン-2.5B* | ✘ | 2.4e21 | ----- | 61.7% | ----- | ----- | ----- | 174 |
GPTNeo-2.7B‡ | ✔ | 6.8e21 | 5.63 | 62.2% | 56.5% | 55.8% | 73.0% | 825 |
GPT-3-1.3B*‡ | ✘ | 2.4e21 | 5.44 | 63.6% | 58.7% | 54.7% | 75.1% | ~800 |
GPT-3-バベッジ‡ | ✘ | ----- | 5.58 | 62.4% | 59.0% | 54.5% | 75.5% | ----- |
メガトロン-8.3B* | ✘ | 7.8e21 | ----- | 66.5% | ----- | ----- | ----- | 174 |
GPT-3-2.7B*‡ | ✘ | 4.8e21 | 4.60 | 67.1% | 62.3% | 62.8% | 75.6% | ~800 |
メガトロン-11B† | ✔ | 1.0e22 | ----- | ----- | ----- | ----- | ----- | 161 |
GPT-J-6B ‡ | ✔ | 1.5e22 | 3.99 | 69.7% | 65.3% | 66.1% | 76.5% | 825 |
GPT-3-6.7B*‡ | ✘ | 1.2e22 | 4.00 | 70.3% | 64.5% | 67.4% | 78.0% | ~800 |
GPT-3-キュリー‡ | ✘ | ----- | 4.00 | 69.3% | 65.6% | 68.5% | 77.9% | ----- |
GPT-3-13B*‡ | ✘ | 2.3e22 | 3.56 | 72.5% | 67.9% | 70.9% | 78.5% | ~800 |
GPT-3-175B*‡ | ✘ | 3.1e23 | 3.00 | 76.2% | 70.2% | 78.9% | 81.0% | ~800 |
GPT-3-ダヴィンチ‡ | ✘ | ----- | 3.0 | 75% | 72% | 78% | 80% | ----- |
ゴーファー 230B* | ✘ | 6.31E+23 | ----- | 74.50% | 70.10% | 79.20% | 81.80% | 1344 |
MT-NLG 530B*‡ | ✘ | ----- | ----- | 76.6% | 73.0% | 80.2% | 82.0% | ----- |
*
それぞれの作成者によって報告された評価番号を表し、他のすべての数値は、リリースされた重みまたは API アクセスを使用して lm-evaluation-harness を実行することによって提供されます。実装の微妙な違いやゼロ ショット タスクのフレーム構成の違いにより、これらは直接比較できない場合があります。詳細については、このブログ投稿を参照してください。
†
Megatron-11B モデルには同等のメトリクスが提供されておらず、リリースされた重みを使用するいくつかの実装では生成品質と評価が再現されません。 (1 2 3 参照) したがって、評価は試みられませんでした。
‡
これらのモデルは、テストセットの汚染の可能性を含むデータを使用してトレーニングされています。 OpenAI GPT-3 モデルは、特定のテスト セットのトレーニング データの重複排除に失敗しましたが、GPT-Neo モデルとこのモデルは、どのテスト セットに対しても重複排除されていない The Pile でトレーニングされています。
このリポジトリ内のほとんどのスクリプトは、TPU 上で実行されるように設計されています。TPU は、TPU-VM アーキテクチャの下では任意のコードを実行できる仮想マシンです。ほとんどのスクリプトは、TPU を起動し、それに SSH 接続して依存関係を設定し、ローカル ディレクトリからコードをコピーしてから、RPC 呼び出しを受け入れることができる Ray ワーカーを起動するように設計されています。
TPUVM は、モデルのトレーニング ステップと評価の実行、チェックポイントの保存とロードを処理し、ドライバー Python プログラムはデータのロードと一般的なオーケストレーション (チェックポイントをいつ保存するかなど) を処理します。
これは、RPC レイテンシとデータ転送コストを最小限に抑えるために、ほとんどのスクリプト ( train.py
、 eval_harness.py
など) が TPU と同じリージョンの GCE 仮想マシン上で実行されることを想定していることを意味します。他のスクリプト (通常、 device_sample.py
、 device_serve.py
、 device_train.py
など、 --tpu
引数をとらないスクリプト) は、TPUVM 上で直接実行されることが想定されています。 device_* スクリプトはv3-8 でのみ機能し、それより大きなポッドでは機能しません。
さらに、GPU で実行する場合などに、提供されたチェックポイント (GPT-J-6B の場合は 8 シャード) をより小さい数に変換する方法の例 ( resharding_example.py
) があります。
モデルを微調整するには、TPU VM でdevice_train.py
実行します。 TPU v3-8 を使用すると、約 5000 トークン/秒の速度で微調整できます。これは小規模から中規模のデータセットには十分です。
徹底的な微調整手順については、ステップバイステップのガイドをお読みください。
このライブラリには、JAX バージョンに固有の要件がいくつかあることに注意してください。具体的には、v1 モデル(GPT-J 6B を含む)を使用するには、 jax==0.2.12
が必要です。これは、 jaxlib==0.1.68
に依存します。これを行わないと、不可解な xmap エラーが発生します。
ただし、v2 モデル コード (公開されていない重み) を使用するには、最新の JAX バージョンを使用できます。
このリポジトリを引用するには:
@misc{mesh-transformer-jax,
author = {Wang, Ben},
title = {{Mesh-Transformer-JAX: Model-Parallel Implementation of Transformer Language Model with JAX}},
howpublished = {url{https://github.com/kingoflolz/mesh-transformer-jax}},
year = 2021,
month = May
}
GPT-J-6B の重量を引用するには:
@misc{gpt-j,
author = {Wang, Ben and Komatsuzaki, Aran},
title = {{GPT-J-6B: A 6 Billion Parameter Autoregressive Language Model}},
howpublished = {url{https://github.com/kingoflolz/mesh-transformer-jax}},
year = 2021,
month = May
}
このリポジトリまたは事前トレーニングされた重みのいずれかを使用して何か素晴らしいことを行う場合は、ぜひお知らせください。お気軽に github の問題を開いていただくか、電子メール (プロフィール内) でお問い合わせください。