siren pytorch
0.1.7
SIREN의 Pytorch 구현 - 주기적 활성화 기능을 사용한 암시적 신경 표현
$ pip install siren-pytorch
SIREN 기반 다층 신경망
import torch
from torch import nn
from siren_pytorch import SirenNet
net = SirenNet (
dim_in = 2 , # input dimension, ex. 2d coor
dim_hidden = 256 , # hidden dimension
dim_out = 3 , # output dimension, ex. rgb value
num_layers = 5 , # number of layers
final_activation = nn . Sigmoid (), # activation of final layer (nn.Identity() for direct output)
w0_initial = 30. # different signals may require different omega_0 in the first layer - this is a hyperparameter
)
coor = torch . randn ( 1 , 2 )
net ( coor ) # (1, 3) <- rgb value
SIREN 레이어 1개
import torch
from siren_pytorch import Siren
neuron = Siren (
dim_in = 3 ,
dim_out = 256
)
coor = torch . randn ( 1 , 3 )
neuron ( coor ) # (1, 256)
사인 활성화( torch.sin
주위의 래퍼)
import torch
from siren_pytorch import Sine
act = Sine ( 1. )
coor = torch . randn ( 1 , 2 )
act ( coor )
지정된 SirenNet
에서 지정된 높이와 너비의 특정 이미지를 학습한 후 이후에 생성하는 래퍼입니다.
import torch
from torch import nn
from siren_pytorch import SirenNet , SirenWrapper
net = SirenNet (
dim_in = 2 , # input dimension, ex. 2d coor
dim_hidden = 256 , # hidden dimension
dim_out = 3 , # output dimension, ex. rgb value
num_layers = 5 , # number of layers
w0_initial = 30. # different signals may require different omega_0 in the first layer - this is a hyperparameter
)
wrapper = SirenWrapper (
net ,
image_width = 256 ,
image_height = 256
)
img = torch . randn ( 1 , 3 , 256 , 256 )
loss = wrapper ( img )
loss . backward ()
# after much training ...
# simply invoke the wrapper without passing in anything
pred_img = wrapper () # (1, 3, 256, 256)
새로운 논문에서는 잠재 코드로 사이렌을 조정하는 가장 좋은 방법은 변조기 피드포워드 네트워크를 통해 잠재 벡터를 전달하는 것이라고 제안합니다. 여기서 각 레이어의 숨겨진 상태는 사이렌의 해당 레이어와 요소별로 곱해집니다.
SirenWrapper
에서 latent_dim
추가 키워드를 설정하면 간단히 사용할 수 있습니다.
import torch
from torch import nn
from siren_pytorch import SirenNet , SirenWrapper
net = SirenNet (
dim_in = 2 , # input dimension, ex. 2d coor
dim_hidden = 256 , # hidden dimension
dim_out = 3 , # output dimension, ex. rgb value
num_layers = 5 , # number of layers
w0_initial = 30. # different signals may require different omega_0 in the first layer - this is a hyperparameter
)
wrapper = SirenWrapper (
net ,
latent_dim = 512 ,
image_width = 256 ,
image_height = 256
)
latent = nn . Parameter ( torch . zeros ( 512 ). normal_ ( 0 , 1e-2 ))
img = torch . randn ( 1 , 3 , 256 , 256 )
loss = wrapper ( img , latent = latent )
loss . backward ()
# after much training ...
# simply invoke the wrapper without passing in anything
pred_img = wrapper ( latent = latent ) # (1, 3, 256, 256)
@misc { sitzmann2020implicit ,
title = { Implicit Neural Representations with Periodic Activation Functions } ,
author = { Vincent Sitzmann and Julien N. P. Martel and Alexander W. Bergman and David B. Lindell and Gordon Wetzstein } ,
year = { 2020 } ,
eprint = { 2006.09661 } ,
archivePrefix = { arXiv } ,
primaryClass = { cs.CV }
}
@misc { mehta2021modulated ,
title = { Modulated Periodic Activations for Generalizable Local Functional Representations } ,
author = { Ishit Mehta and Michaël Gharbi and Connelly Barnes and Eli Shechtman and Ravi Ramamoorthi and Manmohan Chandraker } ,
year = { 2021 } ,
eprint = { 2104.03960 } ,
archivePrefix = { arXiv } ,
primaryClass = { cs.CV }
}