g mlp pytorch
0.1.5
Pytorch에서 Transformers의 모든 MLP를 대체하는 gMLP 구현
$ pip install g-mlp-pytorch
마스크된 언어 모델링의 경우
import torch
from torch import nn
from g_mlp_pytorch import gMLP
model = gMLP (
num_tokens = 20000 ,
dim = 512 ,
depth = 6 ,
seq_len = 256 ,
circulant_matrix = True , # use circulant weight matrix for linear increase in parameters in respect to sequence length
act = nn . Tanh () # activation for spatial gate (defaults to identity)
)
x = torch . randint ( 0 , 20000 , ( 1 , 256 ))
logits = model ( x ) # (1, 256, 20000)
이미지 분류용
import torch
from g_mlp_pytorch import gMLPVision
model = gMLPVision (
image_size = 256 ,
patch_size = 16 ,
num_classes = 1000 ,
dim = 512 ,
depth = 6
)
img = torch . randn ( 1 , 3 , 256 , 256 )
logits = model ( img ) # (1, 1000)
또한 논문에서 aMLP
로 언급한 것처럼 attn_dim
이라는 하나의 추가 키워드를 추가하여 성능을 향상시키기 위해 약간의 주의(단방향)를 추가할 수도 있습니다. 이는 gMLPVision
과 gMLP
모두에 적용됩니다.
import torch
from g_mlp_pytorch import gMLPVision
model = gMLPVision (
image_size = 256 ,
patch_size = 16 ,
num_classes = 1000 ,
dim = 512 ,
depth = 6 ,
attn_dim = 64
)
img = torch . randn ( 1 , 3 , 256 , 256 )
pred = model ( img ) # (1, 1000)
정사각형이 아닌 이미지 및 패치 크기
import torch
from g_mlp_pytorch import gMLPVision
model = gMLPVision (
image_size = ( 256 , 128 ),
patch_size = ( 16 , 8 ),
num_classes = 1000 ,
dim = 512 ,
depth = 6 ,
attn_dim = 64
)
img = torch . randn ( 1 , 3 , 256 , 128 )
pred = model ( img ) # (1, 1000)
독립적인 연구원은 Zhihu의 블로그 게시물에서 gMLP에 대한 다중 접근 방식을 사용할 것을 제안합니다. 그렇게 하려면 heads
1
보다 크게 설정하세요.
import torch
from torch import nn
from g_mlp_pytorch import gMLP
model = gMLP (
num_tokens = 20000 ,
dim = 512 ,
depth = 6 ,
seq_len = 256 ,
causal = True ,
circulant_matrix = True ,
heads = 4 # 4 heads
)
x = torch . randint ( 0 , 20000 , ( 1 , 256 ))
logits = model ( x ) # (1, 256, 20000)
@misc { liu2021pay ,
title = { Pay Attention to MLPs } ,
author = { Hanxiao Liu and Zihang Dai and David R. So and Quoc V. Le } ,
year = { 2021 } ,
eprint = { 2105.08050 } ,
archivePrefix = { arXiv } ,
primaryClass = { cs.LG }
}
@software { peng_bo_2021_5196578 ,
author = { PENG Bo } ,
title = { BlinkDL/RWKV-LM: 0.01 } ,
month = aug,
year = 2021 ,
publisher = { Zenodo } ,
version = { 0.01 } ,
doi = { 10.5281/zenodo.5196578 } ,
url = { https://doi.org/10.5281/zenodo.5196578%7D
}