g mlp pytorch
0.1.5
在 Pytorch 中实现 gMLP,它是 Transformer 的全 MLP 替代品
$ pip install g-mlp-pytorch
用于掩码语言建模
import torch
from torch import nn
from g_mlp_pytorch import gMLP
model = gMLP (
num_tokens = 20000 ,
dim = 512 ,
depth = 6 ,
seq_len = 256 ,
circulant_matrix = True , # use circulant weight matrix for linear increase in parameters in respect to sequence length
act = nn . Tanh () # activation for spatial gate (defaults to identity)
)
x = torch . randint ( 0 , 20000 , ( 1 , 256 ))
logits = model ( x ) # (1, 256, 20000)
用于图像分类
import torch
from g_mlp_pytorch import gMLPVision
model = gMLPVision (
image_size = 256 ,
patch_size = 16 ,
num_classes = 1000 ,
dim = 512 ,
depth = 6
)
img = torch . randn ( 1 , 3 , 256 , 256 )
logits = model ( img ) # (1, 1000)
您还可以添加少量注意力(单头注意力)来提高性能,如论文中提到的aMLP
,并添加一个额外的关键字attn_dim
。这适用于gMLPVision
和gMLP
import torch
from g_mlp_pytorch import gMLPVision
model = gMLPVision (
image_size = 256 ,
patch_size = 16 ,
num_classes = 1000 ,
dim = 512 ,
depth = 6 ,
attn_dim = 64
)
img = torch . randn ( 1 , 3 , 256 , 256 )
pred = model ( img ) # (1, 1000)
非方形图像和补丁大小
import torch
from g_mlp_pytorch import gMLPVision
model = gMLPVision (
image_size = ( 256 , 128 ),
patch_size = ( 16 , 8 ),
num_classes = 1000 ,
dim = 512 ,
depth = 6 ,
attn_dim = 64
)
img = torch . randn ( 1 , 3 , 256 , 128 )
pred = model ( img ) # (1, 1000)
一位独立研究人员在知乎上的一篇博文中建议对 gMLP 使用多头方法。为此,只需将heads
设置为大于1
import torch
from torch import nn
from g_mlp_pytorch import gMLP
model = gMLP (
num_tokens = 20000 ,
dim = 512 ,
depth = 6 ,
seq_len = 256 ,
causal = True ,
circulant_matrix = True ,
heads = 4 # 4 heads
)
x = torch . randint ( 0 , 20000 , ( 1 , 256 ))
logits = model ( x ) # (1, 256, 20000)
@misc { liu2021pay ,
title = { Pay Attention to MLPs } ,
author = { Hanxiao Liu and Zihang Dai and David R. So and Quoc V. Le } ,
year = { 2021 } ,
eprint = { 2105.08050 } ,
archivePrefix = { arXiv } ,
primaryClass = { cs.LG }
}
@software { peng_bo_2021_5196578 ,
author = { PENG Bo } ,
title = { BlinkDL/RWKV-LM: 0.01 } ,
month = aug,
year = 2021 ,
publisher = { Zenodo } ,
version = { 0.01 } ,
doi = { 10.5281/zenodo.5196578 } ,
url = { https://doi.org/10.5281/zenodo.5196578%7D
}