Sora จาก OpenAI, Stable Video Diffusion จาก Stability AI และโมเดลข้อความเป็นวิดีโออื่นๆ อีกมากมายที่จะเปิดตัวหรือจะปรากฏในอนาคต ถือเป็นเทรนด์ AI ที่ได้รับความนิยมมากที่สุดในปี 2024 ตามโมเดลภาษาขนาดใหญ่ (LLM) ในบล็อกนี้ เราจะสร้าง โมเดลข้อความเป็นวิดีโอขนาดเล็กตั้งแต่เริ่มต้น เราจะป้อนข้อความแจ้ง และโมเดลที่ผ่านการฝึกอบรมของเราจะสร้างวิดีโอตามข้อความแจ้งนั้น บล็อกนี้จะครอบคลุมทุกอย่างตั้งแต่การทำความเข้าใจแนวคิดทางทฤษฎีไปจนถึงการเขียนโค้ดสถาปัตยกรรมทั้งหมดและสร้างผลลัพธ์สุดท้าย
เนื่องจากฉันไม่มี GPU ที่หรูหรา ฉันจึงเขียนโค้ดสถาปัตยกรรมขนาดเล็ก ต่อไปนี้คือการเปรียบเทียบเวลาที่ต้องใช้ในการฝึกโมเดลบนโปรเซสเซอร์ที่แตกต่างกัน:
วิดีโอการฝึกอบรม | ยุค | ซีพียู | จีพียู A10 | จีพียู T4 |
---|---|---|---|---|
10ก | 30 | มากกว่า 3 ชม | 1 ชม | 1 ชม. 42น |
30ก | 30 | มากกว่า 6 ชม | 1 ชม. 30 | 2 ชม. 30 |
100K | 30 | - | 3-4 ชม | 5-6 ชม |
การรันบน CPU จะใช้เวลาในการฝึกโมเดลนานกว่ามากอย่างเห็นได้ชัด หากคุณต้องการทดสอบการเปลี่ยนแปลงโค้ดอย่างรวดเร็วและดูผลลัพธ์ CPU ไม่ใช่ตัวเลือกที่ดีที่สุด ฉันแนะนำให้ใช้ T4 GPU จาก Colab หรือ Kaggle เพื่อการฝึกที่มีประสิทธิภาพและรวดเร็วยิ่งขึ้น
นี่คือลิงค์บล็อกที่จะแนะนำคุณเกี่ยวกับวิธีการสร้าง Stable Diffusion ตั้งแต่เริ่มต้น: การเข้ารหัส Stable Diffusion ตั้งแต่เริ่มต้น
เราจะปฏิบัติตามแนวทางที่คล้ายกันกับการเรียนรู้ของเครื่องแบบดั้งเดิมหรือโมเดลการเรียนรู้เชิงลึกที่ฝึกฝนบนชุดข้อมูลแล้วทดสอบกับข้อมูลที่มองไม่เห็น ในบริบทของการแปลงข้อความเป็นวิดีโอ สมมติว่าเรามีชุดข้อมูลการฝึกอบรมที่มีวิดีโอจำนวน 100,000 รายการเกี่ยวกับสุนัขหยิบลูกบอลและแมวไล่หนู เราจะฝึกโมเดลของเราเพื่อสร้างวิดีโอแมวหยิบลูกบอลหรือสุนัขไล่หนู
แม้ว่าชุดข้อมูลการฝึกอบรมดังกล่าวจะเข้าถึงได้ง่ายบนอินเทอร์เน็ต แต่พลังในการคำนวณที่ต้องการนั้นสูงมาก ดังนั้น เราจะทำงานร่วมกับชุดข้อมูลวิดีโอของวัตถุที่กำลังเคลื่อนที่ซึ่งสร้างจากโค้ด Python
เราจะใช้สถาปัตยกรรม GAN (Generative Adversarial Networks) เพื่อสร้างโมเดลของเราแทนโมเดลการแพร่กระจายที่ OpenAI Sora ใช้ ฉันพยายามใช้โมเดลการแพร่กระจาย แต่มันล้มเหลวเนื่องจากความต้องการหน่วยความจำ ซึ่งเกินความสามารถของฉัน ในทางกลับกัน GAN นั้นฝึกและทดสอบได้ง่ายกว่าและเร็วกว่า
เราจะใช้ OOP (Object-Oriented Programming) ดังนั้นคุณต้องมีความเข้าใจพื้นฐานเกี่ยวกับมันควบคู่ไปกับโครงข่ายประสาทเทียม ความรู้เกี่ยวกับ GAN (Generative Adversarial Networks) ไม่จำเป็น เนื่องจากเราจะกล่าวถึงสถาปัตยกรรมของพวกเขาที่นี่
หัวข้อ | ลิงค์ |
---|---|
อุ๊ย | ลิงค์วิดีโอ |
ทฤษฎีโครงข่ายประสาทเทียม | ลิงค์วิดีโอ |
สถาปัตยกรรม GAN | ลิงค์วิดีโอ |
ข้อมูลพื้นฐานเกี่ยวกับหลาม | ลิงค์วิดีโอ |
การทำความเข้าใจสถาปัตยกรรม GAN มีความสำคัญเนื่องจากสถาปัตยกรรมส่วนใหญ่ของเราขึ้นอยู่กับสถาปัตยกรรมนั้น เรามาสำรวจว่ามันคืออะไร ส่วนประกอบ และอื่นๆ อีกมากมาย
Generative Adversarial Network (GAN) เป็นโมเดลการเรียนรู้เชิงลึกที่โครงข่ายประสาทสองเครือข่ายแข่งขันกัน โดยโครงข่ายหนึ่งสร้างข้อมูลใหม่ (เช่น รูปภาพหรือเพลง) จากชุดข้อมูลที่กำหนด และอีกโครงหนึ่งพยายามบอกว่าข้อมูลนั้นเป็นของจริงหรือของปลอม กระบวนการนี้จะดำเนินต่อไปจนกว่าข้อมูลที่สร้างขึ้นจะแยกไม่ออกจากต้นฉบับ
สร้างรูปภาพ : GAN สร้างภาพที่เหมือนจริงจากข้อความแจ้งหรือแก้ไขรูปภาพที่มีอยู่ เช่น การเพิ่มความละเอียดหรือการเพิ่มสีให้กับภาพถ่ายขาวดำ
การเพิ่มข้อมูล : สร้างข้อมูลสังเคราะห์เพื่อฝึกโมเดลการเรียนรู้ของเครื่องอื่นๆ เช่น การสร้างข้อมูลธุรกรรมที่ฉ้อโกงสำหรับระบบตรวจจับการฉ้อโกง
กรอกข้อมูลที่ขาดหายไป : GAN สามารถกรอกข้อมูลที่ขาดหายไป เช่น การสร้างภาพพื้นผิวย่อยจากแผนที่ภูมิประเทศเพื่อการประยุกต์ใช้พลังงาน
สร้างโมเดล 3 มิติ : แปลงรูปภาพ 2 มิติเป็นโมเดล 3 มิติ ซึ่งมีประโยชน์ในด้านต่างๆ เช่น การดูแลสุขภาพ เพื่อสร้างภาพอวัยวะที่สมจริงสำหรับการวางแผนการผ่าตัด
ประกอบด้วยโครงข่ายประสาทเทียมเชิงลึกสองเครือข่าย: เครื่องกำเนิด และ ตัวแบ่งแยก เครือข่ายเหล่านี้ฝึกฝนร่วมกันในการตั้งค่าที่ไม่ตรงกัน โดยเครือข่ายหนึ่งจะสร้างข้อมูลใหม่และอีกเครือข่ายหนึ่งจะประเมินว่าข้อมูลนั้นเป็นของจริงหรือของปลอม
ต่อไปนี้เป็นภาพรวมอย่างง่ายเกี่ยวกับวิธีการทำงานของ GAN:
การวิเคราะห์ชุดการฝึกอบรม : เครื่องกำเนิดไฟฟ้าจะวิเคราะห์ชุดการฝึกอบรมเพื่อระบุคุณลักษณะของข้อมูล ในขณะที่ผู้เลือกปฏิบัติจะวิเคราะห์ข้อมูลเดียวกันอย่างอิสระเพื่อเรียนรู้คุณลักษณะของมัน
การปรับเปลี่ยนข้อมูล : ตัวสร้างจะเพิ่มสัญญาณรบกวน (การเปลี่ยนแปลงแบบสุ่ม) ให้กับคุณลักษณะบางอย่างของข้อมูล
การส่งผ่านข้อมูล : ข้อมูลที่แก้ไขจะถูกส่งไปยังผู้แยกแยะ
การคำนวณความน่าจะเป็น : ตัวแยกแยะจะคำนวณความน่าจะเป็นที่ข้อมูลที่สร้างขึ้นมาจากชุดข้อมูลดั้งเดิม
Feedback Loop : ตัวแบ่งแยกจะให้ผลป้อนกลับไปยังเครื่องกำเนิดไฟฟ้า เพื่อเป็นแนวทางในการลดสัญญาณรบกวนแบบสุ่มในรอบถัดไป
การฝึกอบรมฝ่ายตรงข้าม : เครื่องกำเนิดไฟฟ้าพยายามเพิ่มข้อผิดพลาดของผู้เลือกปฏิบัติให้สูงสุด ในขณะที่ผู้เลือกปฏิบัติพยายามลดข้อผิดพลาดของตนเองให้เหลือน้อยที่สุด ด้วยการฝึกอบรมซ้ำหลายครั้ง ทั้งสองเครือข่ายได้รับการปรับปรุงและพัฒนา
สภาวะสมดุล : การฝึกอบรมดำเนินต่อไปจนกว่าผู้แยกแยะจะไม่สามารถแยกแยะระหว่างข้อมูลจริงและข้อมูลสังเคราะห์ได้อีกต่อไป ซึ่งบ่งชี้ว่าเครื่องกำเนิดได้เรียนรู้ที่จะสร้างข้อมูลที่เป็นจริงได้สำเร็จ เมื่อถึงจุดนี้ กระบวนการฝึกอบรมก็เสร็จสมบูรณ์
ภาพจากคู่มือ aws
เรามาอธิบายโมเดล GAN ด้วยตัวอย่างการแปลจากภาพเป็นภาพโดยเน้นไปที่การปรับเปลี่ยนใบหน้าของมนุษย์
ภาพที่นำเข้า : ข้อมูลที่นำเข้าเป็นภาพจริงของใบหน้ามนุษย์
การปรับเปลี่ยนคุณลักษณะ : ตัวสร้างจะปรับเปลี่ยนคุณลักษณะของใบหน้า เช่น การใส่แว่นกันแดดที่ดวงตา
รูปภาพที่สร้าง : ตัวสร้างจะสร้างชุดรูปภาพโดยเพิ่มแว่นกันแดด
ภารกิจของผู้เลือกปฏิบัติ : ผู้เลือกปฏิบัติจะได้รับทั้งภาพจริง (คนที่สวมแว่นกันแดด) และภาพที่สร้างขึ้น (ใบหน้าที่มีการเพิ่มแว่นกันแดด)
การประเมิน : ผู้เลือกปฏิบัติพยายามแยกความแตกต่างระหว่างภาพจริงและภาพที่สร้างขึ้น
Feedback Loop : หากผู้แยกแยะระบุภาพปลอมได้อย่างถูกต้อง ตัวสร้างจะปรับพารามิเตอร์เพื่อสร้างภาพที่น่าเชื่อถือมากขึ้น หากตัวสร้างสามารถหลอกผู้เลือกปฏิบัติได้สำเร็จ ผู้เลือกปฏิบัติจะอัปเดตพารามิเตอร์เพื่อปรับปรุงการตรวจจับ
ด้วยกระบวนการปฏิปักษ์นี้ ทั้งสองเครือข่ายจึงปรับปรุงอย่างต่อเนื่อง เครื่องกำเนิดจะดีขึ้นในการสร้างภาพที่สมจริง และผู้แยกแยะก็จะดีขึ้นในการระบุของปลอมจนกว่าจะถึงจุดสมดุล โดยที่ผู้แยกแยะไม่สามารถบอกความแตกต่างระหว่างภาพจริงและภาพที่สร้างขึ้นได้อีกต่อไป ณ จุดนี้ GAN ได้เรียนรู้ที่จะสร้างการดัดแปลงที่สมจริงได้สำเร็จ
การติดตั้งไลบรารีที่จำเป็นเป็นขั้นตอนแรกในการสร้างโมเดลข้อความเป็นวิดีโอของเรา
pip install -r requirements.txt
เราจะทำงานร่วมกับไลบรารี Python ต่างๆ มานำเข้ากันเถอะ
# Operating System module for interacting with the operating system
import os
# Module for generating random numbers
import random
# Module for numerical operations
import numpy as np
# OpenCV library for image processing
import cv2
# Python Imaging Library for image processing
from PIL import Image , ImageDraw , ImageFont
# PyTorch library for deep learning
import torch
# Dataset class for creating custom datasets in PyTorch
from torch . utils . data import Dataset
# Module for image transformations
import torchvision . transforms as transforms
# Neural network module in PyTorch
import torch . nn as nn
# Optimization algorithms in PyTorch
import torch . optim as optim
# Function for padding sequences in PyTorch
from torch . nn . utils . rnn import pad_sequence
# Function for saving images in PyTorch
from torchvision . utils import save_image
# Module for plotting graphs and images
import matplotlib . pyplot as plt
# Module for displaying rich content in IPython environments
from IPython . display import clear_output , display , HTML
# Module for encoding and decoding binary data to text
import base64
ตอนนี้เราได้นำเข้าไลบรารีทั้งหมดของเราแล้ว ขั้นตอนต่อไปคือการกำหนดข้อมูลการฝึกอบรมที่เราจะใช้ในการฝึกอบรมสถาปัตยกรรม GAN ของเรา
เราจำเป็นต้องมีวิดีโออย่างน้อย 10,000 รายการเป็นข้อมูลการฝึกอบรม ทำไม เนื่องจากฉันทดสอบด้วยตัวเลขที่น้อยกว่าและผลลัพธ์ที่ได้แย่มาก แทบจะไม่มีอะไรให้เห็นเลย คำถามสำคัญต่อไปคือ วิดีโอเหล่านี้เกี่ยวกับอะไร ชุดข้อมูลวิดีโอการฝึกอบรมของเราประกอบด้วยวงกลมที่เคลื่อนที่ไปในทิศทางที่ต่างกันและมีการเคลื่อนไหวที่ต่างกัน เรามาเขียนโค้ดและสร้างวิดีโอ 10,000 รายการเพื่อดูว่าหน้าตาเป็นอย่างไร
# Create a directory named 'training_dataset'
os . makedirs ( 'training_dataset' , exist_ok = True )
# Define the number of videos to generate for the dataset
num_videos = 10000
# Define the number of frames per video (1 Second Video)
frames_per_video = 10
# Define the size of each image in the dataset
img_size = ( 64 , 64 )
# Define the size of the shapes (Circle)
shape_size = 10
หลังจากตั้งค่าพารามิเตอร์พื้นฐานแล้ว ต่อไปเราต้องกำหนดข้อความแจ้งของชุดข้อมูลการฝึกอบรมของเราตามวิดีโอการฝึกอบรมที่จะถูกสร้างขึ้น
# Define text prompts and corresponding movements for circles
prompts_and_movements = [
( "circle moving down" , "circle" , "down" ), # Move circle downward
( "circle moving left" , "circle" , "left" ), # Move circle leftward
( "circle moving right" , "circle" , "right" ), # Move circle rightward
( "circle moving diagonally up-right" , "circle" , "diagonal_up_right" ), # Move circle diagonally up-right
( "circle moving diagonally down-left" , "circle" , "diagonal_down_left" ), # Move circle diagonally down-left
( "circle moving diagonally up-left" , "circle" , "diagonal_up_left" ), # Move circle diagonally up-left
( "circle moving diagonally down-right" , "circle" , "diagonal_down_right" ), # Move circle diagonally down-right
( "circle rotating clockwise" , "circle" , "rotate_clockwise" ), # Rotate circle clockwise
( "circle rotating counter-clockwise" , "circle" , "rotate_counter_clockwise" ), # Rotate circle counter-clockwise
( "circle shrinking" , "circle" , "shrink" ), # Shrink circle
( "circle expanding" , "circle" , "expand" ), # Expand circle
( "circle bouncing vertically" , "circle" , "bounce_vertical" ), # Bounce circle vertically
( "circle bouncing horizontally" , "circle" , "bounce_horizontal" ), # Bounce circle horizontally
( "circle zigzagging vertically" , "circle" , "zigzag_vertical" ), # Zigzag circle vertically
( "circle zigzagging horizontally" , "circle" , "zigzag_horizontal" ), # Zigzag circle horizontally
( "circle moving up-left" , "circle" , "up_left" ), # Move circle up-left
( "circle moving down-right" , "circle" , "down_right" ), # Move circle down-right
( "circle moving down-left" , "circle" , "down_left" ), # Move circle down-left
]
เราได้กำหนดการเคลื่อนไหวหลายอย่างในวงกลมของเราโดยใช้คำแนะนำเหล่านี้ ตอนนี้ เราจำเป็นต้องเขียนโค้ดสมการทางคณิตศาสตร์บางอย่างเพื่อย้ายวงกลมนั้นตามคำแนะนำ
# defining function to create image with moving shape
def create_image_with_moving_shape ( size , frame_num , shape , direction ):
# Create a new RGB image with specified size and white background
img = Image . new ( 'RGB' , size , color = ( 255 , 255 , 255 ))
# Create a drawing context for the image
draw = ImageDraw . Draw ( img )
# Calculate the center coordinates of the image
center_x , center_y = size [ 0 ] // 2 , size [ 1 ] // 2
# Initialize position with center for all movements
position = ( center_x , center_y )
# Define a dictionary mapping directions to their respective position adjustments or image transformations
direction_map = {
# Adjust position downwards based on frame number
"down" : ( 0 , frame_num * 5 % size [ 1 ]),
# Adjust position to the left based on frame number
"left" : ( - frame_num * 5 % size [ 0 ], 0 ),
# Adjust position to the right based on frame number
"right" : ( frame_num * 5 % size [ 0 ], 0 ),
# Adjust position diagonally up and to the right
"diagonal_up_right" : ( frame_num * 5 % size [ 0 ], - frame_num * 5 % size [ 1 ]),
# Adjust position diagonally down and to the left
"diagonal_down_left" : ( - frame_num * 5 % size [ 0 ], frame_num * 5 % size [ 1 ]),
# Adjust position diagonally up and to the left
"diagonal_up_left" : ( - frame_num * 5 % size [ 0 ], - frame_num * 5 % size [ 1 ]),
# Adjust position diagonally down and to the right
"diagonal_down_right" : ( frame_num * 5 % size [ 0 ], frame_num * 5 % size [ 1 ]),
# Rotate the image clockwise based on frame number
"rotate_clockwise" : img . rotate ( frame_num * 10 % 360 , center = ( center_x , center_y ), fillcolor = ( 255 , 255 , 255 )),
# Rotate the image counter-clockwise based on frame number
"rotate_counter_clockwise" : img . rotate ( - frame_num * 10 % 360 , center = ( center_x , center_y ), fillcolor = ( 255 , 255 , 255 )),
# Adjust position for a bouncing effect vertically
"bounce_vertical" : ( 0 , center_y - abs ( frame_num * 5 % size [ 1 ] - center_y )),
# Adjust position for a bouncing effect horizontally
"bounce_horizontal" : ( center_x - abs ( frame_num * 5 % size [ 0 ] - center_x ), 0 ),
# Adjust position for a zigzag effect vertically
"zigzag_vertical" : ( 0 , center_y - frame_num * 5 % size [ 1 ]) if frame_num % 2 == 0 else ( 0 , center_y + frame_num * 5 % size [ 1 ]),
# Adjust position for a zigzag effect horizontally
"zigzag_horizontal" : ( center_x - frame_num * 5 % size [ 0 ], center_y ) if frame_num % 2 == 0 else ( center_x + frame_num * 5 % size [ 0 ], center_y ),
# Adjust position upwards and to the right based on frame number
"up_right" : ( frame_num * 5 % size [ 0 ], - frame_num * 5 % size [ 1 ]),
# Adjust position upwards and to the left based on frame number
"up_left" : ( - frame_num * 5 % size [ 0 ], - frame_num * 5 % size [ 1 ]),
# Adjust position downwards and to the right based on frame number
"down_right" : ( frame_num * 5 % size [ 0 ], frame_num * 5 % size [ 1 ]),
# Adjust position downwards and to the left based on frame number
"down_left" : ( - frame_num * 5 % size [ 0 ], frame_num * 5 % size [ 1 ])
}
# Check if direction is in the direction map
if direction in direction_map :
# Check if the direction maps to a position adjustment
if isinstance ( direction_map [ direction ], tuple ):
# Update position based on the adjustment
position = tuple ( np . add ( position , direction_map [ direction ]))
else : # If the direction maps to an image transformation
# Update the image based on the transformation
img = direction_map [ direction ]
# Return the image as a numpy array
return np . array ( img )
ฟังก์ชั่นด้านบนใช้เพื่อย้ายวงกลมของเราสำหรับแต่ละเฟรมตามทิศทางที่เลือก เราแค่ต้องวนซ้ำตามจำนวนวิดีโอครั้งเพื่อสร้างวิดีโอทั้งหมด
# Iterate over the number of videos to generate
for i in range ( num_videos ):
# Randomly choose a prompt and movement from the predefined list
prompt , shape , direction = random . choice ( prompts_and_movements )
# Create a directory for the current video
video_dir = f'training_dataset/video_ { i } '
os . makedirs ( video_dir , exist_ok = True )
# Write the chosen prompt to a text file in the video directory
with open ( f' { video_dir } /prompt.txt' , 'w' ) as f :
f . write ( prompt )
# Generate frames for the current video
for frame_num in range ( frames_per_video ):
# Create an image with a moving shape based on the current frame number, shape, and direction
img = create_image_with_moving_shape ( img_size , frame_num , shape , direction )
# Save the generated image as a PNG file in the video directory
cv2 . imwrite ( f' { video_dir } /frame_ { frame_num } .png' , img )
เมื่อคุณรันโค้ดข้างต้น มันจะสร้างชุดข้อมูลการฝึกอบรมทั้งหมดของเรา โครงสร้างของไฟล์ชุดข้อมูลการฝึกอบรมของเรามีลักษณะดังนี้
โฟลเดอร์วิดีโอการฝึกอบรมแต่ละโฟลเดอร์จะมีเฟรมพร้อมกับข้อความแจ้ง มาดูตัวอย่างชุดข้อมูลการฝึกอบรมของเรากันดีกว่า
ในชุดข้อมูลการฝึกของเรา เราไม่ได้รวม การเคลื่อนที่ของวงกลมที่เลื่อนขึ้นแล้วไปทางขวา เราจะใช้สิ่งนี้เป็นพรอมต์การทดสอบเพื่อประเมินโมเดลที่ผ่านการฝึกอบรมของเรากับข้อมูลที่มองไม่เห็น
จุดสำคัญอีกประการหนึ่งที่ควรทราบก็คือ ข้อมูลการฝึกของเราประกอบด้วยตัวอย่างจำนวนมากที่มีวัตถุเคลื่อนออกจากฉากหรือปรากฏบางส่วนอยู่หน้ากล้อง คล้ายกับที่เราสังเกตเห็นในวิดีโอสาธิต OpenAI Sora
เหตุผลในการรวมตัวอย่างดังกล่าวไว้ในข้อมูลการฝึกอบรมของเราคือเพื่อทดสอบว่าแบบจำลองของเราสามารถรักษาความสอดคล้องเมื่อวงกลมเข้าสู่ฉากจากมุมหนึ่งโดยไม่ทำให้รูปร่างเสียหายหรือไม่
เมื่อข้อมูลการฝึกอบรมของเราถูกสร้างขึ้นแล้ว เราจำเป็นต้องแปลงวิดีโอการฝึกอบรมเป็นเทนเซอร์ ซึ่งเป็นประเภทข้อมูลหลักที่ใช้ในเฟรมเวิร์กการเรียนรู้เชิงลึก เช่น PyTorch นอกจากนี้ การดำเนินการเปลี่ยนแปลง เช่น การทำให้เป็นมาตรฐานจะช่วยปรับปรุงการบรรจบกันและความเสถียรของสถาปัตยกรรมการฝึกอบรมโดยปรับขนาดข้อมูลให้อยู่ในช่วงที่เล็กลง
เราต้องเขียนโค้ดคลาสชุดข้อมูลสำหรับงานข้อความเป็นวิดีโอ ซึ่งสามารถอ่านเฟรมวิดีโอและข้อความแจ้งที่เกี่ยวข้องจากไดเร็กทอรีชุดข้อมูลการฝึกอบรม ทำให้สามารถใช้งานได้ใน PyTorch
# Define a dataset class inheriting from torch.utils.data.Dataset
class TextToVideoDataset ( Dataset ):
def __init__ ( self , root_dir , transform = None ):
# Initialize the dataset with root directory and optional transform
self . root_dir = root_dir
self . transform = transform
# List all subdirectories in the root directory
self . video_dirs = [ os . path . join ( root_dir , d ) for d in os . listdir ( root_dir ) if os . path . isdir ( os . path . join ( root_dir , d ))]
# Initialize lists to store frame paths and corresponding prompts
self . frame_paths = []
self . prompts = []
# Loop through each video directory
for video_dir in self . video_dirs :
# List all PNG files in the video directory and store their paths
frames = [ os . path . join ( video_dir , f ) for f in os . listdir ( video_dir ) if f . endswith ( '.png' )]
self . frame_paths . extend ( frames )
# Read the prompt text file in the video directory and store its content
with open ( os . path . join ( video_dir , 'prompt.txt' ), 'r' ) as f :
prompt = f . read (). strip ()
# Repeat the prompt for each frame in the video and store in prompts list
self . prompts . extend ([ prompt ] * len ( frames ))
# Return the total number of samples in the dataset
def __len__ ( self ):
return len ( self . frame_paths )
# Retrieve a sample from the dataset given an index
def __getitem__ ( self , idx ):
# Get the path of the frame corresponding to the given index
frame_path = self . frame_paths [ idx ]
# Open the image using PIL (Python Imaging Library)
image = Image . open ( frame_path )
# Get the prompt corresponding to the given index
prompt = self . prompts [ idx ]
# Apply transformation if specified
if self . transform :
image = self . transform ( image )
# Return the transformed image and the prompt
return image , prompt
ก่อนที่จะดำเนินการเขียนโค้ดสถาปัตยกรรม เราจำเป็นต้องทำให้ข้อมูลการฝึกอบรมของเราเป็นมาตรฐาน เราจะใช้ขนาดแบตช์ 16 และสับเปลี่ยนข้อมูลเพื่อให้เกิดการสุ่มมากขึ้น
# Define a set of transformations to be applied to the data
transform = transforms . Compose ([
transforms . ToTensor (), # Convert PIL Image or numpy.ndarray to tensor
transforms . Normalize (( 0.5 ,), ( 0.5 ,)) # Normalize image with mean and standard deviation
])
# Load the dataset using the defined transform
dataset = TextToVideoDataset ( root_dir = 'training_dataset' , transform = transform )
# Create a dataloader to iterate over the dataset
dataloader = torch . utils . data . DataLoader ( dataset , batch_size = 16 , shuffle = True )
คุณอาจเคยเห็นในสถาปัตยกรรมหม้อแปลงไฟฟ้าที่จุดเริ่มต้นคือการแปลงอินพุตข้อความของเราเป็นการฝังเพื่อการประมวลผลเพิ่มเติมในความสนใจแบบหลายหัวที่คล้ายกันที่นี่เราต้องเขียนโค้ดเลเยอร์การฝังข้อความตามการฝึกอบรมสถาปัตยกรรม GAN จะเกิดขึ้นกับข้อมูลการฝังของเรา และเทนเซอร์รูปภาพ
# Define a class for text embedding
class TextEmbedding ( nn . Module ):
# Constructor method with vocab_size and embed_size parameters
def __init__ ( self , vocab_size , embed_size ):
# Call the superclass constructor
super ( TextEmbedding , self ). __init__ ()
# Initialize embedding layer
self . embedding = nn . Embedding ( vocab_size , embed_size )
# Define the forward pass method
def forward ( self , x ):
# Return embedded representation of input
return self . embedding ( x )
ขนาดคำศัพท์จะขึ้นอยู่กับข้อมูลการฝึกอบรมของเราซึ่งเราจะคำนวณในภายหลัง ขนาดการฝังจะเป็น 10 หากทำงานกับชุดข้อมูลที่ใหญ่กว่า คุณยังสามารถใช้โมเดลการฝังที่คุณเลือกเองซึ่งมีอยู่บน Hugging Face ได้
ตอนนี้เรารู้แล้วว่าตัวสร้างทำอะไรใน GAN มาเขียนโค้ดเลเยอร์นี้แล้วทำความเข้าใจเนื้อหากันดีกว่า
class Generator ( nn . Module ):
def __init__ ( self , text_embed_size ):
super ( Generator , self ). __init__ ()
# Fully connected layer that takes noise and text embedding as input
self . fc1 = nn . Linear ( 100 + text_embed_size , 256 * 8 * 8 )
# Transposed convolutional layers to upsample the input
self . deconv1 = nn . ConvTranspose2d ( 256 , 128 , 4 , 2 , 1 )
self . deconv2 = nn . ConvTranspose2d ( 128 , 64 , 4 , 2 , 1 )
self . deconv3 = nn . ConvTranspose2d ( 64 , 3 , 4 , 2 , 1 ) # Output has 3 channels for RGB images
# Activation functions
self . relu = nn . ReLU ( True ) # ReLU activation function
self . tanh = nn . Tanh () # Tanh activation function for final output
def forward ( self , noise , text_embed ):
# Concatenate noise and text embedding along the channel dimension
x = torch . cat (( noise , text_embed ), dim = 1 )
# Fully connected layer followed by reshaping to 4D tensor
x = self . fc1 ( x ). view ( - 1 , 256 , 8 , 8 )
# Upsampling through transposed convolution layers with ReLU activation
x = self . relu ( self . deconv1 ( x ))
x = self . relu ( self . deconv2 ( x ))
# Final layer with Tanh activation to ensure output values are between -1 and 1 (for images)
x