OpenAI의 Sora, Stability AI의 Stable Video Diffusion 등 현재 나왔거나 앞으로 등장할 수많은 텍스트-비디오 모델은 LLM(Large Language Model)에 이어 2024년 가장 인기 있는 AI 트렌드 중 하나입니다. 이 블로그에서는 소규모 텍스트-비디오 모델을 처음부터 구축해 보겠습니다. 텍스트 프롬프트를 입력하면 훈련된 모델이 해당 프롬프트를 기반으로 비디오를 생성합니다. 이 블로그에서는 이론적 개념 이해부터 전체 아키텍처 코딩 및 최종 결과 생성까지 모든 내용을 다룹니다.
화려한 GPU가 없기 때문에 소규모 아키텍처를 코딩했습니다. 다음은 다양한 프로세서에서 모델을 학습하는 데 필요한 시간을 비교한 것입니다.
교육 비디오 | 시대 | CPU | GPU A10 | GPU T4 |
---|---|---|---|---|
10K | 30 | 3시간 이상 | 1시간 | 1시간 42분 |
30K | 30 | 6시간 이상 | 1시간 30분 | 2시간 30분 |
100K | 30 | - | 3~4시간 | 5~6시간 |
CPU에서 실행하면 모델을 훈련하는 데 훨씬 더 오랜 시간이 걸립니다. 코드 변경 사항을 빠르게 테스트하고 결과를 확인해야 한다면 CPU는 최선의 선택이 아닙니다. 보다 효율적이고 빠른 학습을 위해서는 Colab 또는 Kaggle의 T4 GPU를 사용하는 것이 좋습니다.
다음은 처음부터 Stable Diffusion을 생성하는 방법을 안내하는 블로그 링크입니다. Coding Stable Diffusion from Scratch
우리는 데이터 세트를 학습한 다음 보이지 않는 데이터를 테스트하는 기존 기계 학습 또는 딥 러닝 모델과 유사한 접근 방식을 따를 것입니다. 텍스트-비디오의 맥락에서 공을 가져오는 개와 쥐를 쫓는 고양이에 대한 10만 개의 비디오로 구성된 훈련 데이터 세트가 있다고 가정해 보겠습니다. 우리는 공을 가져오는 고양이나 쥐를 쫓는 개가 나오는 비디오를 생성하도록 모델을 훈련시킬 것입니다.
이러한 교육 데이터 세트는 인터넷에서 쉽게 사용할 수 있지만 필요한 계산 능력은 매우 높습니다. 따라서 우리는 Python 코드에서 생성된 움직이는 객체의 비디오 데이터 세트를 사용하여 작업할 것입니다.
OpenAI Sora가 사용하는 확산 모델 대신 GAN(Generative Adversarial Networks) 아키텍처를 사용하여 모델을 생성하겠습니다. 확산 모델을 사용하려고 시도했지만 용량을 초과하는 메모리 요구 사항으로 인해 충돌이 발생했습니다. 반면 GAN은 훈련과 테스트가 더 쉽고 빠릅니다.
우리는 OOP(객체 지향 프로그래밍)을 사용할 것이므로 신경망과 함께 OOP에 대한 기본적인 이해가 있어야 합니다. 여기서는 해당 아키텍처를 다루므로 GAN(Generative Adversarial Networks)에 대한 지식은 필수는 아닙니다.
주제 | 링크 |
---|---|
이런! | 비디오 링크 |
신경망 이론 | 비디오 링크 |
GAN 아키텍처 | 비디오 링크 |
파이썬 기초 | 비디오 링크 |
우리 아키텍처의 대부분이 GAN 아키텍처에 의존하기 때문에 GAN 아키텍처를 이해하는 것이 중요합니다. 그것이 무엇인지, 구성요소 등을 살펴보겠습니다.
GAN(Generative Adversarial Network)은 두 개의 신경망이 경쟁하는 딥 러닝 모델입니다. 하나는 주어진 데이터 세트에서 새로운 데이터(예: 이미지 또는 음악)를 생성하고 다른 하나는 데이터가 진짜인지 가짜인지 확인하려고 시도합니다. 이 프로세스는 생성된 데이터가 원본과 구별되지 않을 때까지 계속됩니다.
이미지 생성 : GAN은 텍스트 프롬프트에서 사실적인 이미지를 생성하거나 해상도를 높이거나 흑백 사진에 색상을 추가하는 등 기존 이미지를 수정합니다.
데이터 증강 : 사기 탐지 시스템을 위한 사기 거래 데이터 생성과 같은 다른 기계 학습 모델을 교육하기 위해 합성 데이터를 생성합니다.
완전한 누락 정보 : GAN은 에너지 애플리케이션을 위해 지형 지도에서 지하 이미지를 생성하는 것과 같이 누락된 데이터를 채울 수 있습니다.
3D 모델 생성 : 2D 이미지를 3D 모델로 변환합니다. 이는 의료와 같은 분야에서 수술 계획을 위한 사실적인 장기 이미지를 생성하는 데 유용합니다.
이는 생성기 와 판별기라는 두 개의 심층 신경망으로 구성됩니다. 이러한 네트워크는 하나는 새로운 데이터를 생성하고 다른 하나는 데이터가 진짜인지 가짜인지 평가하는 적대적 설정에서 함께 훈련됩니다.
다음은 GAN 작동 방식에 대한 간략한 개요입니다.
훈련 세트 분석 : 생성기는 훈련 세트를 분석하여 데이터 속성을 식별하고, 판별기는 동일한 데이터를 독립적으로 분석하여 속성을 학습합니다.
데이터 수정 : 생성기는 데이터의 일부 속성에 노이즈(무작위 변경)를 추가합니다.
데이터 전달(Data Passing) : 수정된 데이터가 판별자에게 전달됩니다.
확률 계산 : 판별자는 생성된 데이터가 원본 데이터 세트에서 나온 확률을 계산합니다.
피드백 루프(Feedback Loop) : 판별기는 생성기에 피드백을 제공하여 다음 주기에서 무작위 잡음을 줄이도록 안내합니다.
적대적 훈련(Adversarial Training) : 생성자는 판별자의 실수를 최대화하려고 시도하고 판별자는 자신의 오류를 최소화하려고 시도합니다. 많은 훈련 반복을 통해 두 네트워크 모두 개선되고 발전합니다.
평형 상태(Equilibrium State) : 판별자가 더 이상 실제 데이터와 합성 데이터를 구별할 수 없을 때까지 훈련이 계속됩니다. 이는 생성자가 현실적인 데이터를 생성하는 방법을 성공적으로 학습했음을 나타냅니다. 이 시점에서 학습 프로세스가 완료됩니다.
AWS 가이드의 이미지
인간의 얼굴 수정에 초점을 맞춰 이미지 간 변환의 예를 들어 GAN 모델을 설명하겠습니다.
입력 이미지 : 입력은 사람 얼굴의 실제 이미지입니다.
속성 수정 : 생성기는 눈에 선글라스를 추가하는 것과 같이 얼굴의 속성을 수정합니다.
생성된 이미지 : 생성기는 선글라스가 추가된 이미지 세트를 생성합니다.
판별자의 임무 : 판별자는 실제 이미지(선글라스를 낀 사람)와 생성된 이미지(선글라스를 낀 얼굴)를 혼합하여 수신합니다.
평가 : 판별자는 실제 이미지와 생성된 이미지를 구별하려고 시도합니다.
피드백 루프 : 판별기가 가짜 이미지를 올바르게 식별하면 생성기는 매개변수를 조정하여 보다 설득력 있는 이미지를 생성합니다. 생성기가 판별자를 속이는 데 성공하면 판별자는 매개변수를 업데이트하여 감지 기능을 향상시킵니다.
이러한 적대적 프로세스를 통해 두 네트워크 모두 지속적으로 개선됩니다. 생성기는 사실적인 이미지를 더 잘 생성하고, 판별기는 판별기가 더 이상 실제 이미지와 생성된 이미지의 차이를 구분할 수 없는 평형에 도달할 때까지 가짜를 식별하는 데 더 능숙합니다. 이 시점에서 GAN은 현실적인 수정을 생성하는 방법을 성공적으로 학습했습니다.
필요한 라이브러리를 설치하는 것은 텍스트-비디오 모델을 구축하는 첫 번째 단계입니다.
pip install -r requirements.txt
우리는 다양한 Python 라이브러리를 사용하여 작업할 것입니다. 이를 가져와 보겠습니다.
# Operating System module for interacting with the operating system
import os
# Module for generating random numbers
import random
# Module for numerical operations
import numpy as np
# OpenCV library for image processing
import cv2
# Python Imaging Library for image processing
from PIL import Image , ImageDraw , ImageFont
# PyTorch library for deep learning
import torch
# Dataset class for creating custom datasets in PyTorch
from torch . utils . data import Dataset
# Module for image transformations
import torchvision . transforms as transforms
# Neural network module in PyTorch
import torch . nn as nn
# Optimization algorithms in PyTorch
import torch . optim as optim
# Function for padding sequences in PyTorch
from torch . nn . utils . rnn import pad_sequence
# Function for saving images in PyTorch
from torchvision . utils import save_image
# Module for plotting graphs and images
import matplotlib . pyplot as plt
# Module for displaying rich content in IPython environments
from IPython . display import clear_output , display , HTML
# Module for encoding and decoding binary data to text
import base64
이제 모든 라이브러리를 가져왔으므로 다음 단계는 GAN 아키텍처를 훈련하는 데 사용할 훈련 데이터를 정의하는 것입니다.
훈련 데이터로 최소 10,000개의 비디오가 필요합니다. 왜? 글쎄요, 더 작은 숫자로 테스트했는데 결과가 매우 나빴기 때문에 사실상 아무것도 볼 수 없었습니다. 다음 큰 질문은: 이 비디오는 무엇에 관한 것입니까? 우리의 교육 비디오 데이터 세트는 다양한 동작으로 다양한 방향으로 움직이는 원으로 구성됩니다. 그럼, 코드를 작성하고 10,000개의 비디오를 생성하여 어떻게 보이는지 살펴보겠습니다.
# Create a directory named 'training_dataset'
os . makedirs ( 'training_dataset' , exist_ok = True )
# Define the number of videos to generate for the dataset
num_videos = 10000
# Define the number of frames per video (1 Second Video)
frames_per_video = 10
# Define the size of each image in the dataset
img_size = ( 64 , 64 )
# Define the size of the shapes (Circle)
shape_size = 10
몇 가지 기본 매개변수를 설정한 후 생성할 교육 비디오를 기반으로 교육 데이터세트의 텍스트 프롬프트를 정의해야 합니다.
# Define text prompts and corresponding movements for circles
prompts_and_movements = [
( "circle moving down" , "circle" , "down" ), # Move circle downward
( "circle moving left" , "circle" , "left" ), # Move circle leftward
( "circle moving right" , "circle" , "right" ), # Move circle rightward
( "circle moving diagonally up-right" , "circle" , "diagonal_up_right" ), # Move circle diagonally up-right
( "circle moving diagonally down-left" , "circle" , "diagonal_down_left" ), # Move circle diagonally down-left
( "circle moving diagonally up-left" , "circle" , "diagonal_up_left" ), # Move circle diagonally up-left
( "circle moving diagonally down-right" , "circle" , "diagonal_down_right" ), # Move circle diagonally down-right
( "circle rotating clockwise" , "circle" , "rotate_clockwise" ), # Rotate circle clockwise
( "circle rotating counter-clockwise" , "circle" , "rotate_counter_clockwise" ), # Rotate circle counter-clockwise
( "circle shrinking" , "circle" , "shrink" ), # Shrink circle
( "circle expanding" , "circle" , "expand" ), # Expand circle
( "circle bouncing vertically" , "circle" , "bounce_vertical" ), # Bounce circle vertically
( "circle bouncing horizontally" , "circle" , "bounce_horizontal" ), # Bounce circle horizontally
( "circle zigzagging vertically" , "circle" , "zigzag_vertical" ), # Zigzag circle vertically
( "circle zigzagging horizontally" , "circle" , "zigzag_horizontal" ), # Zigzag circle horizontally
( "circle moving up-left" , "circle" , "up_left" ), # Move circle up-left
( "circle moving down-right" , "circle" , "down_right" ), # Move circle down-right
( "circle moving down-left" , "circle" , "down_left" ), # Move circle down-left
]
우리는 이러한 프롬프트를 사용하여 원의 여러 움직임을 정의했습니다. 이제 프롬프트에 따라 해당 원을 이동하기 위해 몇 가지 수학 방정식을 코딩해야 합니다.
# defining function to create image with moving shape
def create_image_with_moving_shape ( size , frame_num , shape , direction ):
# Create a new RGB image with specified size and white background
img = Image . new ( 'RGB' , size , color = ( 255 , 255 , 255 ))
# Create a drawing context for the image
draw = ImageDraw . Draw ( img )
# Calculate the center coordinates of the image
center_x , center_y = size [ 0 ] // 2 , size [ 1 ] // 2
# Initialize position with center for all movements
position = ( center_x , center_y )
# Define a dictionary mapping directions to their respective position adjustments or image transformations
direction_map = {
# Adjust position downwards based on frame number
"down" : ( 0 , frame_num * 5 % size [ 1 ]),
# Adjust position to the left based on frame number
"left" : ( - frame_num * 5 % size [ 0 ], 0 ),
# Adjust position to the right based on frame number
"right" : ( frame_num * 5 % size [ 0 ], 0 ),
# Adjust position diagonally up and to the right
"diagonal_up_right" : ( frame_num * 5 % size [ 0 ], - frame_num * 5 % size [ 1 ]),
# Adjust position diagonally down and to the left
"diagonal_down_left" : ( - frame_num * 5 % size [ 0 ], frame_num * 5 % size [ 1 ]),
# Adjust position diagonally up and to the left
"diagonal_up_left" : ( - frame_num * 5 % size [ 0 ], - frame_num * 5 % size [ 1 ]),
# Adjust position diagonally down and to the right
"diagonal_down_right" : ( frame_num * 5 % size [ 0 ], frame_num * 5 % size [ 1 ]),
# Rotate the image clockwise based on frame number
"rotate_clockwise" : img . rotate ( frame_num * 10 % 360 , center = ( center_x , center_y ), fillcolor = ( 255 , 255 , 255 )),
# Rotate the image counter-clockwise based on frame number
"rotate_counter_clockwise" : img . rotate ( - frame_num * 10 % 360 , center = ( center_x , center_y ), fillcolor = ( 255 , 255 , 255 )),
# Adjust position for a bouncing effect vertically
"bounce_vertical" : ( 0 , center_y - abs ( frame_num * 5 % size [ 1 ] - center_y )),
# Adjust position for a bouncing effect horizontally
"bounce_horizontal" : ( center_x - abs ( frame_num * 5 % size [ 0 ] - center_x ), 0 ),
# Adjust position for a zigzag effect vertically
"zigzag_vertical" : ( 0 , center_y - frame_num * 5 % size [ 1 ]) if frame_num % 2 == 0 else ( 0 , center_y + frame_num * 5 % size [ 1 ]),
# Adjust position for a zigzag effect horizontally
"zigzag_horizontal" : ( center_x - frame_num * 5 % size [ 0 ], center_y ) if frame_num % 2 == 0 else ( center_x + frame_num * 5 % size [ 0 ], center_y ),
# Adjust position upwards and to the right based on frame number
"up_right" : ( frame_num * 5 % size [ 0 ], - frame_num * 5 % size [ 1 ]),
# Adjust position upwards and to the left based on frame number
"up_left" : ( - frame_num * 5 % size [ 0 ], - frame_num * 5 % size [ 1 ]),
# Adjust position downwards and to the right based on frame number
"down_right" : ( frame_num * 5 % size [ 0 ], frame_num * 5 % size [ 1 ]),
# Adjust position downwards and to the left based on frame number
"down_left" : ( - frame_num * 5 % size [ 0 ], frame_num * 5 % size [ 1 ])
}
# Check if direction is in the direction map
if direction in direction_map :
# Check if the direction maps to a position adjustment
if isinstance ( direction_map [ direction ], tuple ):
# Update position based on the adjustment
position = tuple ( np . add ( position , direction_map [ direction ]))
else : # If the direction maps to an image transformation
# Update the image based on the transformation
img = direction_map [ direction ]
# Return the image as a numpy array
return np . array ( img )
위의 기능은 선택한 방향을 기준으로 각 프레임의 원을 이동하는 데 사용됩니다. 모든 비디오를 생성하려면 비디오 횟수만큼 루프를 실행하면 됩니다.
# Iterate over the number of videos to generate
for i in range ( num_videos ):
# Randomly choose a prompt and movement from the predefined list
prompt , shape , direction = random . choice ( prompts_and_movements )
# Create a directory for the current video
video_dir = f'training_dataset/video_ { i } '
os . makedirs ( video_dir , exist_ok = True )
# Write the chosen prompt to a text file in the video directory
with open ( f' { video_dir } /prompt.txt' , 'w' ) as f :
f . write ( prompt )
# Generate frames for the current video
for frame_num in range ( frames_per_video ):
# Create an image with a moving shape based on the current frame number, shape, and direction
img = create_image_with_moving_shape ( img_size , frame_num , shape , direction )
# Save the generated image as a PNG file in the video directory
cv2 . imwrite ( f' { video_dir } /frame_ { frame_num } .png' , img )
위 코드를 실행하면 전체 교육 데이터 세트가 생성됩니다. 훈련 데이터 세트 파일의 구조는 다음과 같습니다.
각 교육 비디오 폴더에는 텍스트 프롬프트와 함께 프레임이 포함되어 있습니다. 훈련 데이터 세트의 샘플을 살펴보겠습니다.
훈련 데이터 세트에는 원이 위로 이동한 후 오른쪽으로 이동하는 모션이 포함되지 않았습니다. 우리는 이것을 보이지 않는 데이터에 대해 훈련된 모델을 평가하기 위한 테스트 프롬프트로 사용할 것입니다.
주목해야 할 또 다른 중요한 점은 훈련 데이터에 OpenAI Sora 데모 비디오에서 관찰한 것과 유사하게 객체가 장면에서 멀어지거나 카메라 앞에 부분적으로 나타나는 많은 샘플이 포함되어 있다는 것입니다.
훈련 데이터에 이러한 샘플을 포함시키는 이유는 원이 모양을 깨지 않고 장면의 맨 구석에서 들어올 때 모델이 일관성을 유지할 수 있는지 테스트하기 위한 것입니다.
이제 훈련 데이터가 생성되었으므로 훈련 비디오를 PyTorch와 같은 딥 러닝 프레임워크에서 사용되는 기본 데이터 유형인 텐서로 변환해야 합니다. 또한 정규화와 같은 변환을 수행하면 데이터를 더 작은 범위로 확장하여 훈련 아키텍처의 수렴과 안정성을 향상시키는 데 도움이 됩니다.
텍스트-비디오 작업을 위한 데이터 세트 클래스를 코딩해야 합니다. 이 클래스는 훈련 데이터 세트 디렉터리에서 비디오 프레임과 해당 텍스트 프롬프트를 읽어 PyTorch에서 사용할 수 있도록 합니다.
# Define a dataset class inheriting from torch.utils.data.Dataset
class TextToVideoDataset ( Dataset ):
def __init__ ( self , root_dir , transform = None ):
# Initialize the dataset with root directory and optional transform
self . root_dir = root_dir
self . transform = transform
# List all subdirectories in the root directory
self . video_dirs = [ os . path . join ( root_dir , d ) for d in os . listdir ( root_dir ) if os . path . isdir ( os . path . join ( root_dir , d ))]
# Initialize lists to store frame paths and corresponding prompts
self . frame_paths = []
self . prompts = []
# Loop through each video directory
for video_dir in self . video_dirs :
# List all PNG files in the video directory and store their paths
frames = [ os . path . join ( video_dir , f ) for f in os . listdir ( video_dir ) if f . endswith ( '.png' )]
self . frame_paths . extend ( frames )
# Read the prompt text file in the video directory and store its content
with open ( os . path . join ( video_dir , 'prompt.txt' ), 'r' ) as f :
prompt = f . read (). strip ()
# Repeat the prompt for each frame in the video and store in prompts list
self . prompts . extend ([ prompt ] * len ( frames ))
# Return the total number of samples in the dataset
def __len__ ( self ):
return len ( self . frame_paths )
# Retrieve a sample from the dataset given an index
def __getitem__ ( self , idx ):
# Get the path of the frame corresponding to the given index
frame_path = self . frame_paths [ idx ]
# Open the image using PIL (Python Imaging Library)
image = Image . open ( frame_path )
# Get the prompt corresponding to the given index
prompt = self . prompts [ idx ]
# Apply transformation if specified
if self . transform :
image = self . transform ( image )
# Return the transformed image and the prompt
return image , prompt
아키텍처 코딩을 진행하기 전에 훈련 데이터를 정규화해야 합니다. 배치 크기 16을 사용하고 데이터를 섞어서 더 많은 무작위성을 도입하겠습니다.
# Define a set of transformations to be applied to the data
transform = transforms . Compose ([
transforms . ToTensor (), # Convert PIL Image or numpy.ndarray to tensor
transforms . Normalize (( 0.5 ,), ( 0.5 ,)) # Normalize image with mean and standard deviation
])
# Load the dataset using the defined transform
dataset = TextToVideoDataset ( root_dir = 'training_dataset' , transform = transform )
# Create a dataloader to iterate over the dataset
dataloader = torch . utils . data . DataLoader ( dataset , batch_size = 16 , shuffle = True )
시작점이 멀티 헤드 어텐션의 추가 처리를 위해 텍스트 입력을 임베딩으로 변환하는 변환기 아키텍처에서 본 적이 있을 것입니다. 여기서는 임베딩 데이터에 대해 GAN 아키텍처 훈련이 수행될 기반으로 텍스트 임베딩 레이어를 코딩해야 합니다. 그리고 이미지 텐서.
# Define a class for text embedding
class TextEmbedding ( nn . Module ):
# Constructor method with vocab_size and embed_size parameters
def __init__ ( self , vocab_size , embed_size ):
# Call the superclass constructor
super ( TextEmbedding , self ). __init__ ()
# Initialize embedding layer
self . embedding = nn . Embedding ( vocab_size , embed_size )
# Define the forward pass method
def forward ( self , x ):
# Return embedded representation of input
return self . embedding ( x )
어휘 크기는 훈련 데이터를 기반으로 하며 나중에 계산할 것입니다. 임베딩 크기는 10입니다. 더 큰 데이터 세트로 작업하는 경우 Hugging Face에서 사용할 수 있는 임베딩 모델을 직접 선택할 수도 있습니다.
이제 생성기가 GAN에서 수행하는 작업을 이미 알았으므로 이 계층을 코딩한 다음 해당 내용을 이해해 보겠습니다.
class Generator ( nn . Module ):
def __init__ ( self , text_embed_size ):
super ( Generator , self ). __init__ ()
# Fully connected layer that takes noise and text embedding as input
self . fc1 = nn . Linear ( 100 + text_embed_size , 256 * 8 * 8 )
# Transposed convolutional layers to upsample the input
self . deconv1 = nn . ConvTranspose2d ( 256 , 128 , 4 , 2 , 1 )
self . deconv2 = nn . ConvTranspose2d ( 128 , 64 , 4 , 2 , 1 )
self . deconv3 = nn . ConvTranspose2d ( 64 , 3 , 4 , 2 , 1 ) # Output has 3 channels for RGB images
# Activation functions
self . relu = nn . ReLU ( True ) # ReLU activation function
self . tanh = nn . Tanh () # Tanh activation function for final output
def forward ( self , noise , text_embed ):
# Concatenate noise and text embedding along the channel dimension
x = torch . cat (( noise , text_embed ), dim = 1 )
# Fully connected layer followed by reshaping to 4D tensor
x = self . fc1 ( x ). view ( - 1 , 256 , 8 , 8 )
# Upsampling through transposed convolution layers with ReLU activation
x = self . relu ( self . deconv1 ( x ))
x = self . relu ( self . deconv2 ( x ))
# Final layer with Tanh activation to ensure output values are between -1 and 1 (for images)
x