Griffin
Griffin : Misturando recorrências lineares controladas com atenção local para modelos de linguagem eficientes
arXiv
Arquitetura do modelo
Todos os nossos modelos contêm os seguintes componentes: (i) um bloco residual, (ii) um bloco MLP e (iii) um bloco de mistura temporal. Embora (i) e (ii) sejam iguais em todos os modelos, consideramos três blocos de mistura temporal: atenção multiconsulta global (MQA), MQA local (janela deslizante) e nosso bloco recorrente proposto. Como parte do bloco recorrente, usamos a Unidade Recorrente Linear Real-Gated (RG-LRU) – uma nova camada recorrente inspirada na Unidade Recorrente Linear Orvieto et al., 2023b.
O bloco residual, conforme mostrado na Figura 2 (a), define a estrutura global de nossos modelos e é inspirado em Transformers pré-norma (Xiong et al., 2020). Depois de incorporar a sequência de entrada, passamos por ela $N$ tais blocos ( $N$ denotando a profundidade do modelo) e então aplicamos o RMSNorm Zhang e Sennrich, 2019 para produzir as ativações finais. Para calcular as probabilidades do token, aplicamos uma camada linear final seguida por um softmax. Os pesos desta camada são compartilhados com a camada de incorporação de entrada.
Bloco residual
Figura 2: a) A principal espinha dorsal da nossa arquitetura de modo é o bloco residual, que é empilhado $N$ vezes. b) O bloco MLP fechado que usamos. c) O bloco recorrente que propomos como alternativa ao Multi Query Attention (MQA). Ele usa nossa camada RG-LRU proposta, definida na Seção 2.4.
O bloco residual contém dois componentes, aplicados em ordem. O primeiro componente assume o estado oculto $chi$ e aplica um RMSNorm Zhang e Sennrich, 2019, seguido pelo bloco de mixagem temporal. Em seguida, mesclamos a saída com uma conexão de salto de $chi$ através da adição. Da mesma forma, o segundo componente aplica o RMSNorm, seguido pelo bloco MLP e então mescla sua saída com uma conexão de salto da entrada do RMSNorm. Este bloco está ilustrado na Figura 2 (a).
Bloco MLP
Usamos um bloco MLP fechado Dauphin et al., 2017 (ilustrado na Figura 2 (b)), que cria duas ramificações a partir de sua entrada de dimensão $D$ . Aplicamos uma camada linear com dimensão de saída $MD$ em cada filial, onde $M$ denota o fator de expansão. Para simplificar, usamos $M=3$ ao longo deste trabalho. Aplicamos uma não linearidade GeLU Hendrycks e Gimpel, 2016 em um dos ramos antes de mesclá-los por multiplicação elemento a elemento, semelhante a GeGeLU Shazeer, 2020. No entanto, em nosso bloco MLP, aplicamos uma camada linear final com dimensão de saída $D$ nas saídas da camada GeGeLU.
Blocos de mixagem temporal
O bloco de mistura temporal é o componente do nosso modelo que agrega ativações de camadas ocultas em diferentes locais temporais na sequência. Consideramos três blocos de mistura temporal: MQA global Shazeer, 2019, MQA local Beltagy et al., 2020 e nosso bloco Recorrente proposto.
Atenção global multi-consulta
Salvo indicação em contrário, usamos MQA em vez de MHA para melhorar as velocidades de inferência de nossas linhas de base do Transformer Shazeer, 2019. Usamos uma dimensão de cabeça fixa $D_{cabeça}=128$ , e fixamos o número de cabeças de atenção $H$ tal que $HD_{cabeça}=D$ . Isso requer a dimensão do modelo $D$ como um múltiplo de 128. Não usamos nenhuma incorporação posicional absoluta, mas usamos Rotary Position Embedding (RoPE) Su et al., 2021 como uma incorporação posicional relativa.
Atenção à janela deslizante local
Uma das principais desvantagens do uso da atenção global é que sua complexidade computacional cresce quadraticamente no comprimento da sequência. Para resolver isso, vários trabalhos começaram a adotar a atenção local Beltagy et al., 2020, também conhecida como atenção de janela deslizante. Permite que cada posição atenda apenas a um número fixo de tokens no passado. Isso não apenas reduz os FLOPs computacionais, mas também limita o tamanho do cache KV ao tamanho da janela, tornando-o não mais quadrático no comprimento da sequência. Todos os outros detalhes são iguais aos do MQA global.
Bloco recorrente
Nosso bloco recorrente (Figura 2 (c)) é semelhante ao bloco GSS Mehta et al., 2022 e ao bloco usado por Mamba Gu e Dao, 2023. Pegamos a entrada da dimensão $D$ e aplique duas camadas lineares com dimensão de saída $D_{RNN}$ em paralelo, criando duas filiais. No primeiro ramo, aplicamos uma pequena camada Conv1D separável, inspirada no Shift-SSM em H3 Dao et al., 2022b, com dimensão de filtro temporal de 4. Observe que esta camada Conv1D é muito pequena, com apenas $4D$ parâmetros. Seguimos a camada Conv1D com nossa camada RG-LRU proposta (definida abaixo). Na segunda ramificação aplicamos uma não linearidade GeLU e então mesclamos as ramificações por multiplicação elemento a elemento. Em seguida, aplicamos uma camada linear final com dimensão de saída $D$ .
Unidade Recorrente Linear Real-Gated (RG-LRU)
Nossa camada RG-LRU proposta tem uma recorrência simples inspirada na Unidade Recorrente Linear (LRU) Orvieto et al., 2023b, mas incorpora um mecanismo de gate motivado pela literatura sobre RNNs não lineares, em particular LSTMs Hochreiter e Schmidhuber, 1997 e GRUs Chung et al., 2014. As equações que descrevem a camada são as seguintes:
$$begin{align} r_t &= sigma(W_{a} x_t + b_a), & text{porta de recorrência} \ i_t &= sigma(W_{x} x_t + b_x), & text{ porta de entrada} \ a_t &= a^{cr_t}, & text{} \ h_t &= a_t odot h_{t-1} + sqrt{1 - a_t^2} odot (i_t odot x_t). & text{} end{align}$$
A saída da camada é $y_t=h_t$ , e a não linearidade $sigma$ nas equações está a função sigmóide. O peso recorrente $a$ na Equação (4) é diagonal. Portanto, todas as operações são elementares. Nós parametrizamos $a$ na Equação (3) como $a=sigma(Lambda)$ , onde $Lambda$ é um parâmetro que pode ser aprendido. Isto garante que $0 <= a <= 1$ , garantindo que a recorrência seja estável. A variável $c$ é uma constante com valor escalar definida como 8. Para estabilidade numérica, na prática calculamos $a^{cr_t}$ no espaço de log (consulte o Apêndice A). A camada possui portas tanto na entrada $x$ e o peso recorrente $a$ . No entanto, nenhuma porta depende do estado recorrente $h_{t-1}$ , o que garante que o cálculo possa ser executado com eficiência no dispositivo. Inicializamos ambos $W_{a}$ e $W_{b}$ usando LeCun init LeCun et al., 2002. Inicializamos $Lambda$ tal que $a^c$ está uniformemente distribuído entre $0,9$ e US$ 0,999$ no início do treinamento, semelhante a (Orvieto et al., 2023b.). Ao contrário de muitos trabalhos recentes na literatura SSM, o RG-LRU não usa inicialização inspirada na teoria de polinômios ortogonais Gu et al., 2020, e também não é definido como a discretização de um sistema contínuo subjacente Gu et al., 2021a. Ao contrário da camada LRU original, não usamos álgebra complexa na recorrência. Embora o uso de recorrências complexas levaria a uma camada mais expressiva, Orvieto et al., 2023a descobrimos que recorrências complexas não eram benéficas para a modelagem de linguagem na prática, como também observado por Gu e Dao, 2023. (ver Apêndice B)
Comportamento do portão
A porta de entrada $eu_t$ é semelhante ao do LSTM, que pode filtrar (ou reduzir) a entrada $x_t$ . No entanto, até onde sabemos, nossa porta de recorrência $r_t$ é diferente de outros mecanismos de controle na literatura. Por exemplo, o mecanismo de seleção proposto em Mamba Gu e Dao, 2023 é comparável ao portão de atualização de GRUs que interpola $x_t$ . Seu efeito no estado oculto permite redefinir seu estado e esquecer qualquer informação que retenha do passado, semelhante ao portão de esquecimento no LSTM. Em contraste, nossa porta de recorrência pode interpolar aproximadamente entre a atualização LRU padrão de Orvieto et al., 2023a e o estado oculto anterior, o que permite descartar efetivamente a entrada e preservar todas as informações do histórico anterior (consulte o Apêndice A para obter mais detalhes ). Acreditamos que o papel fundamental desta porta é permitir que o modelo alcance memória superexponencial, reduzindo a influência de entradas não informativas.