bit diffusion
0.1.4
تنفيذ Bit Diffusion، محاولة مجموعة هينتون لنشر تقليل الضوضاء المنفصل، في Pytorch
يبدو أنهم أخطأوا علامة النص، لكن اتجاه البحث لا يزال يبدو واعدًا. أعتقد أن المستودع النظيف سيقدم لمجتمع البحث الكثير من الفوائد لأولئك الذين يتفرعون من هنا.
$ pip install bit-diffusion
from bit_diffusion import Unet , Trainer , BitDiffusion
model = Unet (
dim = 32 ,
channels = 3 ,
dim_mults = ( 1 , 2 , 4 , 8 ),
). cuda ()
bit_diffusion = BitDiffusion (
model ,
image_size = 128 ,
timesteps = 100 ,
time_difference = 0.1 , # they found in the paper that at lower number of timesteps, a time difference during sampling of greater than 0 helps FID. as timesteps increases, this time difference can be set to 0 as it does not help
use_ddim = True # use ddim
). cuda ()
trainer = Trainer (
bit_diffusion ,
'/path/to/your/data' , # path to your folder of images
results_folder = './results' , # where to save results
num_samples = 16 , # number of samples
train_batch_size = 4 , # training batch size
gradient_accumulate_every = 4 , # gradient accumulation
train_lr = 1e-4 , # learning rate
save_and_sample_every = 1000 , # how often to save and sample
train_num_steps = 700000 , # total training steps
ema_decay = 0.995 , # exponential moving average decay
)
trainer . train ()
سيتم حفظ النتائج بشكل دوري في مجلد ./results
إذا كنت ترغب في تجربة فئة Unet
و BitDiffusion
خارج Trainer
import torch
from bit_diffusion import Unet , BitDiffusion
model = Unet (
dim = 64 ,
dim_mults = ( 1 , 2 , 4 , 8 )
)
bit_diffusion = BitDiffusion (
model ,
image_size = 128 ,
timesteps = 1000
)
training_images = torch . randn ( 8 , 3 , 128 , 128 ) # images are normalized from 0 to 1
loss = bit_diffusion ( training_images )
loss . backward ()
# after a lot of training
sampled_images = bit_diffusion . sample ( batch_size = 4 )
sampled_images . shape # (4, 3, 128, 128)
@article { Chen2022AnalogBG ,
title = { Analog Bits: Generating Discrete Data using Diffusion Models with Self-Conditioning } ,
author = { Ting Chen and Ruixiang Zhang and Geoffrey E. Hinton } ,
journal = { ArXiv } ,
year = { 2022 } ,
volume = { abs/2208.04202 }
}