Griffin
Griffin :将门控线性循环与局部注意力混合以实现高效的语言模型
arXiv
模型架构
我们所有的模型都包含以下组件:(i)残差块,(ii)MLP 块,以及(iii)时间混合块。虽然 (i) 和 (ii) 在所有模型中都是相同的,但我们考虑三个时间混合块:全局多查询注意力 (MQA)、局部(滑动窗口)MQA 和我们提出的循环块。作为循环块的一部分,我们使用实门线性循环单元(RG-LRU)——一种受线性循环单元 Orvieto 等人,2023b 启发的新型循环层。
如图 2(a) 所示,残差块定义了我们模型的全局结构,其灵感来自于前范数 Transformers(Xiong 等人,2020)。嵌入输入序列后,我们将其传递给 $N$这样的块( $N$表示模型深度),然后我们应用 RMSNormZhang 和 Sennrich,2019 来生成最终的激活。为了计算 token 概率,我们应用了最后一个线性层,后跟一个 softmax。该层的权重与输入嵌入层共享。
残块
图 2:a)我们的模式架构的主要支柱是残差块,它是堆叠的 $N$次。 b) 我们使用的门控 MLP 模块。 c)我们提出的循环块作为多查询注意力(MQA)的替代方案。它使用我们提出的 RG-LRU 层,在 2.4 节中定义。
残差块包含两个按顺序应用的组件。第一个组件采用隐藏状态 $chi$并应用 RMSNorm Zhuang 和 Sennrich,2019,然后是时间混合块。然后我们将输出与跳跃连接合并 $chi$通过加法。类似地,第二个组件应用 RMSNorm,后跟 MLP 模块,然后将其输出与来自 RMSNorm 输入的跳跃连接合并。该块如图 2 (a) 所示。
MLP块
我们使用门控 MLP 模块 Dauphin et al., 2017(如图 2(b) 所示),它根据维度输入创建两个分支 $D$ 。我们应用具有输出维度的线性层 $MD$在每个分支上,其中 $M$表示扩展因子。为了简单起见,我们使用 $M=3$贯穿整个这项工作。我们在其中一个分支上应用 GeLU 非线性 Hendrycks 和 Gimpel,2016,然后通过逐元素乘法合并它们,类似于 GeGeLU Shazeer,2020。但是,在我们的 MLP 块中,我们应用了具有输出维度的最终线性层 $D$在 GeGeLU 层的输出上。
时间混合块
时间混合块是我们模型的组件,它聚合序列中不同时间位置的隐藏层激活。我们考虑三个时间混合块:全局 MQA Shazeer,2019、局部 MQA Beltagy 等人,2020 以及我们提出的 Recurrent 块。
全局多查询注意力
除非另有说明,我们使用 MQA 而不是 MHA 来提高 Transformer 基线 Shazeer,2019 的推理速度。我们使用固定的头部尺寸 $D_{头}=128$ ,我们固定注意力头的数量 $H$这样 $HD_{头}=D$ 。这需要模型尺寸 $D$是 128 的倍数。我们不使用任何绝对位置嵌入,但我们使用旋转位置嵌入 (RoPE) Su et al., 2021 作为相对位置嵌入。
局部滑动窗口注意力
使用全局注意力的主要缺点之一是其计算复杂度随序列长度呈二次方增长。为了解决这个问题,一些工作已经开始采用局部注意力 Beltagy et al., 2020,也称为滑动窗口注意力。它允许每个位置在过去只参与固定数量的代币。这不仅减少了计算的 FLOP,而且还将 KV 缓存的大小限制为窗口的大小,使其不再是序列长度的二次方。所有其他细节与全局 MQA 相同。
循环块
我们的循环块(图 2(c))类似于 GSS 块 Mehta et al., 2022 以及 Mamba Gu 和 Dao, 2023 使用的块。我们采用维度的输入 $D$并应用两个具有输出维度的线性层 $D_{RNN}$并行地创建两个分支。在第一个分支上,我们应用了一个小的可分离的 Conv1D 层,其灵感来自于 H3 Dao 等人,2022b 中的 Shift-SSM,时间滤波器维度为 4。请注意,这个 Conv1D 层非常小,只有 $4D$参数。我们在 Conv1D 层后面加上我们提出的 RG-LRU 层(定义如下)。在第二个分支上,我们应用 GeLU 非线性,然后通过逐元素乘法合并分支。然后我们应用具有输出维度的最终线性层 $D$ 。
实门线性循环单元 (RG-LRU)
我们提出的 RG-LRU 层具有受线性循环单元 (LRU) Orvieto 等人,2023b 启发的简单循环,但结合了受非线性 RNN 文献启发的门控机制,特别是 LSTM Hochreiter 和 Schmidhuber,1997 和GRUs Chung et al., 2014。描述该层的方程如下:
$$begin{align} r_t &= sigma(W_{a} x_t + b_a), & text{递归门} \ i_t &= sigma(W_{x} x_t + b_x), & text{输入门} \ a_t &= a^{cr_t}, & text{} \ h_t &= a_t odot h_{t-1} + sqrt{1 - a_t^2} odot (i_t odot x_t)。 & text{} end{对齐}$$
该层的输出为 $y_t=h_t$ ,和非线性 $西格玛$方程中是 sigmoid 函数。经常性体重 $a$等式(4)中是对角线。因此,所有操作都是逐元素的。我们参数化 $a$等式(3)中为 $a=sigma(Lambda)$ , 在哪里 $Lambda$是一个可学习的参数。这保证了 $0 <= a <= 1$ ,确保复发稳定。变量 $c$是一个标量值常量,设置为 8。为了数值稳定性,实际上我们计算 $a^{cr_t}$在日志空间中(参见附录 A)。该层在两个输入上都有门 $x$和经常性体重 $a$ 。然而,这两个门都不依赖于循环状态 $h_{t-1}$ ,这确保了计算可以在设备上高效执行。我们初始化两者 $W_{a}$和 $W_{b}$使用 LeCun init LeCun et al., 2002。我们初始化 $Lambda$这样 $a^c$均匀分布在 $0.9$和 $0.999$在训练开始时,类似于(Orvieto et al., 2023b.)。与 SSM 文献中的许多最新作品不同,RG-LRU 不使用受正交多项式理论启发的初始化 Gu 等人,2020,并且它也没有定义为底层连续系统的离散化 Gu 等人, 2021a。与原始 LRU 层不同,我们在递归中不使用复杂代数。虽然使用复杂的递归会导致更具表现力的层 Orvieto et al., 2023a,但我们发现复杂的递归对实践中的语言建模不利,正如 Gu 和 Dao, 2023 所观察到的那样。(参见附录 B)
门行为
输入门 $i_t$类似于LSTM中的,可以对输入进行过滤(或者按比例缩小) $x_t$ 。然而,据我们所知,我们的递归门 $r_t$与文献中的其他门控机制不同。例如,Mamba Gu and Dao, 2023 中提出的选择机制相当于 GRU 的更新门,它插值 $x_t$ 。它对隐藏状态的影响使其能够重置其状态并忘记过去保存的任何信息,类似于 LSTM 中的遗忘门。相比之下,我们的递归门可以在 Orvieto 等人 2023a 的标准 LRU 更新和之前的隐藏状态之间进行近似插值,这使得它能够有效地丢弃输入并保留之前历史记录中的所有信息(有关更多详细信息,请参阅附录 A) )。我们认为这个门的关键作用是通过减少无信息输入的影响来使模型实现超指数记忆。