OpenAI 的 Sora、Stability AI 的 Stable Video Diffusion 以及許多其他已經出現或未來將出現的文本到視頻模型都是繼大型語言模型(LLM)之後的 2024 年最流行的人工智能趨勢之一。在本部落格中,我們將從頭開始建立一個小型文字到視訊模型。我們將輸入一個文字提示,我們訓練的模型將根據該提示產生一個影片。這篇部落格將涵蓋從理解理論概念到編碼整個架構並產生最終結果的所有內容。
由於我沒有高級的 GPU,因此我編寫了小型架構。以下是在不同處理器上訓練模型所需時間的比較:
培訓影片 | 紀元 | 中央處理器 | 顯示卡A10 | 顯示卡T4 |
---|---|---|---|---|
10K | 30 | 超過3小時 | 1小時 | 1小時42m |
30K | 30 | 超過6小時 | 1小時30分 | 2小時30分 |
10萬 | 30 | - | 3-4小時 | 5-6小時 |
在 CPU 上運行顯然需要更長的時間來訓練模型。如果您需要快速測試程式碼中的變更並查看結果,CPU 並不是最佳選擇。我建議使用 Colab 或 Kaggle 的 T4 GPU 來實現更有效率、更快的訓練。
以下部落格連結指導您如何從頭開始創建穩定擴展:從頭開始編碼穩定擴散
我們將遵循與傳統機器學習或深度學習模型類似的方法,在資料集上進行訓練,然後在未見過的資料上進行測試。在文字到影片的背景下,假設我們有一個包含 10 萬個狗撿球和貓追老鼠影片的訓練資料集。我們將訓練我們的模型來產生貓撿球或狗追老鼠的影片。
儘管此類訓練資料集很容易在網路上取得,但所需的運算能力非常高。因此,我們將使用由 Python 程式碼產生的移動物件的視訊資料集。
我們將使用 GAN(生成對抗網路)架構來創建我們的模型,而不是 OpenAI Sora 使用的擴散模型。我嘗試使用擴散模型,但由於記憶體要求而崩潰,這超出了我的能力。另一方面,GAN 的訓練和測試更容易、更快捷。
我們將使用 OOP(物件導向程式設計),因此您必須對它和神經網路有基本的了解。 GAN(生成對抗網路)的知識不是強制性的,因為我們將在這裡介紹它們的架構。
話題 | 關聯 |
---|---|
物件導向程式設計 | 影片連結 |
神經網路理論 | 影片連結 |
生成式對抗網路架構 | 影片連結 |
Python 基礎知識 | 影片連結 |
理解 GAN 架構很重要,因為我們的大部分架構都依賴它。讓我們來探討一下它是什麼、它的組件等等。
生成對抗網路 (GAN) 是一種深度學習模型,其中兩個神經網路相互競爭:一個從給定的資料集中創建新資料(如圖像或音樂),另一個嘗試判斷資料是真是假。這個過程一直持續到產生的資料與原始資料無法區分為止。
生成圖像:GAN 根據文字提示創建逼真的圖像或修改現有圖像,例如增強解析度或為黑白照片添加顏色。
資料增強:它們產生合成資料來訓練其他機器學習模型,例如為詐欺偵測系統建立詐欺交易資料。
完整缺失的資訊:GAN 可以填充缺失的數據,例如從地形圖產生地下影像以用於能源應用。
產生 3D 模型:它們將 2D 影像轉換為 3D 模型,這在醫療保健等領域非常有用,可以為手術規劃創建逼真的器官影像。
它由兩個深度神經網路組成:生成器和鑑別器。這些網路在對抗性設定中一起訓練,其中一個網路產生新數據,另一個網路評估數據是真實的還是虛假的。
以下是 GAN 工作原理的簡單概述:
訓練集分析:生成器分析訓練集以識別資料屬性,而判別器獨立分析相同的資料以學習其屬性。
資料修改:生成器為資料的某些屬性添加雜訊(隨機變化)。
資料傳遞:修改後的資料然後傳遞到鑑別器。
機率計算:判別器計算產生的資料來自原始資料集的機率。
回饋循環:鑑別器向生成器提供回饋,指導生成器減少下一個週期的隨機雜訊。
對抗性訓練:生成器試圖最大化判別器的錯誤,而判別器則試圖最小化自己的錯誤。透過多次訓練迭代,兩個網路都得到改進和發展。
平衡狀態:訓練持續,直到判別器無法再區分真實數據和合成數據,顯示生成器已成功學會產生真實數據。至此,訓練過程就完成了。
圖片來自aws指南
讓我們以圖像到圖像轉換的範例來解釋 GAN 模型,重點是修改人臉。
輸入影像:輸入是人臉的真實影像。
屬性修改:生成器修改臉部的屬性,例如為眼睛添加太陽眼鏡。
產生的圖像:生成器創建一組添加了太陽眼鏡的圖像。
鑑別器的任務:鑑別器接收真實影像(戴太陽眼鏡的人)和生成影像(加上太陽眼鏡的人臉)的混合。
評估:鑑別器嘗試區分真實影像和生成影像。
反饋循環:如果鑑別器正確識別假圖像,則生成器會調整其參數以產生更令人信服的圖像。如果生成器成功欺騙了鑑別器,鑑別器就會更新其參數以改善其偵測。
透過這個對抗過程,兩個網路都在不斷改進。生成器在創建真實圖像方面變得更好,鑑別器在識別贗品方面也變得更好,直到達到平衡,鑑別器不再能夠區分真實圖像和生成圖像之間的差異。至此,GAN 已經成功學會了產生現實的修改。
安裝所需的庫是建立文字到視訊模型的第一步。
pip install -r requirements.txt
我們將使用一系列 Python 庫,讓我們導入它們。
# Operating System module for interacting with the operating system
import os
# Module for generating random numbers
import random
# Module for numerical operations
import numpy as np
# OpenCV library for image processing
import cv2
# Python Imaging Library for image processing
from PIL import Image , ImageDraw , ImageFont
# PyTorch library for deep learning
import torch
# Dataset class for creating custom datasets in PyTorch
from torch . utils . data import Dataset
# Module for image transformations
import torchvision . transforms as transforms
# Neural network module in PyTorch
import torch . nn as nn
# Optimization algorithms in PyTorch
import torch . optim as optim
# Function for padding sequences in PyTorch
from torch . nn . utils . rnn import pad_sequence
# Function for saving images in PyTorch
from torchvision . utils import save_image
# Module for plotting graphs and images
import matplotlib . pyplot as plt
# Module for displaying rich content in IPython environments
from IPython . display import clear_output , display , HTML
# Module for encoding and decoding binary data to text
import base64
現在我們已經匯入了所有庫,下一步是定義我們將用於訓練 GAN 架構的訓練資料。
我們需要至少 10,000 個影片作為訓練資料。為什麼?嗯,因為我用較小的數字進行了測試,結果很差,幾乎沒有什麼可看的。下一個大問題是:這些影片是關於什麼的?我們的訓練影片資料集由一個以不同運動向不同方向移動的圓圈組成。那麼,讓我們對其進行編碼並產生 10,000 個影片來看看它是什麼樣子。
# Create a directory named 'training_dataset'
os . makedirs ( 'training_dataset' , exist_ok = True )
# Define the number of videos to generate for the dataset
num_videos = 10000
# Define the number of frames per video (1 Second Video)
frames_per_video = 10
# Define the size of each image in the dataset
img_size = ( 64 , 64 )
# Define the size of the shapes (Circle)
shape_size = 10
設定一些基本參數後,接下來我們需要定義訓練資料集的文字提示,根據該文字提示將產生訓練影片。
# Define text prompts and corresponding movements for circles
prompts_and_movements = [
( "circle moving down" , "circle" , "down" ), # Move circle downward
( "circle moving left" , "circle" , "left" ), # Move circle leftward
( "circle moving right" , "circle" , "right" ), # Move circle rightward
( "circle moving diagonally up-right" , "circle" , "diagonal_up_right" ), # Move circle diagonally up-right
( "circle moving diagonally down-left" , "circle" , "diagonal_down_left" ), # Move circle diagonally down-left
( "circle moving diagonally up-left" , "circle" , "diagonal_up_left" ), # Move circle diagonally up-left
( "circle moving diagonally down-right" , "circle" , "diagonal_down_right" ), # Move circle diagonally down-right
( "circle rotating clockwise" , "circle" , "rotate_clockwise" ), # Rotate circle clockwise
( "circle rotating counter-clockwise" , "circle" , "rotate_counter_clockwise" ), # Rotate circle counter-clockwise
( "circle shrinking" , "circle" , "shrink" ), # Shrink circle
( "circle expanding" , "circle" , "expand" ), # Expand circle
( "circle bouncing vertically" , "circle" , "bounce_vertical" ), # Bounce circle vertically
( "circle bouncing horizontally" , "circle" , "bounce_horizontal" ), # Bounce circle horizontally
( "circle zigzagging vertically" , "circle" , "zigzag_vertical" ), # Zigzag circle vertically
( "circle zigzagging horizontally" , "circle" , "zigzag_horizontal" ), # Zigzag circle horizontally
( "circle moving up-left" , "circle" , "up_left" ), # Move circle up-left
( "circle moving down-right" , "circle" , "down_right" ), # Move circle down-right
( "circle moving down-left" , "circle" , "down_left" ), # Move circle down-left
]
我們已經使用這些提示定義了圓的幾種運動。現在,我們需要編寫一些數學方程式來根據提示移動該圓圈。
# defining function to create image with moving shape
def create_image_with_moving_shape ( size , frame_num , shape , direction ):
# Create a new RGB image with specified size and white background
img = Image . new ( 'RGB' , size , color = ( 255 , 255 , 255 ))
# Create a drawing context for the image
draw = ImageDraw . Draw ( img )
# Calculate the center coordinates of the image
center_x , center_y = size [ 0 ] // 2 , size [ 1 ] // 2
# Initialize position with center for all movements
position = ( center_x , center_y )
# Define a dictionary mapping directions to their respective position adjustments or image transformations
direction_map = {
# Adjust position downwards based on frame number
"down" : ( 0 , frame_num * 5 % size [ 1 ]),
# Adjust position to the left based on frame number
"left" : ( - frame_num * 5 % size [ 0 ], 0 ),
# Adjust position to the right based on frame number
"right" : ( frame_num * 5 % size [ 0 ], 0 ),
# Adjust position diagonally up and to the right
"diagonal_up_right" : ( frame_num * 5 % size [ 0 ], - frame_num * 5 % size [ 1 ]),
# Adjust position diagonally down and to the left
"diagonal_down_left" : ( - frame_num * 5 % size [ 0 ], frame_num * 5 % size [ 1 ]),
# Adjust position diagonally up and to the left
"diagonal_up_left" : ( - frame_num * 5 % size [ 0 ], - frame_num * 5 % size [ 1 ]),
# Adjust position diagonally down and to the right
"diagonal_down_right" : ( frame_num * 5 % size [ 0 ], frame_num * 5 % size [ 1 ]),
# Rotate the image clockwise based on frame number
"rotate_clockwise" : img . rotate ( frame_num * 10 % 360 , center = ( center_x , center_y ), fillcolor = ( 255 , 255 , 255 )),
# Rotate the image counter-clockwise based on frame number
"rotate_counter_clockwise" : img . rotate ( - frame_num * 10 % 360 , center = ( center_x , center_y ), fillcolor = ( 255 , 255 , 255 )),
# Adjust position for a bouncing effect vertically
"bounce_vertical" : ( 0 , center_y - abs ( frame_num * 5 % size [ 1 ] - center_y )),
# Adjust position for a bouncing effect horizontally
"bounce_horizontal" : ( center_x - abs ( frame_num * 5 % size [ 0 ] - center_x ), 0 ),
# Adjust position for a zigzag effect vertically
"zigzag_vertical" : ( 0 , center_y - frame_num * 5 % size [ 1 ]) if frame_num % 2 == 0 else ( 0 , center_y + frame_num * 5 % size [ 1 ]),
# Adjust position for a zigzag effect horizontally
"zigzag_horizontal" : ( center_x - frame_num * 5 % size [ 0 ], center_y ) if frame_num % 2 == 0 else ( center_x + frame_num * 5 % size [ 0 ], center_y ),
# Adjust position upwards and to the right based on frame number
"up_right" : ( frame_num * 5 % size [ 0 ], - frame_num * 5 % size [ 1 ]),
# Adjust position upwards and to the left based on frame number
"up_left" : ( - frame_num * 5 % size [ 0 ], - frame_num * 5 % size [ 1 ]),
# Adjust position downwards and to the right based on frame number
"down_right" : ( frame_num * 5 % size [ 0 ], frame_num * 5 % size [ 1 ]),
# Adjust position downwards and to the left based on frame number
"down_left" : ( - frame_num * 5 % size [ 0 ], frame_num * 5 % size [ 1 ])
}
# Check if direction is in the direction map
if direction in direction_map :
# Check if the direction maps to a position adjustment
if isinstance ( direction_map [ direction ], tuple ):
# Update position based on the adjustment
position = tuple ( np . add ( position , direction_map [ direction ]))
else : # If the direction maps to an image transformation
# Update the image based on the transformation
img = direction_map [ direction ]
# Return the image as a numpy array
return np . array ( img )
上面的函數用於根據所選方向移動每一幀的圓。我們只需要在其上運行一個循環,直到達到影片數量的次數即可產生所有影片。
# Iterate over the number of videos to generate
for i in range ( num_videos ):
# Randomly choose a prompt and movement from the predefined list
prompt , shape , direction = random . choice ( prompts_and_movements )
# Create a directory for the current video
video_dir = f'training_dataset/video_ { i } '
os . makedirs ( video_dir , exist_ok = True )
# Write the chosen prompt to a text file in the video directory
with open ( f' { video_dir } /prompt.txt' , 'w' ) as f :
f . write ( prompt )
# Generate frames for the current video
for frame_num in range ( frames_per_video ):
# Create an image with a moving shape based on the current frame number, shape, and direction
img = create_image_with_moving_shape ( img_size , frame_num , shape , direction )
# Save the generated image as a PNG file in the video directory
cv2 . imwrite ( f' { video_dir } /frame_ { frame_num } .png' , img )
運行上述程式碼後,它將產生我們的整個訓練資料集。這是我們的訓練資料集檔案的結構。
每個訓練影片資料夾都包含其幀及其文字提示。讓我們來看看訓練資料集的樣本。
在我們的訓練資料集中,我們沒有包含圓圈向上移動然後向右移動的運動。我們將使用它作為測試提示來評估我們在未見過的資料上訓練的模型。
需要注意的更重要的一點是,我們的訓練資料確實包含許多樣本,其中物體遠離場景或部分出現在相機前面,類似於我們在 OpenAI Sora 演示影片中觀察到的情況。
在我們的訓練資料中包含此類樣本的原因是為了測試當圓圈從最角落進入場景而不破壞其形狀時,我們的模型是否能夠保持一致性。
現在我們的訓練資料已經生成,我們需要將訓練影片轉換為張量,這是 PyTorch 等深度學習框架中使用的主要資料類型。此外,執行歸一化等轉換有助於透過將資料擴展到較小的範圍來提高訓練架構的收斂性和穩定性。
我們必須為文字到視訊任務編寫一個資料集類,它可以從訓練資料集目錄中讀取視訊幀及其相應的文字提示,使其可在 PyTorch 中使用。
# Define a dataset class inheriting from torch.utils.data.Dataset
class TextToVideoDataset ( Dataset ):
def __init__ ( self , root_dir , transform = None ):
# Initialize the dataset with root directory and optional transform
self . root_dir = root_dir
self . transform = transform
# List all subdirectories in the root directory
self . video_dirs = [ os . path . join ( root_dir , d ) for d in os . listdir ( root_dir ) if os . path . isdir ( os . path . join ( root_dir , d ))]
# Initialize lists to store frame paths and corresponding prompts
self . frame_paths = []
self . prompts = []
# Loop through each video directory
for video_dir in self . video_dirs :
# List all PNG files in the video directory and store their paths
frames = [ os . path . join ( video_dir , f ) for f in os . listdir ( video_dir ) if f . endswith ( '.png' )]
self . frame_paths . extend ( frames )
# Read the prompt text file in the video directory and store its content
with open ( os . path . join ( video_dir , 'prompt.txt' ), 'r' ) as f :
prompt = f . read (). strip ()
# Repeat the prompt for each frame in the video and store in prompts list
self . prompts . extend ([ prompt ] * len ( frames ))
# Return the total number of samples in the dataset
def __len__ ( self ):
return len ( self . frame_paths )
# Retrieve a sample from the dataset given an index
def __getitem__ ( self , idx ):
# Get the path of the frame corresponding to the given index
frame_path = self . frame_paths [ idx ]
# Open the image using PIL (Python Imaging Library)
image = Image . open ( frame_path )
# Get the prompt corresponding to the given index
prompt = self . prompts [ idx ]
# Apply transformation if specified
if self . transform :
image = self . transform ( image )
# Return the transformed image and the prompt
return image , prompt
在繼續對架構進行編碼之前,我們需要標準化我們的訓練資料。我們將使用 16 的批量大小並對資料進行打亂以引入更多隨機性。
# Define a set of transformations to be applied to the data
transform = transforms . Compose ([
transforms . ToTensor (), # Convert PIL Image or numpy.ndarray to tensor
transforms . Normalize (( 0.5 ,), ( 0.5 ,)) # Normalize image with mean and standard deviation
])
# Load the dataset using the defined transform
dataset = TextToVideoDataset ( root_dir = 'training_dataset' , transform = transform )
# Create a dataloader to iterate over the dataset
dataloader = torch . utils . data . DataLoader ( dataset , batch_size = 16 , shuffle = True )
您可能已經在Transformer 架構中看到過,其中的起點是將我們的文字輸入轉換為嵌入,以便在多頭注意力中進一步處理,類似於這裡,我們必須編寫一個文字嵌入層,基於該層,GAN架構訓練將在我們的嵌入資料上進行和圖片張量。
# Define a class for text embedding
class TextEmbedding ( nn . Module ):
# Constructor method with vocab_size and embed_size parameters
def __init__ ( self , vocab_size , embed_size ):
# Call the superclass constructor
super ( TextEmbedding , self ). __init__ ()
# Initialize embedding layer
self . embedding = nn . Embedding ( vocab_size , embed_size )
# Define the forward pass method
def forward ( self , x ):
# Return embedded representation of input
return self . embedding ( x )
詞彙量大小將基於我們的訓練數據,稍後我們將計算這些數據。嵌入大小將為 10。
現在我們已經知道生成器在 GAN 中的作用,讓我們對這一層進行編碼,然後了解其內容。
class Generator ( nn . Module ):
def __init__ ( self , text_embed_size ):
super ( Generator , self ). __init__ ()
# Fully connected layer that takes noise and text embedding as input
self . fc1 = nn . Linear ( 100 + text_embed_size , 256 * 8 * 8 )
# Transposed convolutional layers to upsample the input
self . deconv1 = nn . ConvTranspose2d ( 256 , 128 , 4 , 2 , 1 )
self . deconv2 = nn . ConvTranspose2d ( 128 , 64 , 4 , 2 , 1 )
self . deconv3 = nn . ConvTranspose2d ( 64 , 3 , 4 , 2 , 1 ) # Output has 3 channels for RGB images
# Activation functions
self . relu = nn . ReLU ( True ) # ReLU activation function
self . tanh = nn . Tanh () # Tanh activation function for final output
def forward ( self , noise , text_embed ):
# Concatenate noise and text embedding along the channel dimension
x = torch . cat (( noise , text_embed ), dim = 1 )
# Fully connected layer followed by reshaping to 4D tensor
x = self . fc1 ( x ). view ( - 1 , 256 , 8 , 8 )
# Upsampling through transposed convolution layers with ReLU activation
x = self . relu ( self . deconv1 ( x ))
x = self . relu ( self . deconv2 ( x ))
# Final layer with Tanh activation to ensure output values are between -1 and 1 (for images)
x