g mlp pytorch
0.1.5
Implementasi gMLP, pengganti Transformers yang semuanya MLP, di Pytorch
$ pip install g-mlp-pytorch
Untuk pemodelan bahasa bertopeng
import torch
from torch import nn
from g_mlp_pytorch import gMLP
model = gMLP (
num_tokens = 20000 ,
dim = 512 ,
depth = 6 ,
seq_len = 256 ,
circulant_matrix = True , # use circulant weight matrix for linear increase in parameters in respect to sequence length
act = nn . Tanh () # activation for spatial gate (defaults to identity)
)
x = torch . randint ( 0 , 20000 , ( 1 , 256 ))
logits = model ( x ) # (1, 256, 20000)
Untuk klasifikasi gambar
import torch
from g_mlp_pytorch import gMLPVision
model = gMLPVision (
image_size = 256 ,
patch_size = 16 ,
num_classes = 1000 ,
dim = 512 ,
depth = 6
)
img = torch . randn ( 1 , 3 , 256 , 256 )
logits = model ( img ) # (1, 1000)
Anda juga dapat menambahkan sedikit perhatian (berkepala satu) untuk meningkatkan kinerja, seperti yang disebutkan dalam makalah sebagai aMLP
, dengan tambahan satu kata kunci tambahan attn_dim
. Hal ini berlaku untuk gMLPVision
dan gMLP
import torch
from g_mlp_pytorch import gMLPVision
model = gMLPVision (
image_size = 256 ,
patch_size = 16 ,
num_classes = 1000 ,
dim = 512 ,
depth = 6 ,
attn_dim = 64
)
img = torch . randn ( 1 , 3 , 256 , 256 )
pred = model ( img ) # (1, 1000)
Gambar non-persegi dan ukuran tambalan
import torch
from g_mlp_pytorch import gMLPVision
model = gMLPVision (
image_size = ( 256 , 128 ),
patch_size = ( 16 , 8 ),
num_classes = 1000 ,
dim = 512 ,
depth = 6 ,
attn_dim = 64
)
img = torch . randn ( 1 , 3 , 256 , 128 )
pred = model ( img ) # (1, 1000)
Seorang peneliti independen mengusulkan penggunaan pendekatan multi-kepala untuk gMLP dalam postingan blog di Zhihu. Untuk melakukannya, cukup atur heads
menjadi lebih besar dari 1
import torch
from torch import nn
from g_mlp_pytorch import gMLP
model = gMLP (
num_tokens = 20000 ,
dim = 512 ,
depth = 6 ,
seq_len = 256 ,
causal = True ,
circulant_matrix = True ,
heads = 4 # 4 heads
)
x = torch . randint ( 0 , 20000 , ( 1 , 256 ))
logits = model ( x ) # (1, 256, 20000)
@misc { liu2021pay ,
title = { Pay Attention to MLPs } ,
author = { Hanxiao Liu and Zihang Dai and David R. So and Quoc V. Le } ,
year = { 2021 } ,
eprint = { 2105.08050 } ,
archivePrefix = { arXiv } ,
primaryClass = { cs.LG }
}
@software { peng_bo_2021_5196578 ,
author = { PENG Bo } ,
title = { BlinkDL/RWKV-LM: 0.01 } ,
month = aug,
year = 2021 ,
publisher = { Zenodo } ,
version = { 0.01 } ,
doi = { 10.5281/zenodo.5196578 } ,
url = { https://doi.org/10.5281/zenodo.5196578%7D
}