Expérimentations autour d'une idée simple pour induire plusieurs modèles de codage prédictif hiérarchique au sein d'un GPT. C'est si simple que ça ne marchera peut-être pas. Mais là encore, les progrès du deep learning reposent sur des idées simples. Ça vaut le coup.
Jusqu’à présent, l’idée a passé le test décisif d’un ami chercheur. Le terminera dans la semaine prochaine. Si cela ne fonctionne pas, je laisserai les résultats expérimentaux négatifs ainsi que le référentiel, et peut-être qu'un doctorant pourra s'en servir.
Mise à jour : je pense que ça marche ?
StabilityAI pour le parrainage pour réaliser cette recherche indépendante
? Huggingface pour leur bibliothèque d'accélération
$ pip install simple-hierarchical-transformer
Trois hiérarchies, toutes desservant prédisant le prochain jeton
import torch
from simple_hierarchical_transformer import HierarchicalTransformer
model = HierarchicalTransformer (
num_tokens = 20000 , # number of tokens
dim = 512 , # model dimensions
depth = 6 , # depth
dim_head = 64 , # dimension per attention head
heads = 8 , # attention heads
seq_len = 2048 , # sequence lengths
hierarchies = ( 1 , 2 , 8 ), # hierarchies - here we have 1x (like in a regular transformer), then 2x and 8x compressed hierarchical tokens that undergo their own transformer blocks. information is pooled into one hierarchy at each layer
window_sizes = ( 32 , 64 , None ) # local attention window sizes - the idea is that the higher hierarchies can pass distant information to the local one. None stands for full receptive field. Setting 0 would turn off attention at this hierarchy altogether (while token shift will still be in effect in each layer)
)
ids = torch . randint ( 0 , 20000 , ( 1 , 2048 ))
loss , _ = model ( ids , return_loss = True )
loss . backward ()
# after much training
logits = model ( ids )
En ne spécifiant pas hierarchies
et window_sizes
, vous utilisez par défaut un transformateur autorégressif classique avec une attention sur toute la longueur de la séquence.
# non-hierarchical transformer
model = HierarchicalTransformer (
num_tokens = 20000 ,
dim = 512 ,
depth = 8 ,
dim_head = 64 ,
heads = 8 ,
seq_len = 2048 ,
hierarchies = 1 , # implied 1 if not set
window_sizes = None # implied None (full sequence length) if not set
)
Maintenant quelque chose de plus complexe. Les expériences montrent que lorsque vous compressez les hiérarchies, vous avez besoin de dimensions de modèle plus grandes pour une capacité appropriée.
model = HierarchicalTransformer (
num_tokens = 256 ,
dim = ( 128 , 256 , 512 , 1024 ),
depth = 8 ,
seq_len = 1024 ,
use_flash_attn = True ,
ff_mult = ( 2 , 2 , 4 , 4 ),
dim_head = ( 16 , 32 , 64 , 64 ),
heads = ( 2 , 4 , 8 , 8 ),
hierarchies = ( 1 , 2 , 4 , 16 ),
hierarchical_stride = ( 1 , 1 , 1 , 8 ), # this would determine the stride when compressing, and when concatting the hierarchical tokens to the fine tokens, the past tokens will be repeated this amount of time. causality is not violated as using the trick from hourglass transformers where sequence is shifted by compression factor - 1. recommend sticking with 1 except for highly compressed hierarchies, as it becomes very uncompetitive with baseline and generations look off
window_sizes = ( 16 , 32 , 64 , None )
). cuda ()
# hierarchies
# 1x - dim 128 - attention (2 heads, 16 dim, receptive field 16)
# 2x - dim 256 - attention (4 heads, 32 dim, receptive field 32)
# 4x - dim 512 - attention (8 heads, 64 dim, receptive field 64)
# 8x - dim 1024 - attention (8 heads, 64 dim, receptive field of all)
bifurquez vers deux chemins parallèles, un pour les jetons hiérarchiques, l'autre pour les jetons simples et fins.
montrer que l'attention locale in fine + les jetons hiérarchiques peuvent se rapprocher de la ligne de base de l'attention totale
un simple dsconv semble suffisant pour fusionner pour 1 hiérarchie
définir automatiquement la taille de la fenêtre pour qu'elle soit la moitié de la longueur maximale de la séquence pour les hiérarchies fines et toutes
comprendre les effets de la simple mise en commun de tous les jetons fins + hiérarchiques avant la perte d'entropie croisée - pas beaucoup de différence
capacité totale à ajouter un nombre illimité de hiérarchies et à désigner quelle hiérarchie regroupera les informations des autres à des fins de prédiction
dimensions entièrement personnalisables dans les hiérarchies, car les hiérarchies supérieures nécessitent des dimensions de modèle plus grandes
ajouter des pertes de prophète pour les branches hiérarchiques
autoriser la répétition des jetons de hiérarchie pour les jetons fins à l'avenir, car la position peut avoir moins d'importance à mesure que l'on monte dans la hiérarchie. mais ce n'est pas une priorité, faites fonctionner les choses en premier - implémenté en tant que hierarchical_stride
permettre à certaines couches de s'appuyer uniquement sur le changement de jeton, sans y prêter attention
projections aléatoires + vq, comme cela a été fait dans l'article sur le modèle universel de parole du cerveau - pour le codage prédictif hiérarchique
permettre de spécifier quelle hiérarchie reçoit des informations des autres lors de la fusion, peut-être concevoir une attention spécialisée avec masquage, mais doit tenir compte des différentes dimensions du modèle à travers les hiérarchies
créer un bloc d'attention local simple, à utiliser dans toutes les hiérarchies
ajouter une attention flash à la bibliothèque d'attention locale
déterminer si l'attention peut être partagée entre les hiérarchies
faire un rapport wandb propre montrant une compression 2x sans trop de perte pour le niveau des caractères frwik8
essayez un compresseur basé sur l'auto-attention pour les hiérarchies 4 ou supérieures
construire un petit auto-encodeur en utilisant les intégrations de jetons comme entrée, au tout début du réseau, puis utiliser des cartes de fonctionnalités intermédiaires pour chaque réseau hiérarchique parallèle
L'idée la plus proche serait les transformateurs en sablier.
Et mon intérêt renouvelé pour les approches hiérarchiques est venu de cette lecture.
@article { Nawrot2021HierarchicalTA ,
title = { Hierarchical Transformers Are More Efficient Language Models } ,
author = { Piotr Nawrot and Szymon Tworkowski and Michal Tyrolski and Lukasz Kaiser and Yuhuai Wu and Christian Szegedy and Henryk Michalewski } ,
journal = { ArXiv } ,
year = { 2021 } ,
volume = { abs/2110.13711 }
}
@inproceedings { dao2022flashattention ,
title = { Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness } ,
author = { Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{'e}, Christopher } ,
booktitle = { Advances in Neural Information Processing Systems } ,
year = { 2022 }
}
@misc { su2021roformer ,
title = { RoFormer: Enhanced Transformer with Rotary Position Embedding } ,
author = { Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu } ,
year = { 2021 } ,
eprint = { 2104.09864 } ,
archivePrefix = { arXiv } ,
primaryClass = { cs.CL }
}
@inproceedings { Sun2022ALT ,
title = { A Length-Extrapolatable Transformer } ,
author = { Yutao Sun and Li Dong and Barun Patra and Shuming Ma and Shaohan Huang and Alon Benhaim and Vishrav Chaudhary and Xia Song and Furu Wei } ,
year = { 2022 }
}
@software { peng_bo_2021_5196578 ,
author = { PENG Bo } ,
title = { BlinkDL/RWKV-LM: 0.01 } ,
month = { aug } ,
year = { 2021 } ,
publisher = { Zenodo } ,
version = { 0.01 } ,
doi = { 10.5281/zenodo.5196578 } ,
url = { https://doi.org/10.5281/zenodo.5196578 }
}
@article { Piergiovanni2023Mirasol3BAM ,
title = { Mirasol3B: A Multimodal Autoregressive model for time-aligned and contextual modalities } ,
author = { A. J. Piergiovanni and Isaac Noble and Dahun Kim and Michael S. Ryoo and Victor Gomes and Anelia Angelova } ,
journal = { ArXiv } ,
year = { 2023 } ,
volume = { abs/2311.05698 } ,
url = { https://api.semanticscholar.org/CorpusID:265129010 }
}