Pytorch-Implementierung von SIREN – Implizite neuronale Darstellungen mit periodischer Aktivierungsfunktion
$ pip install siren-pytorch
Ein SIREN-basiertes mehrschichtiges neuronales Netzwerk
import torch
from torch import nn
from siren_pytorch import SirenNet
net = SirenNet (
dim_in = 2 , # input dimension, ex. 2d coor
dim_hidden = 256 , # hidden dimension
dim_out = 3 , # output dimension, ex. rgb value
num_layers = 5 , # number of layers
final_activation = nn . Sigmoid (), # activation of final layer (nn.Identity() for direct output)
w0_initial = 30. # different signals may require different omega_0 in the first layer - this is a hyperparameter
)
coor = torch . randn ( 1 , 2 )
net ( coor ) # (1, 3) <- rgb value
Eine SIREN-Schicht
import torch
from siren_pytorch import Siren
neuron = Siren (
dim_in = 3 ,
dim_out = 256
)
coor = torch . randn ( 1 , 3 )
neuron ( coor ) # (1, 256)
Sinus-Aktivierung (nur ein Wrapper um torch.sin
)
import torch
from siren_pytorch import Sine
act = Sine ( 1. )
coor = torch . randn ( 1 , 2 )
act ( coor )
Wrapper zum Trainieren eines bestimmten Bilds mit angegebener Höhe und Breite aus einem bestimmten SirenNet
und zum anschließenden Generieren.
import torch
from torch import nn
from siren_pytorch import SirenNet , SirenWrapper
net = SirenNet (
dim_in = 2 , # input dimension, ex. 2d coor
dim_hidden = 256 , # hidden dimension
dim_out = 3 , # output dimension, ex. rgb value
num_layers = 5 , # number of layers
w0_initial = 30. # different signals may require different omega_0 in the first layer - this is a hyperparameter
)
wrapper = SirenWrapper (
net ,
image_width = 256 ,
image_height = 256
)
img = torch . randn ( 1 , 3 , 256 , 256 )
loss = wrapper ( img )
loss . backward ()
# after much training ...
# simply invoke the wrapper without passing in anything
pred_img = wrapper () # (1, 3, 256, 256)
In einem neuen Artikel wird vorgeschlagen, dass der beste Weg, eine Sirene mit einem latenten Code zu konditionieren, darin besteht, den latenten Vektor durch ein Modulator-Feedforward-Netzwerk zu leiten, in dem der verborgene Zustand jeder Schicht elementweise mit der entsprechenden Schicht der Sirene multipliziert wird.
Sie können dies einfach nutzen, indem Sie im SirenWrapper
ein zusätzliches Schlüsselwort latent_dim
festlegen
import torch
from torch import nn
from siren_pytorch import SirenNet , SirenWrapper
net = SirenNet (
dim_in = 2 , # input dimension, ex. 2d coor
dim_hidden = 256 , # hidden dimension
dim_out = 3 , # output dimension, ex. rgb value
num_layers = 5 , # number of layers
w0_initial = 30. # different signals may require different omega_0 in the first layer - this is a hyperparameter
)
wrapper = SirenWrapper (
net ,
latent_dim = 512 ,
image_width = 256 ,
image_height = 256
)
latent = nn . Parameter ( torch . zeros ( 512 ). normal_ ( 0 , 1e-2 ))
img = torch . randn ( 1 , 3 , 256 , 256 )
loss = wrapper ( img , latent = latent )
loss . backward ()
# after much training ...
# simply invoke the wrapper without passing in anything
pred_img = wrapper ( latent = latent ) # (1, 3, 256, 256)
@misc { sitzmann2020implicit ,
title = { Implicit Neural Representations with Periodic Activation Functions } ,
author = { Vincent Sitzmann and Julien N. P. Martel and Alexander W. Bergman and David B. Lindell and Gordon Wetzstein } ,
year = { 2020 } ,
eprint = { 2006.09661 } ,
archivePrefix = { arXiv } ,
primaryClass = { cs.CV }
}
@misc { mehta2021modulated ,
title = { Modulated Periodic Activations for Generalizable Local Functional Representations } ,
author = { Ishit Mehta and Michaël Gharbi and Connelly Barnes and Eli Shechtman and Ravi Ramamoorthi and Manmohan Chandraker } ,
year = { 2021 } ,
eprint = { 2104.03960 } ,
archivePrefix = { arXiv } ,
primaryClass = { cs.CV }
}