このリポジトリには、Generative Pretrained Transformer (GPT) のバリアントである LLaMA 2 (Large Language Model Meta AI) モデルの実装が含まれています。実装では、モデルのアーキテクチャと推論プロセスに焦点を当てます。アーキテクチャの重要な部分を理解しやすくするために、コードが再構成され、大量のコメントが追加されています。
RMS 正規化: RMSNorm は、元のレイヤー正規化 (LayerNorm) を簡略化したものです。 LayerNorm は、層の活性化を安定させ、モデルの収束を向上させるために内部共変量シフトの問題を処理できる正則化手法です。 LLaMA 2 では非常に成功していることが証明されています。
アクティベーション機能: LLaMA 2 は ReLU の代わりに SwiGLU アクティベーション機能を使用し、トレーニング パフォーマンスの向上につながります。
回転位置エンベディング (RoPE): GPT-Neo-X プロジェクトからインスピレーションを得た LLaMA 2 には、各層に回転位置エンベディングが組み込まれており、モデルの位置の理解を強化します。
コンテキスト長とグループ化クエリ アテンション (GQA) の増加: LLaMA 2 モデルにはコンテキスト ウィンドウが 2 倍 (2048 から 4096 トークン) あり、グループ化クエリ アテンションが採用されています。これにより、長い文書、チャット履歴、要約タスクの処理が向上します。
KV キャッシュは、言語モデル (LM) デコードの推論プロセスを高速化するためにこの実装で採用されている重要な最適化手法です。各トークンが前のトークンに基づいて予測される自己回帰デコード中、モデル内の自己注意には因果関係があります。これは、トークンの表現が、それ自体と、将来のトークンではなく、以前のトークンにのみ基づいて計算されることを意味します。
セルフアテンションでは、入力シーケンスはキー、値、およびクエリ射影を使用して射影されます。 KV キャッシュはキーと値の投影の結果を効率的に保存し、将来のデコード反復での冗長な計算の必要性を排除します。その結果、自己回帰デコード中に固定されたままのトークンの表現をキャッシュから取得できるようになり、推論速度が大幅に向上します。
この KV キャッシュ技術は、デコード中の LLaMA モデルの効率と速度を向上させる重要なアーキテクチャ機能です。
LLaMA 2 モデルには、Shazeer (2019) によって提案された、マルチヘッド アテンション (MHA) アルゴリズムの改良版であるマルチクエリ アテンション (MQA) の概念のバリエーションが組み込まれています。 MQA は、精度の低下を最小限に抑えながら、アテンション メカニズムの効率を高めます。
従来のマルチヘッド アテンションでは、アテンションの計算全体が h 回複製されます。ここで、h はアテンション ヘッドの数です。ただし、GQA は、K 値と V 値からヘッド寸法 (h) を削除または大幅に削減することにより、計算の冗長性を削減します。 MQA では、クエリ値 (Q) の各「ヘッド」に同じ K 変換と V 変換が行われ、アテンションの計算が最適化されます。
この改良により、MHA と同様の計算パフォーマンスが得られますが、メモリから読み書きされるデータ量が大幅に削減されます。結果として、GQA はパフォーマンス (演算強度の増加による) とメモリ空間効率 (保存される KV キャッシュ データ量の減少による) の両方を向上させ、LLaMA アーキテクチャへの価値ある追加となります。
LLaMA 2 モデルでは、位置情報をトークン表現に組み込むことにより、Rotary Positional Embeddings (RoPE) が注意メカニズムを強化する上で重要な役割を果たします。 「注意」の概念は強力ですが、計算された注意が意味のあるものであることを保証するには、トークンに位置の概念が必要です。
位置埋め込みには、絶対と相対の 2 つの主なタイプがあります。絶対位置埋め込みは入力フレーズ内の単語の絶対位置をエンコードし、相対位置埋め込みは 2 つの単語間の相対位置をエンコードします。これらの埋め込みは、トークンがシーケンス内のコンテキストを理解するのに役立つ重要な位置情報を提供します。
回転位置埋め込みは、回転行列を利用して位置情報を埋め込むという独自のアプローチを採用しています。目標は、位置 m および n におけるベクトル q と k の内積が q、k、およびそれらの相対距離 (m — n) にのみ依存することを保証することです。角度がベクトルの位置である回転行列は、この基準に合わせて行列の乗算によって元のベクトルに埋め込まれます。
位置情報を組み込むこの革新的なアプローチにより、トークンの関係とコンテキストを理解するモデルの能力が強化され、注意メカニズムの改善に貢献します。
model.py
: LLaMA トランスフォーマー モデルの実装と、各コンポーネントと機能を説明する詳細なコメントが含まれています。
inference.py
: トレーニングされた LLaMA モデルを推論に使用する方法を示し、入出力処理についての洞察を提供します。
自由にコードを探索し、間違いがあれば修正し、LLaMA 2 モデルを試してみてください。