siren pytorch
0.1.7
การใช้งาน Pytorch ของ SIREN - การแสดงระบบประสาทโดยนัยพร้อมฟังก์ชันการเปิดใช้งานเป็นระยะ
$ pip install siren-pytorch
โครงข่ายประสาทเทียมหลายชั้นที่ใช้ SIREN
import torch
from torch import nn
from siren_pytorch import SirenNet
net = SirenNet (
dim_in = 2 , # input dimension, ex. 2d coor
dim_hidden = 256 , # hidden dimension
dim_out = 3 , # output dimension, ex. rgb value
num_layers = 5 , # number of layers
final_activation = nn . Sigmoid (), # activation of final layer (nn.Identity() for direct output)
w0_initial = 30. # different signals may require different omega_0 in the first layer - this is a hyperparameter
)
coor = torch . randn ( 1 , 2 )
net ( coor ) # (1, 3) <- rgb value
ไซเรนหนึ่งชั้น
import torch
from siren_pytorch import Siren
neuron = Siren (
dim_in = 3 ,
dim_out = 256
)
coor = torch . randn ( 1 , 3 )
neuron ( coor ) # (1, 256)
การเปิดใช้งานไซน์ (เพียงเสื้อคลุมรอบ torch.sin
)
import torch
from siren_pytorch import Sine
act = Sine ( 1. )
coor = torch . randn ( 1 , 2 )
act ( coor )
Wrapper เพื่อฝึกฝนรูปภาพเฉพาะที่มีความสูงและความกว้างที่ระบุจาก SirenNet
ที่กำหนด จากนั้นจึงสร้างในภายหลัง
import torch
from torch import nn
from siren_pytorch import SirenNet , SirenWrapper
net = SirenNet (
dim_in = 2 , # input dimension, ex. 2d coor
dim_hidden = 256 , # hidden dimension
dim_out = 3 , # output dimension, ex. rgb value
num_layers = 5 , # number of layers
w0_initial = 30. # different signals may require different omega_0 in the first layer - this is a hyperparameter
)
wrapper = SirenWrapper (
net ,
image_width = 256 ,
image_height = 256
)
img = torch . randn ( 1 , 3 , 256 , 256 )
loss = wrapper ( img )
loss . backward ()
# after much training ...
# simply invoke the wrapper without passing in anything
pred_img = wrapper () # (1, 3, 256, 256)
บทความใหม่เสนอว่าวิธีที่ดีที่สุดในการกำหนดเงื่อนไขไซเรนด้วยรหัสแฝงคือการส่งเวกเตอร์แฝงผ่านเครือข่ายฟีดฟอร์เวิร์ดของโมดูเลเตอร์ โดยที่สถานะที่ซ่อนอยู่ของแต่ละเลเยอร์จะถูกคูณตามองค์ประกอบด้วยเลเยอร์ที่สอดคล้องกันของไซเรน
คุณสามารถใช้สิ่งนี้ได้ง่ายๆ โดยการตั้งค่าคำหลักเพิ่มเติม latent_dim
บน SirenWrapper
import torch
from torch import nn
from siren_pytorch import SirenNet , SirenWrapper
net = SirenNet (
dim_in = 2 , # input dimension, ex. 2d coor
dim_hidden = 256 , # hidden dimension
dim_out = 3 , # output dimension, ex. rgb value
num_layers = 5 , # number of layers
w0_initial = 30. # different signals may require different omega_0 in the first layer - this is a hyperparameter
)
wrapper = SirenWrapper (
net ,
latent_dim = 512 ,
image_width = 256 ,
image_height = 256
)
latent = nn . Parameter ( torch . zeros ( 512 ). normal_ ( 0 , 1e-2 ))
img = torch . randn ( 1 , 3 , 256 , 256 )
loss = wrapper ( img , latent = latent )
loss . backward ()
# after much training ...
# simply invoke the wrapper without passing in anything
pred_img = wrapper ( latent = latent ) # (1, 3, 256, 256)
@misc { sitzmann2020implicit ,
title = { Implicit Neural Representations with Periodic Activation Functions } ,
author = { Vincent Sitzmann and Julien N. P. Martel and Alexander W. Bergman and David B. Lindell and Gordon Wetzstein } ,
year = { 2020 } ,
eprint = { 2006.09661 } ,
archivePrefix = { arXiv } ,
primaryClass = { cs.CV }
}
@misc { mehta2021modulated ,
title = { Modulated Periodic Activations for Generalizable Local Functional Representations } ,
author = { Ishit Mehta and Michaël Gharbi and Connelly Barnes and Eli Shechtman and Ravi Ramamoorthi and Manmohan Chandraker } ,
year = { 2021 } ,
eprint = { 2104.03960 } ,
archivePrefix = { arXiv } ,
primaryClass = { cs.CV }
}