OpenAI の Sora、Stability AI の Stable Video Diffusion、および登場した、または今後登場する他の多くのテキストからビデオへのモデルは、大規模言語モデル (LLM) に次いで 2024 年に最も人気のある AI トレンドの 1 つです。このブログでは、小規模なテキストからビデオへのモデルを最初から構築します。テキスト プロンプトを入力すると、トレーニングされたモデルがそのプロンプトに基づいてビデオを生成します。このブログでは、理論的概念の理解からアーキテクチャ全体のコーディング、最終結果の生成までのすべてを取り上げます。
私は豪華な GPU を持っていないので、小規模なアーキテクチャをコーディングしました。以下は、さまざまなプロセッサーでモデルをトレーニングするのに必要な時間を比較したものです。
トレーニングビデオ | エポック | CPU | GPU A10 | GPU T4 |
---|---|---|---|---|
10K | 30 | 3時間以上 | 1時間 | 1時間42分 |
30K | 30 | 6時間以上 | 1時間30分 | 2時間30分 |
100K | 30 | - | 3~4時間 | 5~6時間 |
CPU 上で実行すると、モデルのトレーニングに明らかに長い時間がかかります。コードの変更を迅速にテストして結果を確認する必要がある場合、CPU は最適な選択ではありません。トレーニングをより効率的かつ高速に行うには、Colab または Kaggle の T4 GPU を使用することをお勧めします。
これは、安定した拡散を最初から作成する方法をガイドするブログ リンクです: 安定した拡散を最初からコーディングする
データセットでトレーニングし、目に見えないデータでテストする従来の機械学習または深層学習モデルに対しても同様のアプローチに従います。テキストからビデオへのコンテキストで、犬がボールを取ってくる様子や猫がネズミを追いかけている様子を撮影した 100,000 個のビデオのトレーニング データセットがあるとします。ボールを取ってくる猫やネズミを追いかける犬のビデオを生成するようにモデルをトレーニングします。
このようなトレーニング データセットはインターネットで簡単に入手できますが、必要な計算能力は非常に高くなります。したがって、Python コードから生成された移動オブジェクトのビデオ データセットを使用します。
OpenAI Sora が使用する拡散モデルの代わりに、GAN (敵対的生成ネットワーク) アーキテクチャを使用してモデルを作成します。拡散モデルを使用しようとしましたが、私の能力を超えたメモリ要件によりクラッシュしました。一方、GAN は、トレーニングとテストがより簡単かつ迅速です。
OOP (オブジェクト指向プログラミング) を使用するので、ニューラル ネットワークとともに OOP (オブジェクト指向プログラミング) の基本を理解しておく必要があります。ここではそのアーキテクチャについて説明するため、GAN (敵対的生成ネットワーク) の知識は必須ではありません。
トピック | リンク |
---|---|
OOP | ビデオリンク |
ニューラルネットワーク理論 | ビデオリンク |
GAN アーキテクチャ | ビデオリンク |
Python の基本 | ビデオリンク |
アーキテクチャの多くは GAN アーキテクチャに依存しているため、GAN アーキテクチャを理解することが重要です。それが何であるか、そのコンポーネントなどを調べてみましょう。
Generative Adversarial Network (GAN) は、2 つのニューラル ネットワークが競合する深層学習モデルです。1 つは指定されたデータセットから新しいデータ (画像や音楽など) を作成し、もう 1 つはデータが本物か偽物かを判断しようとします。このプロセスは、生成されたデータが元のデータと区別できなくなるまで続きます。
画像の生成: GAN は、テキスト プロンプトからリアルな画像を作成したり、解像度を高めたり、白黒写真に色を追加したりするなど、既存の画像を変更します。
データ拡張: 不正検出システム用の不正取引データの作成など、他の機械学習モデルをトレーニングするための合成データを生成します。
欠落情報の完全化: GAN は、エネルギー用途のために地形図から地下画像を生成するなど、欠落データを埋めることができます。
3D モデルの生成: 2D 画像を 3D モデルに変換します。これは、医療などの分野で手術計画用のリアルな臓器画像を作成するのに役立ちます。
これは、ジェネレーターとディスクリミネーターの 2 つのディープ ニューラル ネットワークで構成されます。これらのネットワークは、一方が新しいデータを生成し、もう一方がデータが本物か偽物かを評価する、敵対的な設定で一緒にトレーニングします。
GAN の仕組みの簡略化した概要は次のとおりです。
トレーニング セット分析: ジェネレーターはトレーニング セットを分析してデータ属性を特定し、ディスクリミネーターは同じデータを独立して分析してその属性を学習します。
データ変更: ジェネレーターは、データの一部の属性にノイズ (ランダムな変化) を追加します。
データの受け渡し: 変更されたデータはディスクリミネーターに渡されます。
確率計算: 弁別器は、生成されたデータが元のデータセットからのものである確率を計算します。
フィードバック ループ: ディスクリミネーターはジェネレーターにフィードバックを提供し、次のサイクルでランダム ノイズを低減するようにジェネレーターを導きます。
敵対的トレーニング: ジェネレーターはディスクリミネーターの間違いを最大化しようとしますが、ディスクリミネーターはそれ自体のエラーを最小限に抑えようとします。トレーニングを何度も繰り返すことで、両方のネットワークが改善され、進化します。
平衡状態: トレーニングは、ディスクリミネーターが実際のデータと合成データを区別できなくなるまで継続され、ジェネレーターが現実的なデータを生成する方法を正常に学習したことを示します。この時点で、トレーニングプロセスは完了です。
awsガイドからの画像
人間の顔の変更に焦点を当て、画像間の変換の例を使用して GAN モデルを説明しましょう。
入力画像: 入力は人間の顔の実画像です。
属性変更: ジェネレーターは、目にサングラスを追加するなど、顔の属性を変更します。
生成された画像: ジェネレーターは、サングラスが追加された一連の画像を作成します。
弁別者のタスク: 弁別者は、実際の画像 (サングラスをかけた人物) と生成された画像 (サングラスが追加された顔) の混合を受け取ります。
評価: ディスクリミネーターは、実際の画像と生成された画像を区別しようとします。
フィードバック ループ: ディスクリミネーターが偽の画像を正しく識別すると、ジェネレーターはパラメーターを調整して、より説得力のある画像を生成します。ジェネレーターがディスクリミネーターを騙すことに成功すると、ディスクリミネーターはパラメーターを更新して検出を改善します。
この敵対的なプロセスを通じて、両方のネットワークは継続的に改善されます。ジェネレータはリアルな画像を作成する能力が向上し、ディスクリミネータは平衡に達するまで偽物を識別する能力が向上し、ディスクリミネータは本物の画像と生成された画像の違いを見分けることができなくなります。この時点で、GAN は現実的な変更を行う方法を学習することに成功しました。
必要なライブラリをインストールすることは、テキストからビデオへのモデルを構築する最初のステップです。
pip install -r requirements.txt
さまざまな Python ライブラリを使用して作業するので、それらをインポートしてみましょう。
# Operating System module for interacting with the operating system
import os
# Module for generating random numbers
import random
# Module for numerical operations
import numpy as np
# OpenCV library for image processing
import cv2
# Python Imaging Library for image processing
from PIL import Image , ImageDraw , ImageFont
# PyTorch library for deep learning
import torch
# Dataset class for creating custom datasets in PyTorch
from torch . utils . data import Dataset
# Module for image transformations
import torchvision . transforms as transforms
# Neural network module in PyTorch
import torch . nn as nn
# Optimization algorithms in PyTorch
import torch . optim as optim
# Function for padding sequences in PyTorch
from torch . nn . utils . rnn import pad_sequence
# Function for saving images in PyTorch
from torchvision . utils import save_image
# Module for plotting graphs and images
import matplotlib . pyplot as plt
# Module for displaying rich content in IPython environments
from IPython . display import clear_output , display , HTML
# Module for encoding and decoding binary data to text
import base64
すべてのライブラリをインポートしたので、次のステップは、GAN アーキテクチャのトレーニングに使用するトレーニング データを定義することです。
トレーニング データとして少なくとも 10,000 本のビデオが必要です。なぜ?そうですね、小さい数値でテストした結果は非常に貧弱で、実質的には何も見えませんでした。次の大きな疑問は、これらのビデオは何についてのものなのかということです。私たちのトレーニング ビデオ データセットは、異なる動きで異なる方向に移動する円で構成されています。それでは、これをコーディングして 10,000 個のビデオを生成して、どのようになるかを確認してみましょう。
# Create a directory named 'training_dataset'
os . makedirs ( 'training_dataset' , exist_ok = True )
# Define the number of videos to generate for the dataset
num_videos = 10000
# Define the number of frames per video (1 Second Video)
frames_per_video = 10
# Define the size of each image in the dataset
img_size = ( 64 , 64 )
# Define the size of the shapes (Circle)
shape_size = 10
いくつかの基本パラメータを設定した後、次に生成されるトレーニング ビデオに基づいてトレーニング データセットのテキスト プロンプトを定義する必要があります。
# Define text prompts and corresponding movements for circles
prompts_and_movements = [
( "circle moving down" , "circle" , "down" ), # Move circle downward
( "circle moving left" , "circle" , "left" ), # Move circle leftward
( "circle moving right" , "circle" , "right" ), # Move circle rightward
( "circle moving diagonally up-right" , "circle" , "diagonal_up_right" ), # Move circle diagonally up-right
( "circle moving diagonally down-left" , "circle" , "diagonal_down_left" ), # Move circle diagonally down-left
( "circle moving diagonally up-left" , "circle" , "diagonal_up_left" ), # Move circle diagonally up-left
( "circle moving diagonally down-right" , "circle" , "diagonal_down_right" ), # Move circle diagonally down-right
( "circle rotating clockwise" , "circle" , "rotate_clockwise" ), # Rotate circle clockwise
( "circle rotating counter-clockwise" , "circle" , "rotate_counter_clockwise" ), # Rotate circle counter-clockwise
( "circle shrinking" , "circle" , "shrink" ), # Shrink circle
( "circle expanding" , "circle" , "expand" ), # Expand circle
( "circle bouncing vertically" , "circle" , "bounce_vertical" ), # Bounce circle vertically
( "circle bouncing horizontally" , "circle" , "bounce_horizontal" ), # Bounce circle horizontally
( "circle zigzagging vertically" , "circle" , "zigzag_vertical" ), # Zigzag circle vertically
( "circle zigzagging horizontally" , "circle" , "zigzag_horizontal" ), # Zigzag circle horizontally
( "circle moving up-left" , "circle" , "up_left" ), # Move circle up-left
( "circle moving down-right" , "circle" , "down_right" ), # Move circle down-right
( "circle moving down-left" , "circle" , "down_left" ), # Move circle down-left
]
これらのプロンプトを使用して、サークルのいくつかの動きを定義しました。ここで、プロンプトに基づいて円を移動するための数式をコーディングする必要があります。
# defining function to create image with moving shape
def create_image_with_moving_shape ( size , frame_num , shape , direction ):
# Create a new RGB image with specified size and white background
img = Image . new ( 'RGB' , size , color = ( 255 , 255 , 255 ))
# Create a drawing context for the image
draw = ImageDraw . Draw ( img )
# Calculate the center coordinates of the image
center_x , center_y = size [ 0 ] // 2 , size [ 1 ] // 2
# Initialize position with center for all movements
position = ( center_x , center_y )
# Define a dictionary mapping directions to their respective position adjustments or image transformations
direction_map = {
# Adjust position downwards based on frame number
"down" : ( 0 , frame_num * 5 % size [ 1 ]),
# Adjust position to the left based on frame number
"left" : ( - frame_num * 5 % size [ 0 ], 0 ),
# Adjust position to the right based on frame number
"right" : ( frame_num * 5 % size [ 0 ], 0 ),
# Adjust position diagonally up and to the right
"diagonal_up_right" : ( frame_num * 5 % size [ 0 ], - frame_num * 5 % size [ 1 ]),
# Adjust position diagonally down and to the left
"diagonal_down_left" : ( - frame_num * 5 % size [ 0 ], frame_num * 5 % size [ 1 ]),
# Adjust position diagonally up and to the left
"diagonal_up_left" : ( - frame_num * 5 % size [ 0 ], - frame_num * 5 % size [ 1 ]),
# Adjust position diagonally down and to the right
"diagonal_down_right" : ( frame_num * 5 % size [ 0 ], frame_num * 5 % size [ 1 ]),
# Rotate the image clockwise based on frame number
"rotate_clockwise" : img . rotate ( frame_num * 10 % 360 , center = ( center_x , center_y ), fillcolor = ( 255 , 255 , 255 )),
# Rotate the image counter-clockwise based on frame number
"rotate_counter_clockwise" : img . rotate ( - frame_num * 10 % 360 , center = ( center_x , center_y ), fillcolor = ( 255 , 255 , 255 )),
# Adjust position for a bouncing effect vertically
"bounce_vertical" : ( 0 , center_y - abs ( frame_num * 5 % size [ 1 ] - center_y )),
# Adjust position for a bouncing effect horizontally
"bounce_horizontal" : ( center_x - abs ( frame_num * 5 % size [ 0 ] - center_x ), 0 ),
# Adjust position for a zigzag effect vertically
"zigzag_vertical" : ( 0 , center_y - frame_num * 5 % size [ 1 ]) if frame_num % 2 == 0 else ( 0 , center_y + frame_num * 5 % size [ 1 ]),
# Adjust position for a zigzag effect horizontally
"zigzag_horizontal" : ( center_x - frame_num * 5 % size [ 0 ], center_y ) if frame_num % 2 == 0 else ( center_x + frame_num * 5 % size [ 0 ], center_y ),
# Adjust position upwards and to the right based on frame number
"up_right" : ( frame_num * 5 % size [ 0 ], - frame_num * 5 % size [ 1 ]),
# Adjust position upwards and to the left based on frame number
"up_left" : ( - frame_num * 5 % size [ 0 ], - frame_num * 5 % size [ 1 ]),
# Adjust position downwards and to the right based on frame number
"down_right" : ( frame_num * 5 % size [ 0 ], frame_num * 5 % size [ 1 ]),
# Adjust position downwards and to the left based on frame number
"down_left" : ( - frame_num * 5 % size [ 0 ], frame_num * 5 % size [ 1 ])
}
# Check if direction is in the direction map
if direction in direction_map :
# Check if the direction maps to a position adjustment
if isinstance ( direction_map [ direction ], tuple ):
# Update position based on the adjustment
position = tuple ( np . add ( position , direction_map [ direction ]))
else : # If the direction maps to an image transformation
# Update the image based on the transformation
img = direction_map [ direction ]
# Return the image as a numpy array
return np . array ( img )
上記の関数は、選択した方向に基づいてフレームごとに円を移動するために使用されます。すべてのビデオを生成するには、その上でビデオの数までループを実行するだけです。
# Iterate over the number of videos to generate
for i in range ( num_videos ):
# Randomly choose a prompt and movement from the predefined list
prompt , shape , direction = random . choice ( prompts_and_movements )
# Create a directory for the current video
video_dir = f'training_dataset/video_ { i } '
os . makedirs ( video_dir , exist_ok = True )
# Write the chosen prompt to a text file in the video directory
with open ( f' { video_dir } /prompt.txt' , 'w' ) as f :
f . write ( prompt )
# Generate frames for the current video
for frame_num in range ( frames_per_video ):
# Create an image with a moving shape based on the current frame number, shape, and direction
img = create_image_with_moving_shape ( img_size , frame_num , shape , direction )
# Save the generated image as a PNG file in the video directory
cv2 . imwrite ( f' { video_dir } /frame_ { frame_num } .png' , img )
上記のコードを実行すると、トレーニング データセット全体が生成されます。トレーニング データセット ファイルの構造は次のとおりです。
各トレーニング ビデオ フォルダーには、フレームとテキスト プロンプトが含まれています。トレーニング データセットのサンプルを見てみましょう。
私たちのトレーニング データセットには、上に移動してから右に移動する円の動きは含まれていません。これをテスト プロンプトとして使用し、目に見えないデータでトレーニング済みモデルを評価します。
もう 1 つ注意すべき重要な点は、OpenAI Sora のデモ ビデオで観察されたものと同様に、オブジェクトがシーンから遠ざかったり、カメラの前に部分的に現れたりするサンプルがトレーニング データに多数含まれていることです。
このようなサンプルをトレーニング データに含める理由は、円がシーンの隅から入ったときに、形状を崩すことなくモデルが一貫性を維持できるかどうかをテストするためです。
トレーニング データが生成されたので、トレーニング ビデオをテンソルに変換する必要があります。テンソルは、PyTorch などのディープ ラーニング フレームワークで使用される主要なデータ タイプです。さらに、正規化などの変換を実行すると、データをより小さな範囲にスケールすることでトレーニング アーキテクチャの収束と安定性が向上します。
テキストからビデオへのタスク用のデータセット クラスをコーディングする必要があります。これにより、ビデオ フレームとそれに対応するテキスト プロンプトをトレーニング データセット ディレクトリから読み取って、PyTorch で使用できるようになります。
# Define a dataset class inheriting from torch.utils.data.Dataset
class TextToVideoDataset ( Dataset ):
def __init__ ( self , root_dir , transform = None ):
# Initialize the dataset with root directory and optional transform
self . root_dir = root_dir
self . transform = transform
# List all subdirectories in the root directory
self . video_dirs = [ os . path . join ( root_dir , d ) for d in os . listdir ( root_dir ) if os . path . isdir ( os . path . join ( root_dir , d ))]
# Initialize lists to store frame paths and corresponding prompts
self . frame_paths = []
self . prompts = []
# Loop through each video directory
for video_dir in self . video_dirs :
# List all PNG files in the video directory and store their paths
frames = [ os . path . join ( video_dir , f ) for f in os . listdir ( video_dir ) if f . endswith ( '.png' )]
self . frame_paths . extend ( frames )
# Read the prompt text file in the video directory and store its content
with open ( os . path . join ( video_dir , 'prompt.txt' ), 'r' ) as f :
prompt = f . read (). strip ()
# Repeat the prompt for each frame in the video and store in prompts list
self . prompts . extend ([ prompt ] * len ( frames ))
# Return the total number of samples in the dataset
def __len__ ( self ):
return len ( self . frame_paths )
# Retrieve a sample from the dataset given an index
def __getitem__ ( self , idx ):
# Get the path of the frame corresponding to the given index
frame_path = self . frame_paths [ idx ]
# Open the image using PIL (Python Imaging Library)
image = Image . open ( frame_path )
# Get the prompt corresponding to the given index
prompt = self . prompts [ idx ]
# Apply transformation if specified
if self . transform :
image = self . transform ( image )
# Return the transformed image and the prompt
return image , prompt
アーキテクチャのコーディングに進む前に、トレーニング データを正規化する必要があります。バッチ サイズ 16 を使用し、データをシャッフルして、よりランダム性を導入します。
# Define a set of transformations to be applied to the data
transform = transforms . Compose ([
transforms . ToTensor (), # Convert PIL Image or numpy.ndarray to tensor
transforms . Normalize (( 0.5 ,), ( 0.5 ,)) # Normalize image with mean and standard deviation
])
# Load the dataset using the defined transform
dataset = TextToVideoDataset ( root_dir = 'training_dataset' , transform = transform )
# Create a dataloader to iterate over the dataset
dataloader = torch . utils . data . DataLoader ( dataset , batch_size = 16 , shuffle = True )
トランスフォーマ アーキテクチャで、マルチヘッド アテンションでさらに処理するためにテキスト入力を埋め込みに変換することが開始点であることを見たことがあるかもしれません。ここでも同様に、埋め込みデータに対して GAN アーキテクチャのトレーニングが行われるベースとなるテキスト埋め込み層をコーディングする必要があります。そして画像テンソル。
# Define a class for text embedding
class TextEmbedding ( nn . Module ):
# Constructor method with vocab_size and embed_size parameters
def __init__ ( self , vocab_size , embed_size ):
# Call the superclass constructor
super ( TextEmbedding , self ). __init__ ()
# Initialize embedding layer
self . embedding = nn . Embedding ( vocab_size , embed_size )
# Define the forward pass method
def forward ( self , x ):
# Return embedded representation of input
return self . embedding ( x )
語彙サイズはトレーニング データに基づいており、後で計算します。埋め込みサイズは 10 になります。より大きなデータセットを扱う場合は、Hugging Face で利用可能な埋め込みモデルを独自に選択して使用することもできます。
GAN でジェネレーターが何を行うかはすでにわかったので、この層をコーディングしてその内容を理解しましょう。
class Generator ( nn . Module ):
def __init__ ( self , text_embed_size ):
super ( Generator , self ). __init__ ()
# Fully connected layer that takes noise and text embedding as input
self . fc1 = nn . Linear ( 100 + text_embed_size , 256 * 8 * 8 )
# Transposed convolutional layers to upsample the input
self . deconv1 = nn . ConvTranspose2d ( 256 , 128 , 4 , 2 , 1 )
self . deconv2 = nn . ConvTranspose2d ( 128 , 64 , 4 , 2 , 1 )
self . deconv3 = nn . ConvTranspose2d ( 64 , 3 , 4 , 2 , 1 ) # Output has 3 channels for RGB images
# Activation functions
self . relu = nn . ReLU ( True ) # ReLU activation function
self . tanh = nn . Tanh () # Tanh activation function for final output
def forward ( self , noise , text_embed ):
# Concatenate noise and text embedding along the channel dimension
x = torch . cat (( noise , text_embed ), dim = 1 )
# Fully connected layer followed by reshaping to 4D tensor
x = self . fc1 ( x ). view ( - 1 , 256 , 8 , 8 )
# Upsampling through transposed convolution layers with ReLU activation
x = self . relu ( self . deconv1 ( x ))
x = self . relu ( self . deconv2 ( x ))
# Final layer with Tanh activation to ensure output values are between -1 and 1 (for images)
x