Sora от OpenAI, Stable Video Diffusion от Stability AI и многие другие модели преобразования текста в видео, которые уже появились или появятся в будущем, входят в число самых популярных тенденций ИИ в 2024 году вслед за моделями больших языков (LLM). В этом блоге мы с нуля построим небольшую модель преобразования текста в видео . Мы введем текстовое приглашение, и наша обученная модель сгенерирует видео на основе этого приглашения. В этом блоге будет рассмотрено все: от понимания теоретических концепций до кодирования всей архитектуры и получения конечного результата.
Поскольку у меня нет мощного графического процессора, я написал код мелкомасштабной архитектуры. Вот сравнение времени, необходимого для обучения модели на разных процессорах:
Обучающие видео | Эпохи | Процессор | графический процессор A10 | графический процессор Т4 |
---|---|---|---|---|
10 тыс. | 30 | более 3 часов | 1 час | 1 час 42 минуты |
30 тыс. | 30 | более 6 часов | 1 час 30 | 2 часа 30 |
100 тыс. | 30 | - | 3-4 часа | 5-6 часов |
Очевидно, что обучение модели на процессоре займет гораздо больше времени. Если вам нужно быстро протестировать изменения в коде и увидеть результаты, CPU — не лучший выбор. Я рекомендую использовать графический процессор T4 от Colab или Kaggle для более эффективного и быстрого обучения.
Вот ссылка на блог, которая поможет вам создать Stable Diffusion с нуля: Кодирование Stable Diffusion с нуля.
Мы будем следовать подходу, аналогичному традиционному машинному обучению или моделям глубокого обучения, которые обучаются на наборе данных, а затем тестируются на невидимых данных. В контексте преобразования текста в видео, предположим, у нас есть обучающий набор данных из 100 тысяч видеороликов, на которых собаки ловят мячи, а кошки гоняются за мышами. Мы научим нашу модель генерировать видеоролики, на которых кошка несет мяч или собака гонится за мышью.
Хотя такие наборы обучающих данных легко доступны в Интернете, требуемая вычислительная мощность чрезвычайно высока. Поэтому мы будем работать с набором видеоданных движущихся объектов, сгенерированным из кода Python.
Мы будем использовать архитектуру GAN (генеративно-состязательные сети) для создания нашей модели вместо диффузионной модели, которую использует OpenAI Sora. Я попытался использовать диффузионную модель, но она вышла из строя из-за требований к памяти, что превышает мои возможности. С другой стороны, GAN легче и быстрее обучать и тестировать.
Мы будем использовать ООП (объектно-ориентированное программирование), поэтому вы должны иметь базовое представление об этом, а также о нейронных сетях. Знание GAN (генеративно-состязательных сетей) не является обязательным, поскольку здесь мы рассмотрим их архитектуру.
Тема | Связь |
---|---|
ООП | Ссылка на видео |
Теория нейронных сетей | Ссылка на видео |
ГАН Архитектура | Ссылка на видео |
Основы Python | Ссылка на видео |
Понимание архитектуры GAN важно, потому что от нее зависит большая часть нашей архитектуры. Давайте рассмотрим, что это такое, его компоненты и многое другое.
Генеративно-состязательная сеть (GAN) — это модель глубокого обучения, в которой конкурируют две нейронные сети: одна создает новые данные (например, изображения или музыку) из заданного набора данных, а другая пытается определить, являются ли данные реальными или поддельными. Этот процесс продолжается до тех пор, пока сгенерированные данные не станут неотличимы от оригинала.
Генерация изображений : GAN создают реалистичные изображения из текстовых подсказок или изменяют существующие изображения, например, повышая разрешение или добавляя цвет к черно-белым фотографиям.
Увеличение данных : они генерируют синтетические данные для обучения других моделей машинного обучения, например, для создания данных о мошеннических транзакциях для систем обнаружения мошенничества.
Полная недостающая информация : GAN могут заполнять недостающие данные, например, генерировать изображения недр на основе карт местности для энергетических приложений.
Создание 3D-моделей : они преобразуют 2D-изображения в 3D-модели, что полезно в таких областях, как здравоохранение, для создания реалистичных изображений органов для хирургического планирования.
Он состоит из двух глубоких нейронных сетей: генератора и дискриминатора . Эти сети тренируются вместе в состязательной конфигурации, где одна генерирует новые данные, а другая оценивает, настоящие они или фальшивые.
Вот упрощенный обзор того, как работает GAN:
Анализ обучающего набора : генератор анализирует обучающий набор для определения атрибутов данных, в то время как дискриминатор независимо анализирует те же данные, чтобы изучить его атрибуты.
Модификация данных : генератор добавляет шум (случайные изменения) к некоторым атрибутам данных.
Передача данных : измененные данные затем передаются дискриминатору.
Расчет вероятности : дискриминатор вычисляет вероятность того, что сгенерированные данные взяты из исходного набора данных.
Петля обратной связи : дискриминатор обеспечивает обратную связь с генератором, направляя его на уменьшение случайного шума в следующем цикле.
Состязательное обучение : генератор пытается максимизировать ошибки дискриминатора, в то время как дискриминатор пытается минимизировать свои собственные ошибки. Благодаря множеству итераций обучения обе сети совершенствуются и развиваются.
Состояние равновесия : обучение продолжается до тех пор, пока дискриминатор больше не сможет различать реальные и синтезированные данные, что указывает на то, что генератор успешно научился создавать реалистичные данные. На этом процесс обучения завершен.
изображение из руководства AWS
Давайте объясним модель GAN на примере перевода изображения в изображение, уделяя особое внимание изменению человеческого лица.
Входное изображение : входное изображение представляет собой реальное изображение человеческого лица.
Модификация атрибутов : генератор изменяет атрибуты лица, например добавляя солнцезащитные очки к глазам.
Сгенерированные изображения : генератор создает набор изображений с добавленными солнцезащитными очками.
Задача дискриминатора : Дискриминатор получает смесь реальных изображений (людей в темных очках) и сгенерированных изображений (лиц, на которых были добавлены солнцезащитные очки).
Оценка : Дискриминатор пытается отличить реальные и сгенерированные изображения.
Петля обратной связи : если дискриминатор правильно идентифицирует поддельные изображения, генератор корректирует свои параметры для создания более убедительных изображений. Если генератор успешно обманывает дискриминатор, дискриминатор обновляет свои параметры, чтобы улучшить обнаружение.
Благодаря этому состязательному процессу обе сети постоянно совершенствуются. Генератор становится лучше в создании реалистичных изображений, а дискриминатор — в выявлении подделок до тех пор, пока не будет достигнуто равновесие, когда дискриминатор больше не сможет отличить реальные и сгенерированные изображения. На данный момент ГАН успешно научился производить реалистичные модификации.
Установка необходимых библиотек — это первый шаг в построении нашей модели преобразования текста в видео.
pip install -r requirements.txt
Мы будем работать с рядом библиотек Python. Давайте их импортируем.
# Operating System module for interacting with the operating system
import os
# Module for generating random numbers
import random
# Module for numerical operations
import numpy as np
# OpenCV library for image processing
import cv2
# Python Imaging Library for image processing
from PIL import Image , ImageDraw , ImageFont
# PyTorch library for deep learning
import torch
# Dataset class for creating custom datasets in PyTorch
from torch . utils . data import Dataset
# Module for image transformations
import torchvision . transforms as transforms
# Neural network module in PyTorch
import torch . nn as nn
# Optimization algorithms in PyTorch
import torch . optim as optim
# Function for padding sequences in PyTorch
from torch . nn . utils . rnn import pad_sequence
# Function for saving images in PyTorch
from torchvision . utils import save_image
# Module for plotting graphs and images
import matplotlib . pyplot as plt
# Module for displaying rich content in IPython environments
from IPython . display import clear_output , display , HTML
# Module for encoding and decoding binary data to text
import base64
Теперь, когда мы импортировали все наши библиотеки, следующим шагом будет определение наших обучающих данных, которые мы будем использовать для обучения нашей архитектуры GAN.
Нам нужно иметь как минимум 10 000 видео в качестве обучающих данных. Почему? Ну, потому что я тестировал с меньшими числами, и результаты были очень плохими, практически ничего не видно. Следующий большой вопрос: о чем эти видео? Наш набор обучающих видеоданных состоит из круга, движущегося в разных направлениях с разными движениями. Итак, давайте напишем его и создадим 10 000 видеороликов, чтобы посмотреть, как это выглядит.
# Create a directory named 'training_dataset'
os . makedirs ( 'training_dataset' , exist_ok = True )
# Define the number of videos to generate for the dataset
num_videos = 10000
# Define the number of frames per video (1 Second Video)
frames_per_video = 10
# Define the size of each image in the dataset
img_size = ( 64 , 64 )
# Define the size of the shapes (Circle)
shape_size = 10
после настройки некоторых основных параметров нам нужно определить текстовые подсказки нашего набора обучающих данных, на основе которых будут создаваться обучающие видеоролики.
# Define text prompts and corresponding movements for circles
prompts_and_movements = [
( "circle moving down" , "circle" , "down" ), # Move circle downward
( "circle moving left" , "circle" , "left" ), # Move circle leftward
( "circle moving right" , "circle" , "right" ), # Move circle rightward
( "circle moving diagonally up-right" , "circle" , "diagonal_up_right" ), # Move circle diagonally up-right
( "circle moving diagonally down-left" , "circle" , "diagonal_down_left" ), # Move circle diagonally down-left
( "circle moving diagonally up-left" , "circle" , "diagonal_up_left" ), # Move circle diagonally up-left
( "circle moving diagonally down-right" , "circle" , "diagonal_down_right" ), # Move circle diagonally down-right
( "circle rotating clockwise" , "circle" , "rotate_clockwise" ), # Rotate circle clockwise
( "circle rotating counter-clockwise" , "circle" , "rotate_counter_clockwise" ), # Rotate circle counter-clockwise
( "circle shrinking" , "circle" , "shrink" ), # Shrink circle
( "circle expanding" , "circle" , "expand" ), # Expand circle
( "circle bouncing vertically" , "circle" , "bounce_vertical" ), # Bounce circle vertically
( "circle bouncing horizontally" , "circle" , "bounce_horizontal" ), # Bounce circle horizontally
( "circle zigzagging vertically" , "circle" , "zigzag_vertical" ), # Zigzag circle vertically
( "circle zigzagging horizontally" , "circle" , "zigzag_horizontal" ), # Zigzag circle horizontally
( "circle moving up-left" , "circle" , "up_left" ), # Move circle up-left
( "circle moving down-right" , "circle" , "down_right" ), # Move circle down-right
( "circle moving down-left" , "circle" , "down_left" ), # Move circle down-left
]
С помощью этих подсказок мы определили несколько движений нашего круга. Теперь нам нужно написать несколько математических уравнений, чтобы переместить этот круг на основе подсказок.
# defining function to create image with moving shape
def create_image_with_moving_shape ( size , frame_num , shape , direction ):
# Create a new RGB image with specified size and white background
img = Image . new ( 'RGB' , size , color = ( 255 , 255 , 255 ))
# Create a drawing context for the image
draw = ImageDraw . Draw ( img )
# Calculate the center coordinates of the image
center_x , center_y = size [ 0 ] // 2 , size [ 1 ] // 2
# Initialize position with center for all movements
position = ( center_x , center_y )
# Define a dictionary mapping directions to their respective position adjustments or image transformations
direction_map = {
# Adjust position downwards based on frame number
"down" : ( 0 , frame_num * 5 % size [ 1 ]),
# Adjust position to the left based on frame number
"left" : ( - frame_num * 5 % size [ 0 ], 0 ),
# Adjust position to the right based on frame number
"right" : ( frame_num * 5 % size [ 0 ], 0 ),
# Adjust position diagonally up and to the right
"diagonal_up_right" : ( frame_num * 5 % size [ 0 ], - frame_num * 5 % size [ 1 ]),
# Adjust position diagonally down and to the left
"diagonal_down_left" : ( - frame_num * 5 % size [ 0 ], frame_num * 5 % size [ 1 ]),
# Adjust position diagonally up and to the left
"diagonal_up_left" : ( - frame_num * 5 % size [ 0 ], - frame_num * 5 % size [ 1 ]),
# Adjust position diagonally down and to the right
"diagonal_down_right" : ( frame_num * 5 % size [ 0 ], frame_num * 5 % size [ 1 ]),
# Rotate the image clockwise based on frame number
"rotate_clockwise" : img . rotate ( frame_num * 10 % 360 , center = ( center_x , center_y ), fillcolor = ( 255 , 255 , 255 )),
# Rotate the image counter-clockwise based on frame number
"rotate_counter_clockwise" : img . rotate ( - frame_num * 10 % 360 , center = ( center_x , center_y ), fillcolor = ( 255 , 255 , 255 )),
# Adjust position for a bouncing effect vertically
"bounce_vertical" : ( 0 , center_y - abs ( frame_num * 5 % size [ 1 ] - center_y )),
# Adjust position for a bouncing effect horizontally
"bounce_horizontal" : ( center_x - abs ( frame_num * 5 % size [ 0 ] - center_x ), 0 ),
# Adjust position for a zigzag effect vertically
"zigzag_vertical" : ( 0 , center_y - frame_num * 5 % size [ 1 ]) if frame_num % 2 == 0 else ( 0 , center_y + frame_num * 5 % size [ 1 ]),
# Adjust position for a zigzag effect horizontally
"zigzag_horizontal" : ( center_x - frame_num * 5 % size [ 0 ], center_y ) if frame_num % 2 == 0 else ( center_x + frame_num * 5 % size [ 0 ], center_y ),
# Adjust position upwards and to the right based on frame number
"up_right" : ( frame_num * 5 % size [ 0 ], - frame_num * 5 % size [ 1 ]),
# Adjust position upwards and to the left based on frame number
"up_left" : ( - frame_num * 5 % size [ 0 ], - frame_num * 5 % size [ 1 ]),
# Adjust position downwards and to the right based on frame number
"down_right" : ( frame_num * 5 % size [ 0 ], frame_num * 5 % size [ 1 ]),
# Adjust position downwards and to the left based on frame number
"down_left" : ( - frame_num * 5 % size [ 0 ], frame_num * 5 % size [ 1 ])
}
# Check if direction is in the direction map
if direction in direction_map :
# Check if the direction maps to a position adjustment
if isinstance ( direction_map [ direction ], tuple ):
# Update position based on the adjustment
position = tuple ( np . add ( position , direction_map [ direction ]))
else : # If the direction maps to an image transformation
# Update the image based on the transformation
img = direction_map [ direction ]
# Return the image as a numpy array
return np . array ( img )
Функция выше используется для перемещения нашего круга для каждого кадра в зависимости от выбранного направления. Нам просто нужно запустить цикл поверх него до нужного количества видео раз, чтобы сгенерировать все видео.
# Iterate over the number of videos to generate
for i in range ( num_videos ):
# Randomly choose a prompt and movement from the predefined list
prompt , shape , direction = random . choice ( prompts_and_movements )
# Create a directory for the current video
video_dir = f'training_dataset/video_ { i } '
os . makedirs ( video_dir , exist_ok = True )
# Write the chosen prompt to a text file in the video directory
with open ( f' { video_dir } /prompt.txt' , 'w' ) as f :
f . write ( prompt )
# Generate frames for the current video
for frame_num in range ( frames_per_video ):
# Create an image with a moving shape based on the current frame number, shape, and direction
img = create_image_with_moving_shape ( img_size , frame_num , shape , direction )
# Save the generated image as a PNG file in the video directory
cv2 . imwrite ( f' { video_dir } /frame_ { frame_num } .png' , img )
Как только вы запустите приведенный выше код, он сгенерирует весь наш набор обучающих данных. Вот как выглядит структура наших файлов набора обучающих данных.
Каждая папка с обучающими видео содержит свои кадры и текстовую подсказку. Давайте посмотрим на образец нашего набора обучающих данных.
В наш набор обучающих данных мы не включили движение круга вверх, а затем вправо . Мы будем использовать это в качестве подсказки для тестирования, чтобы оценить нашу обученную модель на невидимых данных.
Еще один важный момент, на который следует обратить внимание: наши данные обучения содержат множество примеров, в которых объекты удаляются от сцены или частично появляются перед камерой, подобно тому, что мы наблюдали в демонстрационных видеороликах OpenAI Sora.
Причина включения таких образцов в наши обучающие данные — проверить, может ли наша модель сохранять согласованность, когда круг входит в сцену из самого угла, не нарушая при этом своей формы.
Теперь, когда наши обучающие данные сгенерированы, нам нужно преобразовать обучающие видеоролики в тензоры, которые являются основным типом данных, используемым в средах глубокого обучения, таких как PyTorch. Кроме того, выполнение таких преобразований, как нормализация, помогает улучшить сходимость и стабильность архитектуры обучения за счет масштабирования данных до меньшего диапазона.
Нам нужно написать класс набора данных для задач преобразования текста в видео, который может считывать видеокадры и соответствующие им текстовые подсказки из каталога набора обучающих данных, делая их доступными для использования в PyTorch.
# Define a dataset class inheriting from torch.utils.data.Dataset
class TextToVideoDataset ( Dataset ):
def __init__ ( self , root_dir , transform = None ):
# Initialize the dataset with root directory and optional transform
self . root_dir = root_dir
self . transform = transform
# List all subdirectories in the root directory
self . video_dirs = [ os . path . join ( root_dir , d ) for d in os . listdir ( root_dir ) if os . path . isdir ( os . path . join ( root_dir , d ))]
# Initialize lists to store frame paths and corresponding prompts
self . frame_paths = []
self . prompts = []
# Loop through each video directory
for video_dir in self . video_dirs :
# List all PNG files in the video directory and store their paths
frames = [ os . path . join ( video_dir , f ) for f in os . listdir ( video_dir ) if f . endswith ( '.png' )]
self . frame_paths . extend ( frames )
# Read the prompt text file in the video directory and store its content
with open ( os . path . join ( video_dir , 'prompt.txt' ), 'r' ) as f :
prompt = f . read (). strip ()
# Repeat the prompt for each frame in the video and store in prompts list
self . prompts . extend ([ prompt ] * len ( frames ))
# Return the total number of samples in the dataset
def __len__ ( self ):
return len ( self . frame_paths )
# Retrieve a sample from the dataset given an index
def __getitem__ ( self , idx ):
# Get the path of the frame corresponding to the given index
frame_path = self . frame_paths [ idx ]
# Open the image using PIL (Python Imaging Library)
image = Image . open ( frame_path )
# Get the prompt corresponding to the given index
prompt = self . prompts [ idx ]
# Apply transformation if specified
if self . transform :
image = self . transform ( image )
# Return the transformed image and the prompt
return image , prompt
Прежде чем приступить к кодированию архитектуры, нам необходимо нормализовать наши обучающие данные. Мы будем использовать размер пакета 16 и перетасовать данные, чтобы добавить больше случайности.
# Define a set of transformations to be applied to the data
transform = transforms . Compose ([
transforms . ToTensor (), # Convert PIL Image or numpy.ndarray to tensor
transforms . Normalize (( 0.5 ,), ( 0.5 ,)) # Normalize image with mean and standard deviation
])
# Load the dataset using the defined transform
dataset = TextToVideoDataset ( root_dir = 'training_dataset' , transform = transform )
# Create a dataloader to iterate over the dataset
dataloader = torch . utils . data . DataLoader ( dataset , batch_size = 16 , shuffle = True )
Возможно, вы видели в архитектуре преобразователя, где отправной точкой является преобразование нашего текстового ввода во встраивание для дальнейшей обработки с вниманием нескольких голов, аналогично здесь нам нужно закодировать слой встраивания текста, на основе которого будет происходить обучение архитектуры GAN на наших данных встраивания. и тензор изображений.
# Define a class for text embedding
class TextEmbedding ( nn . Module ):
# Constructor method with vocab_size and embed_size parameters
def __init__ ( self , vocab_size , embed_size ):
# Call the superclass constructor
super ( TextEmbedding , self ). __init__ ()
# Initialize embedding layer
self . embedding = nn . Embedding ( vocab_size , embed_size )
# Define the forward pass method
def forward ( self , x ):
# Return embedded representation of input
return self . embedding ( x )
Размер словарного запаса будет основан на наших данных обучения, которые мы рассчитаем позже. Размер встраивания будет равен 10. Если вы работаете с большим набором данных, вы также можете использовать свой собственный выбор модели встраивания, доступной на Hugging Face.
Теперь, когда мы уже знаем, что делает генератор в GAN, давайте закодируем этот уровень, а затем разберемся с его содержимым.
class Generator ( nn . Module ):
def __init__ ( self , text_embed_size ):
super ( Generator , self ). __init__ ()
# Fully connected layer that takes noise and text embedding as input
self . fc1 = nn . Linear ( 100 + text_embed_size , 256 * 8 * 8 )
# Transposed convolutional layers to upsample the input
self . deconv1 = nn . ConvTranspose2d ( 256 , 128 , 4 , 2 , 1 )
self . deconv2 = nn . ConvTranspose2d ( 128 , 64 , 4 , 2 , 1 )
self . deconv3 = nn . ConvTranspose2d ( 64 , 3 , 4 , 2 , 1 ) # Output has 3 channels for RGB images
# Activation functions
self . relu = nn . ReLU ( True ) # ReLU activation function
self . tanh = nn . Tanh () # Tanh activation function for final output
def forward ( self , noise , text_embed ):
# Concatenate noise and text embedding along the channel dimension
x = torch . cat (( noise , text_embed ), dim = 1 )
# Fully connected layer followed by reshaping to 4D tensor
x = self . fc1 ( x ). view ( - 1 , 256 , 8 , 8 )
# Upsampling through transposed convolution layers with ReLU activation
x = self . relu ( self . deconv1 ( x ))
x = self . relu ( self . deconv2 ( x ))
# Final layer with Tanh activation to ensure output values are between -1 and 1 (for images)
x