Griffin
Griffin : 効率的な言語モデルのためのローカル アテンションとゲート線形再帰の混合
arXiv
モデルのアーキテクチャ
すべてのモデルには、(i) 残差ブロック、(ii) MLP ブロック、(iii) 時間混合ブロックのコンポーネントが含まれています。 (i) と (ii) はすべてのモデルで同じですが、グローバル マルチクエリ アテンション (MQA)、ローカル (スライディング ウィンドウ) MQA、および提案するリカレント ブロックという 3 つの時間的混合ブロックを考慮します。リカレント ブロックの一部として、Real-Gated Linear Recurrent Unit (RG-LRU) を使用します。これは、Linear Recurrent Unit Orvieto et al., 2023b からインスピレーションを得た新しいリカレント レイヤーです。
図 2(a) に示す残差ブロックは、モデルのグローバル構造を定義し、プレノルム トランスフォーマーからインスピレーションを得ています (Xiong et al., 2020)。入力シーケンスを埋め込んだ後、それを渡します $N$そのようなブロック( $N$モデルの深さを示します)、RMSNorm Zhang and Sennrich, 2019 を適用して最終的なアクティベーションを生成します。トークンの確率を計算するには、最後の線形層とそれに続くソフトマックスを適用します。この層の重みは、入力埋め込み層と共有されます。
残留ブロック
図 2: a) モード アーキテクチャの主なバックボーンは、積み上げられた残差ブロックです。 $N$回。 b) 使用するゲート MLP ブロック。 c) マルチクエリーアテンション (MQA) の代替として提案するリカレントブロック。これは、セクション 2.4 で定義されている、私たちが提案する RG-LRU レイヤーを使用します。
残差ブロックには、順番に適用される 2 つのコンポーネントが含まれています。最初のコンポーネントは非表示状態になります $chi$そして、RMSNorm Zhang and Sennrich、2019 を適用し、続いて時間混合ブロックを適用します。次に、出力をスキップ接続とマージします。 $chi$加算を通じて。同様に、2 番目のコンポーネントは RMSNorm を適用し、続いて MLP ブロックを適用し、その出力を RMSNorm の入力からのスキップ接続とマージします。このブロックを図 2 (a) に示します。
MLP ブロック
ゲート MLP ブロック Dauphin et al., 2017 (図 2(b) に示す) を使用し、次元の入力から 2 つの分岐を作成します。 $D$ 。出力次元を持つ線形レイヤーを適用します $MD$各ブランチで、 $M$は膨張率を表します。簡単にするために、次を使用します $M=3$この作品全体を通して。 GeGeLU Shazeer, 2020 と同様に、要素ごとの乗算によってそれらをマージする前に、ブランチの 1 つに GeLU 非線形性 Hendrycks and Gimpel, 2016 を適用します。ただし、MLP ブロックでは、出力次元を持つ最終線形層を適用します。 $D$ GeGeLU 層の出力について。
時間的混合ブロック
時間混合ブロックは、シーケンス内のさまざまな時間位置での隠れ層のアクティベーションを集約するモデルのコンポーネントです。グローバル MQA Shazeer, 2019、ローカル MQA Beltagy et al., 2020、および私たちが提案する Recurrent ブロックの 3 つの時間混合ブロックを検討します。
グローバルマルチクエリアテンション
特に明記されていない限り、Transformer ベースライン Shazeer、2019 の推論速度を向上させるために、MHA ではなく MQA を使用します。固定ヘッド寸法を使用します。 $D_{頭}=128$ 、そしてアテンションヘッドの数を修正します $H$そのような $HD_{頭}=D$ 。これにはモデルの寸法が必要です $D$絶対位置埋め込みは使用しませんが、相対位置埋め込みとして Rotary Position Embedding (RoPE) Su et al., 2021 を使用します。
ローカルスライディングウィンドウの注意
グローバル アテンションを使用することの主な欠点の 1 つは、計算の複雑さがシーケンスの長さに応じて二次関数的に増大することです。これに対処するために、いくつかの研究では、スライディング ウィンドウ アテンションとしても知られるローカル アテンション (Beltagy et al., 2020) を採用し始めています。これにより、各ポジションは過去の一定数のトークンのみに参加することができます。これにより、計算の FLOP が削減されるだけでなく、KV キャッシュのサイズがウィンドウのサイズに制限され、シーケンスの長さが 2 次ではなくなります。その他の詳細はすべてグローバル MQA と同じです。
反復ブロック
私たちのリカレント ブロック (図 2(c)) は、Mehta et al., 2022 の GSS ブロックや、Mamba Gu と Dao, 2023 によって使用されているブロックに似ています。次元の入力を受け取ります。 $D$出力次元を持つ 2 つの線形レイヤーを適用します $D_{RNN}$並行して、2 つのブランチを作成します。最初のブランチでは、H3 Dao et al., 2022b の Shift-SSM に触発された小さな分離可能な Conv1D レイヤーを、時間フィルター次元 4 で適用します。この Conv1D レイヤーは非常に小さく、 $4D$パラメータ。 Conv1D 層に続いて、提案する RG-LRU 層 (以下に定義) を適用します。2 番目のブランチでは、GeLU 非線形性を適用し、要素ごとの乗算によってブランチをマージします。次に、出力次元を持つ最終的な線形レイヤーを適用します。 $D$ 。
リアルゲート線形リカレントユニット (RG-LRU)
私たちが提案する RG-LRU 層は、線形再帰ユニット (LRU) Orvieto et al., 2023b に触発された単純な再帰を持っていますが、非線形 RNN、特に LSTM Hochreiter と Schmidhuber, 1997 に関する文献によって動機付けられたゲート機構が組み込まれています。 GRUs Chung et al.、2014。層を説明する方程式は次のとおりです。
$$begin{align} r_t &= sigma(W_{a} x_t + b_a), & text{リカレンス ゲート} \ i_t &= sigma(W_{x} x_t + b_x), & text{入力ゲート} \ a_t &= a^{cr_t}, & text{} \ h_t &= a_t odot h_{t-1} + sqrt{1 - a_t^2} odot (i_t odot x_t)。 & text{} end{align}$$
レイヤーの出力は次のとおりです。 $y_t=h_t$ 、および非線形性 $シグマ$方程式中の はシグモイド関数です。反復体重 $a$式 (4) の は対角です。したがって、すべての操作は要素ごとに行われます。パラメータ化します $a$式(3)では次のようになります。 $a=シグマ(ラムダ)$ 、 どこ $ラムダ$学習可能なパラメータです。これにより、次のことが保証されます $0 <= a <= 1$ 、再発が安定していることを保証します。変数 $c$は 8 に設定されたスカラー値の定数です。数値を安定させるために、実際には次のように計算します。 $a^{cr_t}$ログスペースに保存されます (付録 A を参照)。レイヤーには両方の入力にゲートがあります $x$そして再発体重 $a$ 。ただし、どちらのゲートも再発状態に依存しません。 $h_{t-1}$これにより、デバイス上で計算を効率的に実行できるようになります。両方を初期化します $W_{a}$そして $W_{b}$ LeCun init を使用 LeCun et al.、2002。 $ラムダ$そのような $a^c$の間で均一に分布しています $0.9$そして $0.999$ (Orvieto et al., 2023b.) と同様、トレーニングの開始時。 SSM 文献の最近の多くの研究とは異なり、RG-LRU は直交多項式の理論に触発された初期化を使用しません (Gu et al., 2020)。また、基礎となる連続システムの離散化として定義されていません (Gu et al., 2020)。 2021a。元の LRU 層とは異なり、漸化式では複雑な代数を使用しません。複雑な再帰を使用すると、より表現力豊かな層が得られますが、Orvieto et al., 2023a では、Gu and Dao, 2023 によっても観察されているように、複雑な再帰は実際の言語モデリングには有益ではないことがわかりました。(付録 B を参照)
ゲートの動作
入力ゲート $i_t$ LSTM のものと似ており、入力をフィルタリング (またはスケールダウン) できます。 $x_t$ 。しかし、私たちの知る限りでは、再発ゲートは $r_t$文献にある他のゲート機構とは異なります。たとえば、Mamba Gu and Dao, 2023 で提案されている選択メカニズムは、補間する GRU の更新ゲートに相当します。 $x_t$ 。隠し状態に対するその効果により、LSTM の忘却ゲートと同様に、状態をリセットし、過去に保持している情報を忘れることができます。対照的に、私たちの再発ゲートは、Orvieto et al., 2023a の標準 LRU 更新と以前の隠れ状態の間を近似的に補間することができ、これにより入力を効果的に破棄し、以前の履歴からのすべての情報を保存することができます (詳細については付録 A を参照) )。このゲートの重要な役割は、有益でない入力の影響を軽減することでモデルが超指数関数的な記憶を達成できるようにすることであると考えています。