classifier free guidance pytorch
0.7.1
在 Pytorch 中实现无分类器指导,重点是文本调节,以及包含多个文本嵌入模型的灵活性,如 eDiff-I 中所做的那样
现在很明显,文本指导是模型的最终界面。该存储库将利用一些 Python 装饰器魔法,轻松将 SOTA 文本调节合并到任何模型中。
StabilityAI 以及我的其他赞助商的慷慨赞助
?拥抱他们令人惊叹的变形金刚库。根据最新研究建议,文本调节模块将使用 T5 嵌入
OpenCLIP 用于提供 SOTA 开源 CLIP 模型。通过将 T5 嵌入与 CLIP 文本嵌入相结合,eDiff 模型得到了巨大的改进
$ pip install classifier-free-guidance-pytorch
import torch
from classifier_free_guidance_pytorch import TextConditioner
text_conditioner = TextConditioner (
model_types = 't5' ,
hidden_dims = ( 256 , 512 ),
hiddens_channel_first = False ,
cond_drop_prob = 0.2 # conditional dropout 20% of the time, must be greater than 0. to unlock classifier free guidance
). cuda ()
# pass in your text as a List[str], and get back a List[callable]
# each callable function receives the hiddens in the dimensions listed at init (hidden_dims)
first_condition_fn , second_condition_fn = text_conditioner ([ 'a dog chasing after a ball' ])
# these hiddens will be in the direct flow of your model, say in a unet
first_hidden = torch . randn ( 1 , 16 , 256 ). cuda ()
second_hidden = torch . randn ( 1 , 32 , 512 ). cuda ()
# conditioned features
first_conditioned = first_condition_fn ( first_hidden )
second_conditioned = second_condition_fn ( second_hidden )
如果您希望使用基于交叉注意力的调节(网络中的每个隐藏功能都可以处理单个子词标记),只需导入AttentionTextConditioner
即可。休息也是一样
from classifier_free_guidance_pytorch import AttentionTextConditioner
text_conditioner = AttentionTextConditioner (
model_types = ( 't5' , 'clip' ), # something like in eDiff paper, where they used both T5 and Clip for even better results (Balaji et al.)
hidden_dims = ( 256 , 512 ),
cond_drop_prob = 0.2
)
这项工作正在进行中,旨在尽可能轻松地通过文本调节您的网络。
首先,假设您有一个简单的两层网络
import torch
from torch import nn
class MLP ( nn . Module ):
def __init__ (
self ,
dim
):
super (). __init__ ()
self . proj_in = nn . Sequential ( nn . Linear ( dim , dim * 2 ), nn . ReLU ())
self . proj_mid = nn . Sequential ( nn . Linear ( dim * 2 , dim ), nn . ReLU ())
self . proj_out = nn . Linear ( dim , 1 )
def forward (
self ,
data
):
hiddens1 = self . proj_in ( data )
hiddens2 = self . proj_mid ( hiddens1 )
return self . proj_out ( hiddens2 )
# instantiate model and pass in some data, get (in this case) a binary prediction
model = MLP ( dim = 256 )
data = torch . randn ( 2 , 256 )
pred = model ( data )
您想用文本调节隐藏层( hiddens1
和hiddens2
)。这里的每个批处理元素都会有自己的自由文本条件
使用此存储库,这已减少到约 3 个步骤。
import torch
from torch import nn
from classifier_free_guidance_pytorch import classifier_free_guidance_class_decorator
@ classifier_free_guidance_class_decorator
class MLP ( nn . Module ):
def __init__ ( self , dim ):
super (). __init__ ()
self . proj_in = nn . Sequential ( nn . Linear ( dim , dim * 2 ), nn . ReLU ())
self . proj_mid = nn . Sequential ( nn . Linear ( dim * 2 , dim ), nn . ReLU ())
self . proj_out = nn . Linear ( dim , 1 )
def forward (
self ,
inp ,
cond_fns # List[Callable] - (1) your forward function now receives a list of conditioning functions, which you invoke on your hidden tensors
):
cond_hidden1 , cond_hidden2 = cond_fns # conditioning functions are given back in the order of the `hidden_dims` set on the text conditioner
hiddens1 = self . proj_in ( inp )
hiddens1 = cond_hidden1 ( hiddens1 ) # (2) condition the first hidden layer with FiLM
hiddens2 = self . proj_mid ( hiddens1 )
hiddens2 = cond_hidden2 ( hiddens2 ) # condition the second hidden layer with FiLM
return self . proj_out ( hiddens2 )
# instantiate your model - extra keyword arguments will need to be defined, prepended by `text_condition_`
model = MLP (
dim = 256 ,
text_condition_type = 'film' , # can be film, attention, or null (none)
text_condition_model_types = ( 't5' , 'clip' ), # in this example, conditioning on both T5 and OpenCLIP
text_condition_hidden_dims = ( 512 , 256 ), # and pass in the hidden dimensions you would like to condition on. in this case there are two hidden dimensions (dim * 2 and dim, after the first and second projections)
text_condition_cond_drop_prob = 0.25 # conditional dropout probability for classifier free guidance. can be set to 0. if you do not need it and just want the text conditioning
)
# now you have your input data as well as corresponding free text as List[str]
data = torch . randn ( 2 , 256 )
texts = [ 'a description' , 'another description' ]
# (3) train your model, passing in your list of strings as 'texts'
pred = model ( data , texts = texts )
# after much training, you can now do classifier free guidance by passing in a condition scale of > 1. !
model . eval ()
guided_pred = model ( data , texts = texts , cond_scale = 3. , remove_parallel_component = True ) # cond_scale stands for conditioning scale from classifier free guidance paper
完整的薄膜调节,无需分类器免费指导(此处使用)
添加薄膜调节的分类器免费指导
完全交叉注意调节
制作视频中时空unet的压力测试
@article { Ho2022ClassifierFreeDG ,
title = { Classifier-Free Diffusion Guidance } ,
author = { Jonathan Ho } ,
journal = { ArXiv } ,
year = { 2022 } ,
volume = { abs/2207.12598 }
}
@article { Balaji2022eDiffITD ,
title = { eDiff-I: Text-to-Image Diffusion Models with an Ensemble of Expert Denoisers } ,
author = { Yogesh Balaji and Seungjun Nah and Xun Huang and Arash Vahdat and Jiaming Song and Karsten Kreis and Miika Aittala and Timo Aila and Samuli Laine and Bryan Catanzaro and Tero Karras and Ming-Yu Liu } ,
journal = { ArXiv } ,
year = { 2022 } ,
volume = { abs/2211.01324 }
}
@inproceedings { dao2022flashattention ,
title = { Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness } ,
author = { Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{'e}, Christopher } ,
booktitle = { Advances in Neural Information Processing Systems } ,
year = { 2022 }
}
@inproceedings { Lin2023CommonDN ,
title = { Common Diffusion Noise Schedules and Sample Steps are Flawed } ,
author = { Shanchuan Lin and Bingchen Liu and Jiashi Li and Xiao Yang } ,
year = { 2023 }
}
@inproceedings { Chung2024CFGMC ,
title = { CFG++: Manifold-constrained Classifier Free Guidance for Diffusion Models } ,
author = { Hyungjin Chung and Jeongsol Kim and Geon Yeong Park and Hyelin Nam and Jong Chul Ye } ,
year = { 2024 } ,
url = { https://api.semanticscholar.org/CorpusID:270391454 }
}
@inproceedings { Sadat2024EliminatingOA ,
title = { Eliminating Oversaturation and Artifacts of High Guidance Scales in Diffusion Models } ,
author = { Seyedmorteza Sadat and Otmar Hilliges and Romann M. Weber } ,
year = { 2024 } ,
url = { https://api.semanticscholar.org/CorpusID:273098845 }
}