g mlp pytorch
0.1.5
Implementierung von gMLP, einem reinen MLP-Ersatz für Transformers, in Pytorch
$ pip install g-mlp-pytorch
Zur maskierten Sprachmodellierung
import torch
from torch import nn
from g_mlp_pytorch import gMLP
model = gMLP (
num_tokens = 20000 ,
dim = 512 ,
depth = 6 ,
seq_len = 256 ,
circulant_matrix = True , # use circulant weight matrix for linear increase in parameters in respect to sequence length
act = nn . Tanh () # activation for spatial gate (defaults to identity)
)
x = torch . randint ( 0 , 20000 , ( 1 , 256 ))
logits = model ( x ) # (1, 256, 20000)
Zur Bildklassifizierung
import torch
from g_mlp_pytorch import gMLPVision
model = gMLPVision (
image_size = 256 ,
patch_size = 16 ,
num_classes = 1000 ,
dim = 512 ,
depth = 6
)
img = torch . randn ( 1 , 3 , 256 , 256 )
logits = model ( img ) # (1, 1000)
Sie können auch ein kleines Maß an Aufmerksamkeit (einköpfig) hinzufügen, um die Leistung zu steigern, wie im Artikel als aMLP
erwähnt, indem Sie ein zusätzliches Schlüsselwort attn_dim
hinzufügen. Dies gilt sowohl für gMLPVision
als auch für gMLP
import torch
from g_mlp_pytorch import gMLPVision
model = gMLPVision (
image_size = 256 ,
patch_size = 16 ,
num_classes = 1000 ,
dim = 512 ,
depth = 6 ,
attn_dim = 64
)
img = torch . randn ( 1 , 3 , 256 , 256 )
pred = model ( img ) # (1, 1000)
Nicht quadratische Bilder und Patchgrößen
import torch
from g_mlp_pytorch import gMLPVision
model = gMLPVision (
image_size = ( 256 , 128 ),
patch_size = ( 16 , 8 ),
num_classes = 1000 ,
dim = 512 ,
depth = 6 ,
attn_dim = 64
)
img = torch . randn ( 1 , 3 , 256 , 128 )
pred = model ( img ) # (1, 1000)
Ein unabhängiger Forscher schlägt in einem Blogbeitrag auf Zhihu die Verwendung eines mehrköpfigen Ansatzes für gMLPs vor. Stellen Sie dazu einfach heads
auf einen Wert größer als 1
ein
import torch
from torch import nn
from g_mlp_pytorch import gMLP
model = gMLP (
num_tokens = 20000 ,
dim = 512 ,
depth = 6 ,
seq_len = 256 ,
causal = True ,
circulant_matrix = True ,
heads = 4 # 4 heads
)
x = torch . randint ( 0 , 20000 , ( 1 , 256 ))
logits = model ( x ) # (1, 256, 20000)
@misc { liu2021pay ,
title = { Pay Attention to MLPs } ,
author = { Hanxiao Liu and Zihang Dai and David R. So and Quoc V. Le } ,
year = { 2021 } ,
eprint = { 2105.08050 } ,
archivePrefix = { arXiv } ,
primaryClass = { cs.LG }
}
@software { peng_bo_2021_5196578 ,
author = { PENG Bo } ,
title = { BlinkDL/RWKV-LM: 0.01 } ,
month = aug,
year = 2021 ,
publisher = { Zenodo } ,
version = { 0.01 } ,
doi = { 10.5281/zenodo.5196578 } ,
url = { https://doi.org/10.5281/zenodo.5196578%7D
}