g mlp pytorch
0.1.5
Implementação de gMLP, um substituto totalmente MLP para Transformers, em Pytorch
$ pip install g-mlp-pytorch
Para modelagem de linguagem mascarada
import torch
from torch import nn
from g_mlp_pytorch import gMLP
model = gMLP (
num_tokens = 20000 ,
dim = 512 ,
depth = 6 ,
seq_len = 256 ,
circulant_matrix = True , # use circulant weight matrix for linear increase in parameters in respect to sequence length
act = nn . Tanh () # activation for spatial gate (defaults to identity)
)
x = torch . randint ( 0 , 20000 , ( 1 , 256 ))
logits = model ( x ) # (1, 256, 20000)
Para classificação de imagens
import torch
from g_mlp_pytorch import gMLPVision
model = gMLPVision (
image_size = 256 ,
patch_size = 16 ,
num_classes = 1000 ,
dim = 512 ,
depth = 6
)
img = torch . randn ( 1 , 3 , 256 , 256 )
logits = model ( img ) # (1, 1000)
Você também pode adicionar um pouco de atenção (uma cabeça) para aumentar o desempenho, conforme mencionado no artigo como aMLP
, com a adição de uma palavra-chave extra attn_dim
. Isso se aplica tanto ao gMLPVision
quanto ao gMLP
import torch
from g_mlp_pytorch import gMLPVision
model = gMLPVision (
image_size = 256 ,
patch_size = 16 ,
num_classes = 1000 ,
dim = 512 ,
depth = 6 ,
attn_dim = 64
)
img = torch . randn ( 1 , 3 , 256 , 256 )
pred = model ( img ) # (1, 1000)
Imagens não quadradas e tamanhos de patch
import torch
from g_mlp_pytorch import gMLPVision
model = gMLPVision (
image_size = ( 256 , 128 ),
patch_size = ( 16 , 8 ),
num_classes = 1000 ,
dim = 512 ,
depth = 6 ,
attn_dim = 64
)
img = torch . randn ( 1 , 3 , 256 , 128 )
pred = model ( img ) # (1, 1000)
Um pesquisador independente propõe o uso de uma abordagem multifacetada para gMLPs em uma postagem de blog no Zhihu. Para fazer isso, basta definir heads
para serem maiores que 1
import torch
from torch import nn
from g_mlp_pytorch import gMLP
model = gMLP (
num_tokens = 20000 ,
dim = 512 ,
depth = 6 ,
seq_len = 256 ,
causal = True ,
circulant_matrix = True ,
heads = 4 # 4 heads
)
x = torch . randint ( 0 , 20000 , ( 1 , 256 ))
logits = model ( x ) # (1, 256, 20000)
@misc { liu2021pay ,
title = { Pay Attention to MLPs } ,
author = { Hanxiao Liu and Zihang Dai and David R. So and Quoc V. Le } ,
year = { 2021 } ,
eprint = { 2105.08050 } ,
archivePrefix = { arXiv } ,
primaryClass = { cs.LG }
}
@software { peng_bo_2021_5196578 ,
author = { PENG Bo } ,
title = { BlinkDL/RWKV-LM: 0.01 } ,
month = aug,
year = 2021 ,
publisher = { Zenodo } ,
version = { 0.01 } ,
doi = { 10.5281/zenodo.5196578 } ,
url = { https://doi.org/10.5281/zenodo.5196578%7D
}