Griffin
Griffin : Mélanger des récurrences linéaires fermées avec une attention locale pour des modèles de langage efficaces
arXiv
Architecture du modèle
Tous nos modèles contiennent les composants suivants : (i) un bloc résiduel, (ii) un bloc MLP et (iii) un bloc de mélange temporel. Bien que (i) et (ii) soient les mêmes dans tous les modèles, nous considérons trois blocs de mélange temporel : l'attention multi-requête globale (MQA), le MQA local (à fenêtre coulissante) et notre bloc récurrent proposé. Dans le cadre du bloc récurrent, nous utilisons l'unité récurrente linéaire à portail réel (RG-LRU) – une nouvelle couche récurrente inspirée de l'unité récurrente linéaire Orvieto et al., 2023b.
Le bloc résiduel, comme le montre la figure 2(a), définit la structure globale de nos modèles et s'inspire des transformateurs pré-normes (Xiong et al., 2020). Après avoir intégré la séquence d'entrée, nous la transmettons $N$ de tels blocs ( $N$ désignant la profondeur du modèle), puis nous appliquons RMSNorm Zhang et Sennrich, 2019 pour produire les activations finales. Pour calculer les probabilités des jetons, nous appliquons une couche linéaire finale suivie d'un softmax. Les poids de cette couche sont partagés avec la couche d'intégration d'entrée.
Blocage résiduel
Figure 2 : a) L'épine dorsale principale de notre architecture de modes est le bloc résiduel, qui est empilé $N$ fois. b) Le bloc MLP fermé que nous utilisons. c) Le bloc récurrent que nous proposons comme alternative au Multi Query Attention (MQA). Il utilise notre couche RG-LRU proposée, définie dans la section 2.4.
Le bloc résiduel contient deux composants, appliqués dans l'ordre. Le premier composant prend l'état caché $chi$ et applique un RMSNorm Zhang et Sennrich, 2019, suivi du bloc de mélange temporel. Nous fusionnons ensuite la sortie avec une connexion sautée de $chi$ par addition. De même, le deuxième composant applique RMSNorm, suivi du bloc MLP, puis fusionne sa sortie avec une connexion sautée à partir de l'entrée de RMSNorm. Ce bloc est illustré sur la figure 2 (a).
Bloc MLP
Nous utilisons un bloc MLP fermé Dauphin et al., 2017 (illustré sur la figure 2 (b)), qui crée deux branches à partir de son entrée de dimension $D$ . Nous appliquons une couche linéaire avec une dimension de sortie $MD$ sur chaque branche, où M$ désigne le facteur d'expansion. Pour simplifier, nous utilisons M$=3$ tout au long de ce travail. Nous appliquons une non-linéarité GeLU Hendrycks et Gimpel, 2016 sur l'une des branches avant de les fusionner par multiplication élément par élément, similaire à GeGeLU Shazeer, 2020. Cependant, dans notre bloc MLP, nous appliquons une couche linéaire finale avec une dimension de sortie $D$ sur les sorties de la couche GeGeLU.
Blocs de mélange temporel
Le bloc de mélange temporel est le composant de notre modèle qui regroupe les activations de couches cachées à différents emplacements temporels de la séquence. Nous considérons trois blocs de mélange temporel : MQA global Shazeer, 2019, MQA local Beltagy et al., 2020 et notre bloc récurrent proposé.
Attention globale multi-requêtes
Sauf indication contraire, nous utilisons MQA plutôt que MHA pour améliorer les vitesses d'inférence de nos lignes de base Transformer Shazeer, 2019. Nous utilisons une dimension de tête fixe $D_{tête}=128$ , et on fixe le nombre de têtes d'attention $H$ tel que $HD_{head}=D$ . Cela nécessite la dimension du modèle $D$ être un multiple de 128. Nous n'utilisons aucun plongement positionnel absolu, mais nous utilisons le Rotary Position Embedding (RoPE) Su et al., 2021 comme plongement positionnel relatif.
Attention fenêtre coulissante locale
L’un des principaux inconvénients de l’utilisation de l’attention globale est que la complexité de calcul augmente quadratiquement dans la longueur de la séquence. Pour résoudre ce problème, plusieurs travaux ont commencé à adopter l’attention locale Beltagy et al., 2020, également connue sous le nom d’attention par fenêtre coulissante. Cela permet à chaque position de s'occuper uniquement d'un nombre fixe de jetons dans le passé. Cela réduit non seulement les FLOP de calcul, mais limite également la taille du cache KV à la taille de la fenêtre, ce qui le rend plus quadratique dans la longueur de la séquence. Tous les autres détails sont les mêmes que ceux du MQA global.
Blocage récurrent
Notre bloc récurrent (Figure 2(c)) est similaire au bloc GSS Mehta et al., 2022 et au bloc utilisé par Mamba Gu et Dao, 2023. Nous prenons l'entrée de dimension $D$ et appliquez deux couches linéaires avec la dimension de sortie $D_{RNN}$ en parallèle, créant deux branches. Sur la première branche, nous appliquons une petite couche Conv1D séparable, inspirée du Shift-SSM dans H3 Dao et al., 2022b, avec une dimension de filtre temporel de 4. A noter que cette couche Conv1D est très petite, avec juste $4D$ paramètres. Nous suivons la couche Conv1D avec notre couche RG-LRU proposée (définie ci-dessous.) Sur la deuxième branche, nous appliquons une non-linéarité GeLU puis fusionnons les branches par multiplication élément par élément. Nous appliquons ensuite une couche linéaire finale avec la dimension de sortie $D$ .
Unité récurrente linéaire à portail réel (RG-LRU)
Notre couche RG-LRU proposée a une récurrence simple inspirée de l'unité récurrente linéaire (LRU) Orvieto et al., 2023b, mais intègre un mécanisme de déclenchement motivé par la littérature sur les RNN non linéaires, en particulier les LSTM Hochreiter et Schmidhuber, 1997 et GRU Chung et al., 2014. Les équations décrivant la couche sont les suivantes :
$$begin{align} r_t &= sigma(W_{a} x_t + b_a), & text{porte de récurrence} \ i_t &= sigma(W_{x} x_t + b_x), & text{ porte d'entrée} \ a_t &= a^{cr_t}, & text{} \ h_t &= a_t odot h_{t-1} + sqrt{1 - a_t^2} odot (i_t odot x_t). & text{} end{align}$$
La sortie de la couche est $y_t=h_t$ , et la non-linéarité $sigma$ dans les équations se trouve la fonction sigmoïde. Le poids récurrent $un$ dans l’équation (4), est diagonale. Par conséquent, toutes les opérations sont élémentaires. Nous paramétrons $un$ dans l'équation (3) comme $a=sigma(Lambda)$ , où $Lambda$ est un paramètre apprenable. Ceci garantit que 0 $ <= un <= 1$ , garantissant que la récurrence est stable. La variable $c$ est une constante à valeur scalaire fixée à 8. Pour la stabilité numérique, en pratique nous calculons $a^{cr_t}$ dans l’espace journal (voir l’annexe A). La couche a des portes à la fois sur l'entrée $x$ et le poids récurrent $un$ . Cependant, aucune des deux portes ne dépend de l'état récurrent $h_{t-1}$ , ce qui garantit que le calcul peut être exécuté efficacement sur l'appareil. Nous initialisons les deux $W_{a}$ et $W_{b}$ en utilisant LeCun init LeCun et al., 2002. Nous initialisons $Lambda$ tel que $a^c$ est uniformément réparti entre $0,9$ et $0.999$ au début de la formation, similaire à (Orvieto et al., 2023b.). Contrairement à de nombreux travaux récents dans la littérature SSM, le RG-LRU n'utilise pas d'initialisation inspirée de la théorie des polynômes orthogonaux Gu et al., 2020, et il n'est pas non plus défini comme la discrétisation d'un système continu sous-jacent Gu et al., 2021a. Contrairement à la couche LRU originale, nous n’utilisons pas d’algèbre complexe dans la récurrence. Bien que l'utilisation de récurrences complexes conduise à une couche plus expressive Orvieto et al., 2023a, nous avons constaté que les récurrences complexes n'étaient pas bénéfiques pour la modélisation du langage dans la pratique, comme l'ont également observé Gu et Dao, 2023. (voir Annexe B)
Comportement de la porte
La porte d'entrée $i_t$ est similaire à celui de LSTM, qui peut filtrer (ou réduire) l'entrée $x_t$ . Cependant, à notre connaissance, notre porte de récurrence $r_t$ est différent des autres mécanismes de déclenchement de la littérature. Par exemple, le mécanisme de sélection proposé dans Mamba Gu et Dao, 2023 est comparable à la porte de mise à jour des GRU qui interpole $x_t$ . Son effet sur l'état caché lui permet de réinitialiser son état et d'oublier toutes les informations qu'il détient du passé, semblable à la porte d'oubli du LSTM. En revanche, notre porte de récurrence peut approximativement interpoler entre la mise à jour LRU standard d'Orvieto et al., 2023a et l'état caché précédent, ce qui lui permet d'éliminer efficacement l'entrée et de conserver toutes les informations de l'historique précédent (voir l'Annexe A pour plus de détails). ). Nous pensons que le rôle clé de cette porte est de permettre au modèle d'atteindre une mémoire super-exponentielle en réduisant l'influence des entrées non informatives.