Experimentos en torno a una idea simple para inducir múltiples modelos de codificación predictiva jerárquica dentro de un GPT. Es tan simple que puede que no funcione. Pero, de nuevo, el progreso del aprendizaje profundo se construye sobre la base de ideas simples. Vale la pena intentarlo.
Hasta ahora, la idea ha pasado la prueba de fuego de un amigo investigador. Lo completará en la próxima semana más o menos. Si no funciona, dejaré los resultados experimentales negativos, así como el repositorio, y tal vez algún estudiante de doctorado pueda aprovecharlos.
Actualización: ¿Creo que está funcionando?
StabilityAI por el patrocinio para realizar esta investigación independiente
? Huggingface por su biblioteca acelerada
$ pip install simple-hierarchical-transformer
Tres jerarquías, todas ellas con servicios que predicen el siguiente token.
import torch
from simple_hierarchical_transformer import HierarchicalTransformer
model = HierarchicalTransformer (
num_tokens = 20000 , # number of tokens
dim = 512 , # model dimensions
depth = 6 , # depth
dim_head = 64 , # dimension per attention head
heads = 8 , # attention heads
seq_len = 2048 , # sequence lengths
hierarchies = ( 1 , 2 , 8 ), # hierarchies - here we have 1x (like in a regular transformer), then 2x and 8x compressed hierarchical tokens that undergo their own transformer blocks. information is pooled into one hierarchy at each layer
window_sizes = ( 32 , 64 , None ) # local attention window sizes - the idea is that the higher hierarchies can pass distant information to the local one. None stands for full receptive field. Setting 0 would turn off attention at this hierarchy altogether (while token shift will still be in effect in each layer)
)
ids = torch . randint ( 0 , 20000 , ( 1 , 2048 ))
loss , _ = model ( ids , return_loss = True )
loss . backward ()
# after much training
logits = model ( ids )
Al no especificar hierarchies
y window_sizes
, básicamente se utiliza de forma predeterminada un transformador autorregresivo normal con atención en toda la longitud de la secuencia.
# non-hierarchical transformer
model = HierarchicalTransformer (
num_tokens = 20000 ,
dim = 512 ,
depth = 8 ,
dim_head = 64 ,
heads = 8 ,
seq_len = 2048 ,
hierarchies = 1 , # implied 1 if not set
window_sizes = None # implied None (full sequence length) if not set
)
Ahora algo más complejo. Los experimentos muestran que a medida que se comprimen las jerarquías, se necesitan mayores dimensiones del modelo para obtener la capacidad adecuada.
model = HierarchicalTransformer (
num_tokens = 256 ,
dim = ( 128 , 256 , 512 , 1024 ),
depth = 8 ,
seq_len = 1024 ,
use_flash_attn = True ,
ff_mult = ( 2 , 2 , 4 , 4 ),
dim_head = ( 16 , 32 , 64 , 64 ),
heads = ( 2 , 4 , 8 , 8 ),
hierarchies = ( 1 , 2 , 4 , 16 ),
hierarchical_stride = ( 1 , 1 , 1 , 8 ), # this would determine the stride when compressing, and when concatting the hierarchical tokens to the fine tokens, the past tokens will be repeated this amount of time. causality is not violated as using the trick from hourglass transformers where sequence is shifted by compression factor - 1. recommend sticking with 1 except for highly compressed hierarchies, as it becomes very uncompetitive with baseline and generations look off
window_sizes = ( 16 , 32 , 64 , None )
). cuda ()
# hierarchies
# 1x - dim 128 - attention (2 heads, 16 dim, receptive field 16)
# 2x - dim 256 - attention (4 heads, 32 dim, receptive field 32)
# 4x - dim 512 - attention (8 heads, 64 dim, receptive field 64)
# 8x - dim 1024 - attention (8 heads, 64 dim, receptive field of all)
divida en dos caminos paralelos, uno para tokens jerárquicos y otro para tokens simples y finos.
mostrar que la atención local en tokens finos + jerárquicos puede acercarse a la línea base de atención total
dsconv simple parece suficiente para fusionarse en 1 jerarquía
establecer automáticamente el tamaño de la ventana para que sea la mitad de la longitud máxima de la secuencia para jerarquías finas y para todas las jerarquías
descubrir los efectos de simplemente agrupar todos los tokens finos + jerárquicos antes de la pérdida de entropía cruzada: no hay mucha diferencia
Capacidad completa para agregar cualquier número de jerarquías y designar qué jerarquía agrupará la información de las demás para realizar predicciones.
dimensiones totalmente personalizables en todas las jerarquías, ya que las jerarquías más altas requieren mayores dimensiones del modelo
agregar pérdidas de profetas para ramas jerárquicas
permita repetir tokens de jerarquía para tokens finos en el futuro, ya que la posición puede importar menos a medida que uno asciende en la jerarquía. pero no es una prioridad, primero haga que todo funcione - implementado como hierarchical_stride
permitir que algunas capas dependan únicamente del cambio de token, sin atención
proyecciones aleatorias + vq, como se hizo en el documento del modelo de habla universal del cerebro - para codificación predictiva jerárquica
permitir especificar qué jerarquía recibe información de las demás durante la fusión, tal vez diseñar una atención especializada con enmascaramiento, pero es necesario tener en cuenta diferentes dimensiones del modelo entre jerarquías
Cree un bloque de atención local simple, para usar en todas las jerarquías.
agregar atención flash a la biblioteca de atención local
descubrir si la atención se puede compartir entre jerarquías
haga un informe limpio de wandb que muestre compresión 2x sin mucha pérdida para el nivel de carácter enwik8
Pruebe un compresor basado en atención personal para jerarquías 4 o superiores.
construya un pequeño codificador automático utilizando las incrustaciones de tokens como entrada, al comienzo de la red, y luego use mapas de características intermedias para cada red jerárquica paralela
La idea más cercana serían los transformadores de reloj de arena.
Y mi renovado interés en los enfoques jerárquicos surgió al leer esto.
@article { Nawrot2021HierarchicalTA ,
title = { Hierarchical Transformers Are More Efficient Language Models } ,
author = { Piotr Nawrot and Szymon Tworkowski and Michal Tyrolski and Lukasz Kaiser and Yuhuai Wu and Christian Szegedy and Henryk Michalewski } ,
journal = { ArXiv } ,
year = { 2021 } ,
volume = { abs/2110.13711 }
}
@inproceedings { dao2022flashattention ,
title = { Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness } ,
author = { Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{'e}, Christopher } ,
booktitle = { Advances in Neural Information Processing Systems } ,
year = { 2022 }
}
@misc { su2021roformer ,
title = { RoFormer: Enhanced Transformer with Rotary Position Embedding } ,
author = { Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu } ,
year = { 2021 } ,
eprint = { 2104.09864 } ,
archivePrefix = { arXiv } ,
primaryClass = { cs.CL }
}
@inproceedings { Sun2022ALT ,
title = { A Length-Extrapolatable Transformer } ,
author = { Yutao Sun and Li Dong and Barun Patra and Shuming Ma and Shaohan Huang and Alon Benhaim and Vishrav Chaudhary and Xia Song and Furu Wei } ,
year = { 2022 }
}
@software { peng_bo_2021_5196578 ,
author = { PENG Bo } ,
title = { BlinkDL/RWKV-LM: 0.01 } ,
month = { aug } ,
year = { 2021 } ,
publisher = { Zenodo } ,
version = { 0.01 } ,
doi = { 10.5281/zenodo.5196578 } ,
url = { https://doi.org/10.5281/zenodo.5196578 }
}
@article { Piergiovanni2023Mirasol3BAM ,
title = { Mirasol3B: A Multimodal Autoregressive model for time-aligned and contextual modalities } ,
author = { A. J. Piergiovanni and Isaac Noble and Dahun Kim and Michael S. Ryoo and Victor Gomes and Anelia Angelova } ,
journal = { ArXiv } ,
year = { 2023 } ,
volume = { abs/2311.05698 } ,
url = { https://api.semanticscholar.org/CorpusID:265129010 }
}