Sora from OpenAI, Stable Video Diffusion from Stability AI, and many other text-to-video models that have come out or will appear in the future are among the most popular AI trends in 2024, following large language models (LLMs). In this blog, we will build a small scale text-to-video model from scratch. We will input a text prompt, and our trained model will generate a video based on that prompt. This blog will cover everything from understanding the theoretical concepts to coding the entire architecture and generating the final result.
Since I don’t have a fancy GPU, I’ve coded the small-scale architecture. Here’s a comparison of the time required to train the model on different processors:
Training Videos | Epochs | CPU | GPU A10 | GPU T4 |
---|---|---|---|---|
10K | 30 | more than 3 hr | 1 hr | 1 hr 42m |
30K | 30 | more than 6 hr | 1 hr 30 | 2 hr 30 |
100K | 30 | - | 3-4 hr | 5-6 hr |
Running on a CPU will obviously take much longer to train the model. If you need to quickly test changes in the code and see results, CPU is not the best choice. I recommend using a T4 GPU from Colab or Kaggle for more efficient and faster training.
Here is the blog link which guides you on how to create Stable Diffusion from scratch: Coding Stable Diffusion from Scratch
We will follow a similar approach to traditional machine learning or deep learning models that train on a dataset and are then tested on unseen data. In the context of text-to-video, let’s say we have a training dataset of 100K videos of dogs fetching balls and cats chasing mice. We will train our model to generate videos of a cat fetching a ball or a dog chasing a mouse.
Although such training datasets are easily available on the internet, the required computational power is extremely high. Therefore, we will work with a video dataset of moving objects generated from Python code.
We will use the GAN (Generative Adversarial Networks) architecture to create our model instead of the diffusion model that OpenAI Sora uses. I attempted to use the diffusion model, but it crashed due to memory requirements, which is beyond my capacity. GANs, on the other hand, are easier and quicker to train and test.
We will be using OOP (Object-Oriented Programming), so you must have a basic understanding of it along with neural networks. Knowledge of GANs (Generative Adversarial Networks) is not mandatory, as we will be covering their architecture here.
Topic | Link |
---|---|
OOP | Video Link |
Neural Networks Theory | Video Link |
GAN Architecture | Video Link |
Python basics | Video Link |
Understanding GAN architecture is important because much of our architecture depends on it. Let’s explore what it is, its components, and more.
Generative Adversarial Network (GAN) is a deep learning model where two neural networks compete: one creates new data (like images or music) from a given dataset, and the other tries to tell if the data is real or fake. This process continues until the generated data is indistinguishable from the original.
Generate Images: GANs create realistic images from text prompts or modify existing images, such as enhancing resolution or adding color to black-and-white photos.
Data Augmentation: They generate synthetic data to train other machine learning models, such as creating fraudulent transaction data for fraud detection systems.
Complete Missing Information: GANs can fill in missing data, like generating sub-surface images from terrain maps for energy applications.
Generate 3D Models: They convert 2D images into 3D models, useful in fields like healthcare for creating realistic organ images for surgical planning.
It consists of two deep neural networks: the generator and the discriminator. These networks train together in an adversarial setup, where one generates new data and the other evaluates if the data is real or fake.
Here’s a simplified overview of how GAN works:
Training Set Analysis: The generator analyzes the training set to identify data attributes, while the discriminator independently analyzes the same data to learn its attributes.
Data Modification: The generator adds noise (random changes) to some attributes of the data.
Data Passing: The modified data is then passed to the discriminator.
Probability Calculation: The discriminator calculates the probability that the generated data is from the original dataset.
Feedback Loop: The discriminator provides feedback to the generator, guiding it to reduce random noise in the next cycle.
Adversarial Training: The generator tries to maximize the discriminator’s mistakes, while the discriminator tries to minimize its own errors. Through many training iterations, both networks improve and evolve.
Equilibrium State: Training continues until the discriminator can no longer distinguish between real and synthesized data, indicating that the generator has successfully learned to produce realistic data. At this point, the training process is complete.
image from aws guide
Let’s explain the GAN model with an example of image-to-image translation, focusing on modifying a human face.
Input Image: The input is a real image of a human face.
Attribute Modification: The generator modifies attributes of the face, like adding sunglasses to the eyes.
Generated Images: The generator creates a set of images with sunglasses added.
Discriminator’s Task: The discriminator receives a mix of real images (people with sunglasses) and generated images (faces where sunglasses were added).
Evaluation: The discriminator tries to differentiate between real and generated images.
Feedback Loop: If the discriminator correctly identifies fake images, the generator adjusts its parameters to produce more convincing images. If the generator successfully fools the discriminator, the discriminator updates its parameters to improve its detection.
Through this adversarial process, both networks continuously improve. The generator gets better at creating realistic images, and the discriminator gets better at identifying fakes until equilibrium is reached, where the discriminator can no longer tell the difference between real and generated images. At this point, the GAN has successfully learned to produce realistic modifications.
Installing the required libraries is the first step in building our text-to-video model.
pip install -r requirements.txt
We will be working with a range of Python libraries, Let’s import them.
# Operating System module for interacting with the operating system
import os
# Module for generating random numbers
import random
# Module for numerical operations
import numpy as np
# OpenCV library for image processing
import cv2
# Python Imaging Library for image processing
from PIL import Image, ImageDraw, ImageFont
# PyTorch library for deep learning
import torch
# Dataset class for creating custom datasets in PyTorch
from torch.utils.data import Dataset
# Module for image transformations
import torchvision.transforms as transforms
# Neural network module in PyTorch
import torch.nn as nn
# Optimization algorithms in PyTorch
import torch.optim as optim
# Function for padding sequences in PyTorch
from torch.nn.utils.rnn import pad_sequence
# Function for saving images in PyTorch
from torchvision.utils import save_image
# Module for plotting graphs and images
import matplotlib.pyplot as plt
# Module for displaying rich content in IPython environments
from IPython.display import clear_output, display, HTML
# Module for encoding and decoding binary data to text
import base64
Now that we’ve imported all of our libraries, the next step is to define our training data that we will be using to train our GAN architecture.
We need to have at least 10,000 videos as training data. Why? Well, because I tested with smaller numbers and the results were very poor, practically nothing to see. The next big question is: what are these videos about? Our training video dataset consists of a circle moving in different directions with different motions. So, let’s code it and generate 10,000 videos to see what it looks like.
# Create a directory named 'training_dataset'
os.makedirs('training_dataset', exist_ok=True)
# Define the number of videos to generate for the dataset
num_videos = 10000
# Define the number of frames per video (1 Second Video)
frames_per_video = 10
# Define the size of each image in the dataset
img_size = (64, 64)
# Define the size of the shapes (Circle)
shape_size = 10
after settings some basic parameters next we need to define the text prompts of our training dataset based on which training videos will be generated.
# Define text prompts and corresponding movements for circles
prompts_and_movements = [
("circle moving down", "circle", "down"), # Move circle downward
("circle moving left", "circle", "left"), # Move circle leftward
("circle moving right", "circle", "right"), # Move circle rightward
("circle moving diagonally up-right", "circle", "diagonal_up_right"), # Move circle diagonally up-right
("circle moving diagonally down-left", "circle", "diagonal_down_left"), # Move circle diagonally down-left
("circle moving diagonally up-left", "circle", "diagonal_up_left"), # Move circle diagonally up-left
("circle moving diagonally down-right", "circle", "diagonal_down_right"), # Move circle diagonally down-right
("circle rotating clockwise", "circle", "rotate_clockwise"), # Rotate circle clockwise
("circle rotating counter-clockwise", "circle", "rotate_counter_clockwise"), # Rotate circle counter-clockwise
("circle shrinking", "circle", "shrink"), # Shrink circle
("circle expanding", "circle", "expand"), # Expand circle
("circle bouncing vertically", "circle", "bounce_vertical"), # Bounce circle vertically
("circle bouncing horizontally", "circle", "bounce_horizontal"), # Bounce circle horizontally
("circle zigzagging vertically", "circle", "zigzag_vertical"), # Zigzag circle vertically
("circle zigzagging horizontally", "circle", "zigzag_horizontal"), # Zigzag circle horizontally
("circle moving up-left", "circle", "up_left"), # Move circle up-left
("circle moving down-right", "circle", "down_right"), # Move circle down-right
("circle moving down-left", "circle", "down_left"), # Move circle down-left
]
We’ve defined several movements of our circle using these prompts. Now, we need to code some mathematical equations to move that circle based on the prompts.
# defining function to create image with moving shape
def create_image_with_moving_shape(size, frame_num, shape, direction):
# Create a new RGB image with specified size and white background
img = Image.new('RGB', size, color=(255, 255, 255))
# Create a drawing context for the image
draw = ImageDraw.Draw(img)
# Calculate the center coordinates of the image
center_x, center_y = size[0] // 2, size[1] // 2
# Initialize position with center for all movements
position = (center_x, center_y)
# Define a dictionary mapping directions to their respective position adjustments or image transformations
direction_map = {
# Adjust position downwards based on frame number
"down": (0, frame_num * 5 % size[1]),
# Adjust position to the left based on frame number
"left": (-frame_num * 5 % size[0], 0),
# Adjust position to the right based on frame number
"right": (frame_num * 5 % size[0], 0),
# Adjust position diagonally up and to the right
"diagonal_up_right": (frame_num * 5 % size[0], -frame_num * 5 % size[1]),
# Adjust position diagonally down and to the left
"diagonal_down_left": (-frame_num * 5 % size[0], frame_num * 5 % size[1]),
# Adjust position diagonally up and to the left
"diagonal_up_left": (-frame_num * 5 % size[0], -frame_num * 5 % size[1]),
# Adjust position diagonally down and to the right
"diagonal_down_right": (frame_num * 5 % size[0], frame_num * 5 % size[1]),
# Rotate the image clockwise based on frame number
"rotate_clockwise": img.rotate(frame_num * 10 % 360, center=(center_x, center_y), fillcolor=(255, 255, 255)),
# Rotate the image counter-clockwise based on frame number
"rotate_counter_clockwise": img.rotate(-frame_num * 10 % 360, center=(center_x, center_y), fillcolor=(255, 255, 255)),
# Adjust position for a bouncing effect vertically
"bounce_vertical": (0, center_y - abs(frame_num * 5 % size[1] - center_y)),
# Adjust position for a bouncing effect horizontally
"bounce_horizontal": (center_x - abs(frame_num * 5 % size[0] - center_x), 0),
# Adjust position for a zigzag effect vertically
"zigzag_vertical": (0, center_y - frame_num * 5 % size[1]) if frame_num % 2 == 0 else (0, center_y + frame_num * 5 % size[1]),
# Adjust position for a zigzag effect horizontally
"zigzag_horizontal": (center_x - frame_num * 5 % size[0], center_y) if frame_num % 2 == 0 else (center_x + frame_num * 5 % size[0], center_y),
# Adjust position upwards and to the right based on frame number
"up_right": (frame_num * 5 % size[0], -frame_num * 5 % size[1]),
# Adjust position upwards and to the left based on frame number
"up_left": (-frame_num * 5 % size[0], -frame_num * 5 % size[1]),
# Adjust position downwards and to the right based on frame number
"down_right": (frame_num * 5 % size[0], frame_num * 5 % size[1]),
# Adjust position downwards and to the left based on frame number
"down_left": (-frame_num * 5 % size[0], frame_num * 5 % size[1])
}
# Check if direction is in the direction map
if direction in direction_map:
# Check if the direction maps to a position adjustment
if isinstance(direction_map[direction], tuple):
# Update position based on the adjustment
position = tuple(np.add(position, direction_map[direction]))
else: # If the direction maps to an image transformation
# Update the image based on the transformation
img = direction_map[direction]
# Return the image as a numpy array
return np.array(img)
The function above is used to move our circle for each frame based on the selected direction. We just need to run a loop on top of it up to the number of videos times to generate all videos.
# Iterate over the number of videos to generate
for i in range(num_videos):
# Randomly choose a prompt and movement from the predefined list
prompt, shape, direction = random.choice(prompts_and_movements)
# Create a directory for the current video
video_dir = f'training_dataset/video_{i}'
os.makedirs(video_dir, exist_ok=True)
# Write the chosen prompt to a text file in the video directory
with open(f'{video_dir}/prompt.txt', 'w') as f:
f.write(prompt)
# Generate frames for the current video
for frame_num in range(frames_per_video):
# Create an image with a moving shape based on the current frame number, shape, and direction
img = create_image_with_moving_shape(img_size, frame_num, shape, direction)
# Save the generated image as a PNG file in the video directory
cv2.imwrite(f'{video_dir}/frame_{frame_num}.png', img)
Once you run the above code, it will generate our entire training dataset. Here is what the structure of our training dataset files looks like.
Each training video folder contains its frames along with its text prompt. Let’s take a look at the sample of our training dataset.
In our training dataset, we haven’t included the motion of the circle moving up and then to the right. We will use this as our testing prompt to evaluate our trained model on unseen data.
One more important point to note is that our training data does contains many samples where objects moving away from the scene or appear partially in front of the camera, similar to what we have observed in the OpenAI Sora demo videos.
The reason for including such samples in our training data is to test whether our model can maintain consistency when the circle enters the scene from the very corner without breaking its shape.
Now that our training data has been generated, we need to convert the training videos to tensors, which are the primary data type used in deep learning frameworks like PyTorch. Additionally, performing transformations like normalization helps improve the convergence and stability of the training architecture by scaling the data to a smaller range.
We have to code a dataset class for text-to-video tasks, which can read video frames and their corresponding text prompts from the training dataset directory, making them available for use in PyTorch.
# Define a dataset class inheriting from torch.utils.data.Dataset
class TextToVideoDataset(Dataset):
def __init__(self, root_dir, transform=None):
# Initialize the dataset with root directory and optional transform
self.root_dir = root_dir
self.transform = transform
# List all subdirectories in the root directory
self.video_dirs = [os.path.join(root_dir, d) for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))]
# Initialize lists to store frame paths and corresponding prompts
self.frame_paths = []
self.prompts = []
# Loop through each video directory
for video_dir in self.video_dirs:
# List all PNG files in the video directory and store their paths
frames = [os.path.join(video_dir, f) for f in os.listdir(video_dir) if f.endswith('.png')]
self.frame_paths.extend(frames)
# Read the prompt text file in the video directory and store its content
with open(os.path.join(video_dir, 'prompt.txt'), 'r') as f:
prompt = f.read().strip()
# Repeat the prompt for each frame in the video and store in prompts list
self.prompts.extend([prompt] * len(frames))
# Return the total number of samples in the dataset
def __len__(self):
return len(self.frame_paths)
# Retrieve a sample from the dataset given an index
def __getitem__(self, idx):
# Get the path of the frame corresponding to the given index
frame_path = self.frame_paths[idx]
# Open the image using PIL (Python Imaging Library)
image = Image.open(frame_path)
# Get the prompt corresponding to the given index
prompt = self.prompts[idx]
# Apply transformation if specified
if self.transform:
image = self.transform(image)
# Return the transformed image and the prompt
return image, prompt
Before proceeding to code the architecture, we need to normalize our training data. We will use a batch size of 16 and shuffle the data to introduce more randomness.
# Define a set of transformations to be applied to the data
transform = transforms.Compose([
transforms.ToTensor(), # Convert PIL Image or numpy.ndarray to tensor
transforms.Normalize((0.5,), (0.5,)) # Normalize image with mean and standard deviation
])
# Load the dataset using the defined transform
dataset = TextToVideoDataset(root_dir='training_dataset', transform=transform)
# Create a dataloader to iterate over the dataset
dataloader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=True)
You may have seen in transformer architecture where the starting point is to convert our text input into embedding for further processing in multi head attention similar here we have to code an text embedding layer based on which the GAN architecture training will take place on our embedding data and images tensor.
# Define a class for text embedding
class TextEmbedding(nn.Module):
# Constructor method with vocab_size and embed_size parameters
def __init__(self, vocab_size, embed_size):
# Call the superclass constructor
super(TextEmbedding, self).__init__()
# Initialize embedding layer
self.embedding = nn.Embedding(vocab_size, embed_size)
# Define the forward pass method
def forward(self, x):
# Return embedded representation of input
return self.embedding(x)
The vocabulary size will be based on our training data, which we will calculate later. The embedding size will be 10. If working with a larger dataset, you can also use your own choice of embedding model available on Hugging Face.
Now that we already know what the generator does in GANs, let’s code this layer and then understand its contents.
class Generator(nn.Module):
def __init__(self, text_embed_size):
super(Generator, self).__init__()
# Fully connected layer that takes noise and text embedding as input
self.fc1 = nn.Linear(100 + text_embed_size, 256 * 8 * 8)
# Transposed convolutional layers to upsample the input
self.deconv1 = nn.ConvTranspose2d(256, 128, 4, 2, 1)
self.deconv2 = nn.ConvTranspose2d(128, 64, 4, 2, 1)
self.deconv3 = nn.ConvTranspose2d(64, 3, 4, 2, 1) # Output has 3 channels for RGB images
# Activation functions
self.relu = nn.ReLU(True) # ReLU activation function
self.tanh = nn.Tanh() # Tanh activation function for final output
def forward(self, noise, text_embed):
# Concatenate noise and text embedding along the channel dimension
x = torch.cat((noise, text_embed), dim=1)
# Fully connected layer followed by reshaping to 4D tensor
x = self.fc1(x).view(-1, 256, 8, 8)
# Upsampling through transposed convolution layers with ReLU activation
x = self.relu(self.deconv1(x))
x = self.relu(self.deconv2(x))
# Final layer with Tanh activation to ensure output values are between -1 and 1 (for images)
x