OpenAI 的 Sora、Stability AI 的 Stable Video Diffusion 以及许多其他已经出现或未来将出现的文本到视频模型都是继大型语言模型(LLM)之后的 2024 年最流行的人工智能趋势之一。在本博客中,我们将从头开始构建一个小型文本到视频模型。我们将输入一个文本提示,我们训练的模型将根据该提示生成一个视频。该博客将涵盖从理解理论概念到编码整个架构并生成最终结果的所有内容。
由于我没有高级的 GPU,因此我编写了小型架构。以下是在不同处理器上训练模型所需时间的比较:
培训视频 | 纪元 | 中央处理器 | 显卡A10 | 显卡T4 |
---|---|---|---|---|
10K | 30 | 超过3小时 | 1小时 | 1小时42m |
30K | 30 | 超过6小时 | 1小时30分 | 2小时30分 |
10万 | 30 | - | 3-4小时 | 5-6小时 |
在 CPU 上运行显然需要更长的时间来训练模型。如果您需要快速测试代码中的更改并查看结果,CPU 并不是最佳选择。我建议使用 Colab 或 Kaggle 的 T4 GPU 来实现更高效、更快的训练。
以下博客链接指导您如何从头开始创建稳定扩散:从头开始编码稳定扩散
我们将遵循与传统机器学习或深度学习模型类似的方法,在数据集上进行训练,然后在未见过的数据上进行测试。在文本到视频的背景下,假设我们有一个包含 10 万个狗捡球和猫追老鼠视频的训练数据集。我们将训练我们的模型来生成猫捡球或狗追老鼠的视频。
尽管此类训练数据集很容易在互联网上获得,但所需的计算能力非常高。因此,我们将使用由 Python 代码生成的移动对象的视频数据集。
我们将使用 GAN(生成对抗网络)架构来创建我们的模型,而不是 OpenAI Sora 使用的扩散模型。我尝试使用扩散模型,但由于内存要求而崩溃,这超出了我的能力。另一方面,GAN 的训练和测试更容易、更快捷。
我们将使用 OOP(面向对象编程),因此您必须对它和神经网络有基本的了解。 GAN(生成对抗网络)的知识不是强制性的,因为我们将在这里介绍它们的架构。
话题 | 关联 |
---|---|
面向对象编程 | 视频链接 |
神经网络理论 | 视频链接 |
生成式对抗网络架构 | 视频链接 |
Python 基础知识 | 视频链接 |
理解 GAN 架构很重要,因为我们的大部分架构都依赖于它。让我们来探讨一下它是什么、它的组件等等。
生成对抗网络 (GAN) 是一种深度学习模型,其中两个神经网络相互竞争:一个从给定的数据集中创建新数据(例如图像或音乐),另一个尝试判断数据是真实的还是虚假的。这个过程一直持续到生成的数据与原始数据无法区分为止。
生成图像:GAN 根据文本提示创建逼真的图像或修改现有图像,例如增强分辨率或为黑白照片添加颜色。
数据增强:它们生成合成数据来训练其他机器学习模型,例如为欺诈检测系统创建欺诈交易数据。
完整缺失的信息:GAN 可以填充缺失的数据,例如从地形图生成地下图像以用于能源应用。
生成 3D 模型:它们将 2D 图像转换为 3D 模型,这在医疗保健等领域非常有用,可以为手术规划创建逼真的器官图像。
它由两个深度神经网络组成:生成器和鉴别器。这些网络在对抗性设置中一起训练,其中一个网络生成新数据,另一个网络评估数据是真实的还是虚假的。
以下是 GAN 工作原理的简单概述:
训练集分析:生成器分析训练集以识别数据属性,而鉴别器独立分析相同的数据以学习其属性。
数据修改:生成器向数据的某些属性添加噪声(随机变化)。
数据传递:修改后的数据然后被传递到鉴别器。
概率计算:判别器计算生成的数据来自原始数据集的概率。
反馈循环:鉴别器向生成器提供反馈,指导生成器减少下一个周期的随机噪声。
对抗性训练:生成器试图最大化判别器的错误,而判别器则试图最小化自己的错误。通过多次训练迭代,两个网络都得到改进和发展。
平衡状态:训练继续,直到判别器无法再区分真实数据和合成数据,表明生成器已成功学会生成真实数据。至此,训练过程就完成了。
图片来自aws指南
让我们以图像到图像转换的示例来解释 GAN 模型,重点是修改人脸。
输入图像:输入是人脸的真实图像。
属性修改:生成器修改脸部的属性,例如为眼睛添加太阳镜。
生成的图像:生成器创建一组添加了太阳镜的图像。
鉴别器的任务:鉴别器接收真实图像(戴太阳镜的人)和生成图像(添加太阳镜的人脸)的混合。
评估:鉴别器尝试区分真实图像和生成图像。
反馈循环:如果鉴别器正确识别出假图像,则生成器会调整其参数以产生更令人信服的图像。如果生成器成功欺骗了鉴别器,鉴别器就会更新其参数以改进其检测。
通过这个对抗过程,两个网络都在不断改进。生成器在创建真实图像方面变得更好,鉴别器在识别赝品方面也变得更好,直到达到平衡,鉴别器不再能够区分真实图像和生成图像之间的差异。至此,GAN 已经成功学会了产生现实的修改。
安装所需的库是构建文本到视频模型的第一步。
pip install -r requirements.txt
我们将使用一系列 Python 库,让我们导入它们。
# Operating System module for interacting with the operating system
import os
# Module for generating random numbers
import random
# Module for numerical operations
import numpy as np
# OpenCV library for image processing
import cv2
# Python Imaging Library for image processing
from PIL import Image , ImageDraw , ImageFont
# PyTorch library for deep learning
import torch
# Dataset class for creating custom datasets in PyTorch
from torch . utils . data import Dataset
# Module for image transformations
import torchvision . transforms as transforms
# Neural network module in PyTorch
import torch . nn as nn
# Optimization algorithms in PyTorch
import torch . optim as optim
# Function for padding sequences in PyTorch
from torch . nn . utils . rnn import pad_sequence
# Function for saving images in PyTorch
from torchvision . utils import save_image
# Module for plotting graphs and images
import matplotlib . pyplot as plt
# Module for displaying rich content in IPython environments
from IPython . display import clear_output , display , HTML
# Module for encoding and decoding binary data to text
import base64
现在我们已经导入了所有库,下一步是定义我们将用于训练 GAN 架构的训练数据。
我们需要至少 10,000 个视频作为训练数据。为什么?嗯,因为我用较小的数字进行了测试,结果很差,几乎没有什么可看的。下一个大问题是:这些视频是关于什么的?我们的训练视频数据集由一个以不同运动向不同方向移动的圆圈组成。那么,让我们对其进行编码并生成 10,000 个视频来看看它是什么样子。
# Create a directory named 'training_dataset'
os . makedirs ( 'training_dataset' , exist_ok = True )
# Define the number of videos to generate for the dataset
num_videos = 10000
# Define the number of frames per video (1 Second Video)
frames_per_video = 10
# Define the size of each image in the dataset
img_size = ( 64 , 64 )
# Define the size of the shapes (Circle)
shape_size = 10
设置一些基本参数后,接下来我们需要定义训练数据集的文本提示,根据该文本提示将生成训练视频。
# Define text prompts and corresponding movements for circles
prompts_and_movements = [
( "circle moving down" , "circle" , "down" ), # Move circle downward
( "circle moving left" , "circle" , "left" ), # Move circle leftward
( "circle moving right" , "circle" , "right" ), # Move circle rightward
( "circle moving diagonally up-right" , "circle" , "diagonal_up_right" ), # Move circle diagonally up-right
( "circle moving diagonally down-left" , "circle" , "diagonal_down_left" ), # Move circle diagonally down-left
( "circle moving diagonally up-left" , "circle" , "diagonal_up_left" ), # Move circle diagonally up-left
( "circle moving diagonally down-right" , "circle" , "diagonal_down_right" ), # Move circle diagonally down-right
( "circle rotating clockwise" , "circle" , "rotate_clockwise" ), # Rotate circle clockwise
( "circle rotating counter-clockwise" , "circle" , "rotate_counter_clockwise" ), # Rotate circle counter-clockwise
( "circle shrinking" , "circle" , "shrink" ), # Shrink circle
( "circle expanding" , "circle" , "expand" ), # Expand circle
( "circle bouncing vertically" , "circle" , "bounce_vertical" ), # Bounce circle vertically
( "circle bouncing horizontally" , "circle" , "bounce_horizontal" ), # Bounce circle horizontally
( "circle zigzagging vertically" , "circle" , "zigzag_vertical" ), # Zigzag circle vertically
( "circle zigzagging horizontally" , "circle" , "zigzag_horizontal" ), # Zigzag circle horizontally
( "circle moving up-left" , "circle" , "up_left" ), # Move circle up-left
( "circle moving down-right" , "circle" , "down_right" ), # Move circle down-right
( "circle moving down-left" , "circle" , "down_left" ), # Move circle down-left
]
我们已经使用这些提示定义了圆的几种运动。现在,我们需要编写一些数学方程来根据提示移动该圆圈。
# defining function to create image with moving shape
def create_image_with_moving_shape ( size , frame_num , shape , direction ):
# Create a new RGB image with specified size and white background
img = Image . new ( 'RGB' , size , color = ( 255 , 255 , 255 ))
# Create a drawing context for the image
draw = ImageDraw . Draw ( img )
# Calculate the center coordinates of the image
center_x , center_y = size [ 0 ] // 2 , size [ 1 ] // 2
# Initialize position with center for all movements
position = ( center_x , center_y )
# Define a dictionary mapping directions to their respective position adjustments or image transformations
direction_map = {
# Adjust position downwards based on frame number
"down" : ( 0 , frame_num * 5 % size [ 1 ]),
# Adjust position to the left based on frame number
"left" : ( - frame_num * 5 % size [ 0 ], 0 ),
# Adjust position to the right based on frame number
"right" : ( frame_num * 5 % size [ 0 ], 0 ),
# Adjust position diagonally up and to the right
"diagonal_up_right" : ( frame_num * 5 % size [ 0 ], - frame_num * 5 % size [ 1 ]),
# Adjust position diagonally down and to the left
"diagonal_down_left" : ( - frame_num * 5 % size [ 0 ], frame_num * 5 % size [ 1 ]),
# Adjust position diagonally up and to the left
"diagonal_up_left" : ( - frame_num * 5 % size [ 0 ], - frame_num * 5 % size [ 1 ]),
# Adjust position diagonally down and to the right
"diagonal_down_right" : ( frame_num * 5 % size [ 0 ], frame_num * 5 % size [ 1 ]),
# Rotate the image clockwise based on frame number
"rotate_clockwise" : img . rotate ( frame_num * 10 % 360 , center = ( center_x , center_y ), fillcolor = ( 255 , 255 , 255 )),
# Rotate the image counter-clockwise based on frame number
"rotate_counter_clockwise" : img . rotate ( - frame_num * 10 % 360 , center = ( center_x , center_y ), fillcolor = ( 255 , 255 , 255 )),
# Adjust position for a bouncing effect vertically
"bounce_vertical" : ( 0 , center_y - abs ( frame_num * 5 % size [ 1 ] - center_y )),
# Adjust position for a bouncing effect horizontally
"bounce_horizontal" : ( center_x - abs ( frame_num * 5 % size [ 0 ] - center_x ), 0 ),
# Adjust position for a zigzag effect vertically
"zigzag_vertical" : ( 0 , center_y - frame_num * 5 % size [ 1 ]) if frame_num % 2 == 0 else ( 0 , center_y + frame_num * 5 % size [ 1 ]),
# Adjust position for a zigzag effect horizontally
"zigzag_horizontal" : ( center_x - frame_num * 5 % size [ 0 ], center_y ) if frame_num % 2 == 0 else ( center_x + frame_num * 5 % size [ 0 ], center_y ),
# Adjust position upwards and to the right based on frame number
"up_right" : ( frame_num * 5 % size [ 0 ], - frame_num * 5 % size [ 1 ]),
# Adjust position upwards and to the left based on frame number
"up_left" : ( - frame_num * 5 % size [ 0 ], - frame_num * 5 % size [ 1 ]),
# Adjust position downwards and to the right based on frame number
"down_right" : ( frame_num * 5 % size [ 0 ], frame_num * 5 % size [ 1 ]),
# Adjust position downwards and to the left based on frame number
"down_left" : ( - frame_num * 5 % size [ 0 ], frame_num * 5 % size [ 1 ])
}
# Check if direction is in the direction map
if direction in direction_map :
# Check if the direction maps to a position adjustment
if isinstance ( direction_map [ direction ], tuple ):
# Update position based on the adjustment
position = tuple ( np . add ( position , direction_map [ direction ]))
else : # If the direction maps to an image transformation
# Update the image based on the transformation
img = direction_map [ direction ]
# Return the image as a numpy array
return np . array ( img )
上面的函数用于根据所选方向移动每一帧的圆。我们只需要在其上运行一个循环,直到达到视频数量的次数即可生成所有视频。
# Iterate over the number of videos to generate
for i in range ( num_videos ):
# Randomly choose a prompt and movement from the predefined list
prompt , shape , direction = random . choice ( prompts_and_movements )
# Create a directory for the current video
video_dir = f'training_dataset/video_ { i } '
os . makedirs ( video_dir , exist_ok = True )
# Write the chosen prompt to a text file in the video directory
with open ( f' { video_dir } /prompt.txt' , 'w' ) as f :
f . write ( prompt )
# Generate frames for the current video
for frame_num in range ( frames_per_video ):
# Create an image with a moving shape based on the current frame number, shape, and direction
img = create_image_with_moving_shape ( img_size , frame_num , shape , direction )
# Save the generated image as a PNG file in the video directory
cv2 . imwrite ( f' { video_dir } /frame_ { frame_num } .png' , img )
运行上述代码后,它将生成我们的整个训练数据集。这是我们的训练数据集文件的结构。
每个训练视频文件夹都包含其帧及其文本提示。让我们看一下训练数据集的样本。
在我们的训练数据集中,我们没有包含圆圈向上移动然后向右移动的运动。我们将使用它作为测试提示来评估我们在未见过的数据上训练的模型。
需要注意的更重要的一点是,我们的训练数据确实包含许多样本,其中物体远离场景或部分出现在相机前面,类似于我们在 OpenAI Sora 演示视频中观察到的情况。
在我们的训练数据中包含此类样本的原因是为了测试当圆圈从最角落进入场景而不破坏其形状时,我们的模型是否能够保持一致性。
现在我们的训练数据已经生成,我们需要将训练视频转换为张量,这是 PyTorch 等深度学习框架中使用的主要数据类型。此外,执行归一化等转换有助于通过将数据扩展到更小的范围来提高训练架构的收敛性和稳定性。
我们必须为文本到视频任务编写一个数据集类,它可以从训练数据集目录中读取视频帧及其相应的文本提示,使其可在 PyTorch 中使用。
# Define a dataset class inheriting from torch.utils.data.Dataset
class TextToVideoDataset ( Dataset ):
def __init__ ( self , root_dir , transform = None ):
# Initialize the dataset with root directory and optional transform
self . root_dir = root_dir
self . transform = transform
# List all subdirectories in the root directory
self . video_dirs = [ os . path . join ( root_dir , d ) for d in os . listdir ( root_dir ) if os . path . isdir ( os . path . join ( root_dir , d ))]
# Initialize lists to store frame paths and corresponding prompts
self . frame_paths = []
self . prompts = []
# Loop through each video directory
for video_dir in self . video_dirs :
# List all PNG files in the video directory and store their paths
frames = [ os . path . join ( video_dir , f ) for f in os . listdir ( video_dir ) if f . endswith ( '.png' )]
self . frame_paths . extend ( frames )
# Read the prompt text file in the video directory and store its content
with open ( os . path . join ( video_dir , 'prompt.txt' ), 'r' ) as f :
prompt = f . read (). strip ()
# Repeat the prompt for each frame in the video and store in prompts list
self . prompts . extend ([ prompt ] * len ( frames ))
# Return the total number of samples in the dataset
def __len__ ( self ):
return len ( self . frame_paths )
# Retrieve a sample from the dataset given an index
def __getitem__ ( self , idx ):
# Get the path of the frame corresponding to the given index
frame_path = self . frame_paths [ idx ]
# Open the image using PIL (Python Imaging Library)
image = Image . open ( frame_path )
# Get the prompt corresponding to the given index
prompt = self . prompts [ idx ]
# Apply transformation if specified
if self . transform :
image = self . transform ( image )
# Return the transformed image and the prompt
return image , prompt
在继续对架构进行编码之前,我们需要规范化我们的训练数据。我们将使用 16 的批量大小并对数据进行打乱以引入更多随机性。
# Define a set of transformations to be applied to the data
transform = transforms . Compose ([
transforms . ToTensor (), # Convert PIL Image or numpy.ndarray to tensor
transforms . Normalize (( 0.5 ,), ( 0.5 ,)) # Normalize image with mean and standard deviation
])
# Load the dataset using the defined transform
dataset = TextToVideoDataset ( root_dir = 'training_dataset' , transform = transform )
# Create a dataloader to iterate over the dataset
dataloader = torch . utils . data . DataLoader ( dataset , batch_size = 16 , shuffle = True )
您可能已经在 Transformer 架构中看到过,其中的起点是将我们的文本输入转换为嵌入,以便在多头注意力中进一步处理,与这里类似,我们必须编写一个文本嵌入层,基于该层,GAN 架构训练将在我们的嵌入数据上进行和图像张量。
# Define a class for text embedding
class TextEmbedding ( nn . Module ):
# Constructor method with vocab_size and embed_size parameters
def __init__ ( self , vocab_size , embed_size ):
# Call the superclass constructor
super ( TextEmbedding , self ). __init__ ()
# Initialize embedding layer
self . embedding = nn . Embedding ( vocab_size , embed_size )
# Define the forward pass method
def forward ( self , x ):
# Return embedded representation of input
return self . embedding ( x )
词汇量大小将基于我们的训练数据,稍后我们将计算这些数据。嵌入大小将为 10。如果使用更大的数据集,您还可以使用自己选择的 Hugging Face 上提供的嵌入模型。
现在我们已经知道生成器在 GAN 中的作用,让我们对这一层进行编码,然后了解其内容。
class Generator ( nn . Module ):
def __init__ ( self , text_embed_size ):
super ( Generator , self ). __init__ ()
# Fully connected layer that takes noise and text embedding as input
self . fc1 = nn . Linear ( 100 + text_embed_size , 256 * 8 * 8 )
# Transposed convolutional layers to upsample the input
self . deconv1 = nn . ConvTranspose2d ( 256 , 128 , 4 , 2 , 1 )
self . deconv2 = nn . ConvTranspose2d ( 128 , 64 , 4 , 2 , 1 )
self . deconv3 = nn . ConvTranspose2d ( 64 , 3 , 4 , 2 , 1 ) # Output has 3 channels for RGB images
# Activation functions
self . relu = nn . ReLU ( True ) # ReLU activation function
self . tanh = nn . Tanh () # Tanh activation function for final output
def forward ( self , noise , text_embed ):
# Concatenate noise and text embedding along the channel dimension
x = torch . cat (( noise , text_embed ), dim = 1 )
# Fully connected layer followed by reshaping to 4D tensor
x = self . fc1 ( x ). view ( - 1 , 256 , 8 , 8 )
# Upsampling through transposed convolution layers with ReLU activation
x = self . relu ( self . deconv1 ( x ))
x = self . relu ( self . deconv2 ( x ))
# Final layer with Tanh activation to ensure output values are between -1 and 1 (for images)
x