Experimentos em torno de uma ideia simples para induzir vários modelos hierárquicos de codificação preditiva em um GPT. É tão simples que pode não funcionar. Mas, novamente, o progresso do aprendizado profundo é construído sobre a base de ideias simples. Vale a pena tentar.
Até agora, a ideia passou no teste decisivo de um amigo pesquisador. Será concluído na próxima semana ou depois. Se não der certo, deixarei os resultados experimentais negativos, bem como o repositório, e talvez algum estudante de doutorado possa desenvolver isso.
Atualização: acho que está funcionando?
StabilityAI pelo patrocínio para realizar esta pesquisa independente
? Huggingface por sua biblioteca acelerada
$ pip install simple-hierarchical-transformer
Três hierarquias, todas atendendo prevendo o próximo token
import torch
from simple_hierarchical_transformer import HierarchicalTransformer
model = HierarchicalTransformer (
num_tokens = 20000 , # number of tokens
dim = 512 , # model dimensions
depth = 6 , # depth
dim_head = 64 , # dimension per attention head
heads = 8 , # attention heads
seq_len = 2048 , # sequence lengths
hierarchies = ( 1 , 2 , 8 ), # hierarchies - here we have 1x (like in a regular transformer), then 2x and 8x compressed hierarchical tokens that undergo their own transformer blocks. information is pooled into one hierarchy at each layer
window_sizes = ( 32 , 64 , None ) # local attention window sizes - the idea is that the higher hierarchies can pass distant information to the local one. None stands for full receptive field. Setting 0 would turn off attention at this hierarchy altogether (while token shift will still be in effect in each layer)
)
ids = torch . randint ( 0 , 20000 , ( 1 , 2048 ))
loss , _ = model ( ids , return_loss = True )
loss . backward ()
# after much training
logits = model ( ids )
Ao não especificar hierarchies
e window_sizes
, você basicamente usa como padrão um transformador autorregressivo regular com atenção em todo o comprimento da sequência.
# non-hierarchical transformer
model = HierarchicalTransformer (
num_tokens = 20000 ,
dim = 512 ,
depth = 8 ,
dim_head = 64 ,
heads = 8 ,
seq_len = 2048 ,
hierarchies = 1 , # implied 1 if not set
window_sizes = None # implied None (full sequence length) if not set
)
Agora algo mais complexo. As experiências mostram que, à medida que as hierarquias são comprimidas, são necessárias dimensões de modelo maiores para obter a capacidade apropriada.
model = HierarchicalTransformer (
num_tokens = 256 ,
dim = ( 128 , 256 , 512 , 1024 ),
depth = 8 ,
seq_len = 1024 ,
use_flash_attn = True ,
ff_mult = ( 2 , 2 , 4 , 4 ),
dim_head = ( 16 , 32 , 64 , 64 ),
heads = ( 2 , 4 , 8 , 8 ),
hierarchies = ( 1 , 2 , 4 , 16 ),
hierarchical_stride = ( 1 , 1 , 1 , 8 ), # this would determine the stride when compressing, and when concatting the hierarchical tokens to the fine tokens, the past tokens will be repeated this amount of time. causality is not violated as using the trick from hourglass transformers where sequence is shifted by compression factor - 1. recommend sticking with 1 except for highly compressed hierarchies, as it becomes very uncompetitive with baseline and generations look off
window_sizes = ( 16 , 32 , 64 , None )
). cuda ()
# hierarchies
# 1x - dim 128 - attention (2 heads, 16 dim, receptive field 16)
# 2x - dim 256 - attention (4 heads, 32 dim, receptive field 32)
# 4x - dim 512 - attention (8 heads, 64 dim, receptive field 64)
# 8x - dim 1024 - attention (8 heads, 64 dim, receptive field of all)
ramificam-se para dois caminhos paralelos, um para tokens hierárquicos e outro para tokens finos simples.
mostram que a atenção local em tokens finos + hierárquicos pode chegar perto da linha de base de atenção total
simples dsconv parece suficiente para mesclar para 1 hierarquia
definir automaticamente o tamanho da janela para ser metade do comprimento máximo da sequência para hierarquias finas e todas as hierarquias
descobrir os efeitos de apenas agrupar todos os tokens finos + hierárquicos antes da perda de entropia cruzada - não há muita diferença
capacidade completa de adicionar qualquer número de hierarquias e designar qual hierarquia reunirá as informações das outras para previsão
dimensões totalmente personalizáveis entre hierarquias, já que hierarquias superiores exigem dimensões de modelo maiores
adicione perdas de profetas para ramos hierárquicos
permitir a repetição de tokens de hierarquia para tokens finos no futuro, já que a posição pode importar menos à medida que se sobe na hierarquia. mas não é uma prioridade, faça as coisas funcionarem primeiro - implementado como hierarchical_stride
permitir que algumas camadas dependam apenas da mudança de token, sem atenção
projeções aleatórias + vq, como foi feito no modelo de fala universal do cérebro - para codificação preditiva hierárquica
permitem especificar qual hierarquia recebe informações das outras durante a fusão, talvez projete uma atenção especializada com mascaramento, mas precisa levar em conta diferentes dimensões do modelo entre hierarquias
construir um bloco de atenção local simples, para uso em todas as hierarquias
adicionar atenção instantânea à biblioteca de atenção local
descobrir se a atenção pode ser compartilhada entre hierarquias
faça um relatório limpo do wandb mostrando compactação 2x sem muita perda para o nível do personagem enwik8
experimente um compressor baseado em autoatenção para hierarquias 4 ou superiores
construir um pequeno autoencoder usando os embeddings de token como entrada, bem no início da rede, e então usar mapas de recursos intermediários para cada rede hierárquica paralela
A ideia mais próxima seriam transformadores de ampulheta.
E meu interesse renovado em abordagens hierárquicas veio da leitura disto.
@article { Nawrot2021HierarchicalTA ,
title = { Hierarchical Transformers Are More Efficient Language Models } ,
author = { Piotr Nawrot and Szymon Tworkowski and Michal Tyrolski and Lukasz Kaiser and Yuhuai Wu and Christian Szegedy and Henryk Michalewski } ,
journal = { ArXiv } ,
year = { 2021 } ,
volume = { abs/2110.13711 }
}
@inproceedings { dao2022flashattention ,
title = { Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness } ,
author = { Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{'e}, Christopher } ,
booktitle = { Advances in Neural Information Processing Systems } ,
year = { 2022 }
}
@misc { su2021roformer ,
title = { RoFormer: Enhanced Transformer with Rotary Position Embedding } ,
author = { Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu } ,
year = { 2021 } ,
eprint = { 2104.09864 } ,
archivePrefix = { arXiv } ,
primaryClass = { cs.CL }
}
@inproceedings { Sun2022ALT ,
title = { A Length-Extrapolatable Transformer } ,
author = { Yutao Sun and Li Dong and Barun Patra and Shuming Ma and Shaohan Huang and Alon Benhaim and Vishrav Chaudhary and Xia Song and Furu Wei } ,
year = { 2022 }
}
@software { peng_bo_2021_5196578 ,
author = { PENG Bo } ,
title = { BlinkDL/RWKV-LM: 0.01 } ,
month = { aug } ,
year = { 2021 } ,
publisher = { Zenodo } ,
version = { 0.01 } ,
doi = { 10.5281/zenodo.5196578 } ,
url = { https://doi.org/10.5281/zenodo.5196578 }
}
@article { Piergiovanni2023Mirasol3BAM ,
title = { Mirasol3B: A Multimodal Autoregressive model for time-aligned and contextual modalities } ,
author = { A. J. Piergiovanni and Isaac Noble and Dahun Kim and Michael S. Ryoo and Victor Gomes and Anelia Angelova } ,
journal = { ArXiv } ,
year = { 2023 } ,
volume = { abs/2311.05698 } ,
url = { https://api.semanticscholar.org/CorpusID:265129010 }
}