Implementación de Pytorch de SIREN: representaciones neuronales implícitas con función de activación periódica
$ pip install siren-pytorch
Una red neuronal multicapa basada en SIREN
import torch
from torch import nn
from siren_pytorch import SirenNet
net = SirenNet (
dim_in = 2 , # input dimension, ex. 2d coor
dim_hidden = 256 , # hidden dimension
dim_out = 3 , # output dimension, ex. rgb value
num_layers = 5 , # number of layers
final_activation = nn . Sigmoid (), # activation of final layer (nn.Identity() for direct output)
w0_initial = 30. # different signals may require different omega_0 in the first layer - this is a hyperparameter
coor = torch . randn ( 1 , 2 )
net ( coor ) # (1, 3) <- rgb value
Una capa de SIRENA
import torch
from siren_pytorch import Siren
neuron = Siren (
dim_in = 3 ,
dim_out = 256
coor = torch . randn ( 1 , 3 )
neuron ( coor ) # (1, 256)
Activación sinusoidal (solo una envoltura alrededor de torch.sin
import torch
from siren_pytorch import Sine
act = Sine ( 1. )
coor = torch . randn ( 1 , 2 )
act ( coor )
Wrapper para entrenar en una imagen específica de altura y ancho específicos de un SirenNet
determinado y luego generarla posteriormente.
import torch
from torch import nn
from siren_pytorch import SirenNet , SirenWrapper
net = SirenNet (
dim_in = 2 , # input dimension, ex. 2d coor
dim_hidden = 256 , # hidden dimension
dim_out = 3 , # output dimension, ex. rgb value
num_layers = 5 , # number of layers
w0_initial = 30. # different signals may require different omega_0 in the first layer - this is a hyperparameter
wrapper = SirenWrapper (
net ,
image_width = 256 ,
image_height = 256
img = torch . randn ( 1 , 3 , 256 , 256 )
loss = wrapper ( img )
loss . backward ()
# after much training ...
# simply invoke the wrapper without passing in anything
pred_img = wrapper () # (1, 3, 256, 256)
Un nuevo artículo propone que la mejor manera de acondicionar una sirena con un código latente es pasar el vector latente a través de una red de avance del modulador, donde el estado oculto de cada capa se multiplica por elementos con la capa correspondiente de la sirena.
Puede usar esto simplemente configurando una palabra clave adicional latent_dim
en SirenWrapper
import torch
from torch import nn
from siren_pytorch import SirenNet , SirenWrapper
net = SirenNet (
dim_in = 2 , # input dimension, ex. 2d coor
dim_hidden = 256 , # hidden dimension
dim_out = 3 , # output dimension, ex. rgb value
num_layers = 5 , # number of layers
w0_initial = 30. # different signals may require different omega_0 in the first layer - this is a hyperparameter
wrapper = SirenWrapper (
net ,
latent_dim = 512 ,
image_width = 256 ,
image_height = 256
latent = nn . Parameter ( torch . zeros ( 512 ). normal_ ( 0 , 1e-2 ))
img = torch . randn ( 1 , 3 , 256 , 256 )
loss = wrapper ( img , latent = latent )
loss . backward ()
# after much training ...
# simply invoke the wrapper without passing in anything
pred_img = wrapper ( latent = latent ) # (1, 3, 256, 256)
@misc { sitzmann2020implicit ,
title = { Implicit Neural Representations with Periodic Activation Functions } ,
author = { Vincent Sitzmann and Julien N. P. Martel and Alexander W. Bergman and David B. Lindell and Gordon Wetzstein } ,
year = { 2020 } ,
eprint = { 2006.09661 } ,
archivePrefix = { arXiv } ,
primaryClass = { cs.CV }
@misc { mehta2021modulated ,
title = { Modulated Periodic Activations for Generalizable Local Functional Representations } ,
author = { Ishit Mehta and Michaël Gharbi and Connelly Barnes and Eli Shechtman and Ravi Ramamoorthi and Manmohan Chandraker } ,
year = { 2021 } ,
eprint = { 2104.03960 } ,
archivePrefix = { arXiv } ,
primaryClass = { cs.CV }