siren pytorch
0.1.7
SIREN の Pytorch 実装 - 定期的なアクティブ化機能を備えた暗黙的なニューラル表現
$ pip install siren-pytorch
SIRENベースの多層ニューラルネットワーク
import torch
from torch import nn
from siren_pytorch import SirenNet
net = SirenNet (
dim_in = 2 , # input dimension, ex. 2d coor
dim_hidden = 256 , # hidden dimension
dim_out = 3 , # output dimension, ex. rgb value
num_layers = 5 , # number of layers
final_activation = nn . Sigmoid (), # activation of final layer (nn.Identity() for direct output)
w0_initial = 30. # different signals may require different omega_0 in the first layer - this is a hyperparameter
)
coor = torch . randn ( 1 , 2 )
net ( coor ) # (1, 3) <- rgb value
1 つの SIREN レイヤー
import torch
from siren_pytorch import Siren
neuron = Siren (
dim_in = 3 ,
dim_out = 256
)
coor = torch . randn ( 1 , 3 )
neuron ( coor ) # (1, 256)
Sine アクティベーション ( torch.sin
の単なるラッパー)
import torch
from siren_pytorch import Sine
act = Sine ( 1. )
coor = torch . randn ( 1 , 2 )
act ( coor )
指定されたSirenNet
からの指定された高さと幅の特定の画像でトレーニングし、その後生成するラッパー。
import torch
from torch import nn
from siren_pytorch import SirenNet , SirenWrapper
net = SirenNet (
dim_in = 2 , # input dimension, ex. 2d coor
dim_hidden = 256 , # hidden dimension
dim_out = 3 , # output dimension, ex. rgb value
num_layers = 5 , # number of layers
w0_initial = 30. # different signals may require different omega_0 in the first layer - this is a hyperparameter
)
wrapper = SirenWrapper (
net ,
image_width = 256 ,
image_height = 256
)
img = torch . randn ( 1 , 3 , 256 , 256 )
loss = wrapper ( img )
loss . backward ()
# after much training ...
# simply invoke the wrapper without passing in anything
pred_img = wrapper () # (1, 3, 256, 256)
新しい論文では、潜在コードでサイレンを調整する最良の方法は、各層の隠れ状態が要素ごとにサイレンの対応する層と乗算される変調器フィードフォワード ネットワークに潜在ベクトルを渡すことであると提案しています。
これは、 SirenWrapper
で追加のキーワードlatent_dim
設定するだけで使用できます。
import torch
from torch import nn
from siren_pytorch import SirenNet , SirenWrapper
net = SirenNet (
dim_in = 2 , # input dimension, ex. 2d coor
dim_hidden = 256 , # hidden dimension
dim_out = 3 , # output dimension, ex. rgb value
num_layers = 5 , # number of layers
w0_initial = 30. # different signals may require different omega_0 in the first layer - this is a hyperparameter
)
wrapper = SirenWrapper (
net ,
latent_dim = 512 ,
image_width = 256 ,
image_height = 256
)
latent = nn . Parameter ( torch . zeros ( 512 ). normal_ ( 0 , 1e-2 ))
img = torch . randn ( 1 , 3 , 256 , 256 )
loss = wrapper ( img , latent = latent )
loss . backward ()
# after much training ...
# simply invoke the wrapper without passing in anything
pred_img = wrapper ( latent = latent ) # (1, 3, 256, 256)
@misc { sitzmann2020implicit ,
title = { Implicit Neural Representations with Periodic Activation Functions } ,
author = { Vincent Sitzmann and Julien N. P. Martel and Alexander W. Bergman and David B. Lindell and Gordon Wetzstein } ,
year = { 2020 } ,
eprint = { 2006.09661 } ,
archivePrefix = { arXiv } ,
primaryClass = { cs.CV }
}
@misc { mehta2021modulated ,
title = { Modulated Periodic Activations for Generalizable Local Functional Representations } ,
author = { Ishit Mehta and Michaël Gharbi and Connelly Barnes and Eli Shechtman and Ravi Ramamoorthi and Manmohan Chandraker } ,
year = { 2021 } ,
eprint = { 2104.03960 } ,
archivePrefix = { arXiv } ,
primaryClass = { cs.CV }
}