从头开始构建一个小规模的文生视频模型

OpenAI 的 Sora、Stability AI 的 Stable Video Diffusion 以及许多其他已经发布或未来将出现的文本生成视频模型,是继大语言模型 (LLM) 之后 2024 年最流行的 AI 趋势之一。

在这篇博客中,作者将展示如何将从头开始构建一个小规模的文本生成视频模型,涵盖了从理解理论概念、到编写整个架构再到生成最终结果的所有内容。

由于作者没有大算力的 GPU,所以仅编写了小规模架构。以下是在不同处理器上训练模型所需时间的比较。

图片

作者表示,在 CPU 上运行显然需要更长的时间来训练模型。如果你需要快速测试代码中的更改并查看结果,CPU 不是最佳选择。因此建议使用 Colab 或 Kaggle 的 T4 GPU 进行更高效、更快速的训练。

技术交流&资料

技术要学会分享、交流,不建议闭门造车。一个人可以走的很快、一堆人可以走的更远。

成立了算法面试和技术交流群,相关资料、技术交流&答疑,均可加我们的交流群获取,群友已超过2000人,添加时最好的备注方式为:来源+兴趣方向,方便找到志同道合的朋友。

方式①、微信搜索公众号:机器学习社区,后台回复:加群
方式②、添加微信号:mlc2040,备注:来自CSDN + 技术交流

构建目标

我们采用了与传统机器学习或深度学习模型类似的方法,即在数据集上进行训练,然后在未见过数据上进行测试。在文本转视频的背景下,假设有一个包含 10 万个狗捡球和猫追老鼠视频的训练数据集,然后训练模型来生成猫捡球或狗追老鼠的视频。

图片

图源:iStock, GettyImages

虽然此类训练数据集在互联网上很容易获得,但所需的算力极高。因此,我们将使用由 Python 代码生成的移动对象视频数据集。同时使用 GAN(生成对抗网络)架构来创建模型,而不是 OpenAI Sora 使用的扩散模型。

我们也尝试使用扩散模型,但内存要求超出了自己的能力。另一方面,GAN 可以更容易、更快地进行训练和测试。

准备条件

我们将使用 OOP(面向对象编程),因此必须对它以及神经网络有基本的了解。此外 GAN(生成对抗网络)的知识不是必需的,因为这里简单介绍它们的架构。

  • OOP:https://www.youtube.com/watch?v=q2SGW2VgwAM

  • 神经网络理论:https://www.youtube.com/watch?v=Jy4wM2X21u0

  • GAN 架构:https://www.youtube.com/watch?v=TpMIssRdhco

  • Python 基础:https://www.youtube.com/watch?v=eWRfhZUzrAc

了解 GAN 架构

什么是 GAN?

生成对抗网络是一种深度学习模型,其中两个神经网络相互竞争:一个从给定的数据集创建新数据(如图像或音乐),另一个则判断数据是真实的还是虚假的。这个过程一直持续到生成的数据与原始数据无法区分。

真实世界应用

  • 生成图像:GAN 根据文本 prompt 创建逼真的图像或修改现有图像,例如增强分辨率或为黑白照片添加颜色。

  • 数据增强:GAN 生成合成数据来训练其他机器学习模型,例如为欺诈检测系统创建欺诈交易数据。

  • 补充缺失信息:GAN 可以填充缺失数据,例如根据地形图生成地下图像以用于能源应用。

  • 生成 3D 模型:GAN 将 2D 图像转换为 3D 模型,在医疗保健等领域非常有用,可用于为手术规划创建逼真的器官图像。

GAN 工作原理

GAN 由两个深度神经网络组成:生成器和判别器。这两个网络在对抗设置中一起训练,其中一个网络生成新数据,另一个网络评估数据是真是假。

图片

GAN 训练示例

让我们以图像到图像的转换为例,解释一下 GAN 模型,重点是修改人脸。

1. 输入图像:输入图像是一张真实的人脸图像。

2. 属性修改:生成器会修改人脸的属性,比如给眼睛加上墨镜。

3. 生成图像:生成器会创建一组添加了太阳镜的图像。

4. 判别器的任务:判别器接收到混合的真实图像(带有太阳镜的人)和生成的图像(添加了太阳镜的人脸)。

5. 评估:判别器尝试区分真实图像和生成图像。

6. 反馈回路:如果判别器正确识别出假图像,生成器会调整其参数以生成更逼真的图像。如果生成器成功欺骗了判别器,判别器会更新其参数以提高检测能力。

通过这一对抗过程,两个网络都在不断改进。生成器越来越善于生成逼真的图像,而判别器则越来越善于识别假图像,直到达到平衡,判别器再也无法区分真实图像和生成的图像。此时,GAN 已成功学会生成逼真的修改图像。

设置背景

我们将使用一系列 Python 库,让我们导入它们。

# Operating System module for interacting with the operating system
import os# Module for generating random numbers
import random# Module for numerical operations
import numpy as np# OpenCV library for image processing
import cv2# Python Imaging Library for image processing
from PIL import Image, ImageDraw, ImageFont# PyTorch library for deep learning
import torch# Dataset class for creating custom datasets in PyTorch
from torch.utils.data import Dataset# Module for image transformations
import torchvision.transforms as transforms# Neural network module in PyTorch
import torch.nn as nn# Optimization algorithms in PyTorch
import torch.optim as optim# Function for padding sequences in PyTorch
from torch.nn.utils.rnn import pad_sequence# Function for saving images in PyTorch
from torchvision.utils import save_image# Module for plotting graphs and images
import matplotlib.pyplot as plt# Module for displaying rich content in IPython environments
from IPython.display import clear_output, display, HTML# Module for encoding and decoding binary data to text
import base64

现在我们已经导入了所有的库,下一步就是定义我们的训练数据,用于训练 GAN 架构。

对训练数据进行编码

我们需要至少 10000 个视频作为训练数据。为什么呢?因为我测试了较小数量的视频,结果非常糟糕,几乎没有任何效果。下一个重要问题是:这些视频内容是什么? 我们的训练视频数据集包括一个圆圈以不同方向和不同运动方式移动的视频。让我们来编写代码并生成 10,000 个视频,看看它的效果如何。

# Create a directory named 'training_dataset'
os.makedirs('training_dataset', exist_ok=True)# Define the number of videos to generate for the dataset
num_videos = 10000# Define the number of frames per video (1 Second Video)
frames_per_video = 10# Define the size of each image in the dataset
img_size = (64, 64)# Define the size of the shapes (Circle)
shape_size = 10 

设置一些基本参数后,接下来我们需要定义训练数据集的文本 prompt,并据此生成训练视频。

# Define text prompts and corresponding movements for circles
prompts_and_movements = [("circle moving down", "circle", "down"), # Move circle downward("circle moving left", "circle", "left"), # Move circle leftward("circle moving right", "circle", "right"), # Move circle rightward("circle moving diagonally up-right", "circle", "diagonal_up_right"), # Move circle diagonally up-right("circle moving diagonally down-left", "circle", "diagonal_down_left"), # Move circle diagonally down-left("circle moving diagonally up-left", "circle", "diagonal_up_left"), # Move circle diagonally up-left("circle moving diagonally down-right", "circle", "diagonal_down_right"), # Move circle diagonally down-right("circle rotating clockwise", "circle", "rotate_clockwise"), # Rotate circle clockwise("circle rotating counter-clockwise", "circle", "rotate_counter_clockwise"), # Rotate circle counter-clockwise("circle shrinking", "circle", "shrink"), # Shrink circle("circle expanding", "circle", "expand"), # Expand circle("circle bouncing vertically", "circle", "bounce_vertical"), # Bounce circle vertically("circle bouncing horizontally", "circle", "bounce_horizontal"), # Bounce circle horizontally("circle zigzagging vertically", "circle", "zigzag_vertical"), # Zigzag circle vertically("circle zigzagging horizontally", "circle", "zigzag_horizontal"), # Zigzag circle horizontally("circle moving up-left", "circle", "up_left"), # Move circle up-left("circle moving down-right", "circle", "down_right"), # Move circle down-right("circle moving down-left", "circle", "down_left"), # Move circle down-left
]

我们已经利用这些 prompt 定义了圆的几个运动轨迹。现在,我们需要编写一些数学公式,以便根据 prompt 移动圆。

# Define function with parameters
def create_image_with_moving_shape(size, frame_num, shape, direction):# Create a new RGB image with specified size and white backgroundimg = Image.new('RGB', size, color=(255, 255, 255)) # Create a drawing context for the imagedraw = ImageDraw.Draw(img) # Calculate the center coordinates of the imagecenter_x, center_y = size[0] // 2, size[1] // 2 # Initialize position with center for all movementsposition = (center_x, center_y) # Define a dictionary mapping directions to their respective position adjustments or image transformationsdirection_map = { # Adjust position downwards based on frame number"down": (0, frame_num * 5 % size[1]), # Adjust position to the left based on frame number"left": (-frame_num * 5 % size[0], 0), # Adjust position to the right based on frame number"right": (frame_num * 5 % size[0], 0), # Adjust position diagonally up and to the right"diagonal_up_right": (frame_num * 5 % size[0], -frame_num * 5 % size[1]), # Adjust position diagonally down and to the left"diagonal_down_left": (-frame_num * 5 % size[0], frame_num * 5 % size[1]), # Adjust position diagonally up and to the left"diagonal_up_left": (-frame_num * 5 % size[0], -frame_num * 5 % size[1]), # Adjust position diagonally down and to the right"diagonal_down_right": (frame_num * 5 % size[0], frame_num * 5 % size[1]), # Rotate the image clockwise based on frame number"rotate_clockwise": img.rotate(frame_num * 10 % 360, center=(center_x, center_y), fillcolor=(255, 255, 255)), # Rotate the image counter-clockwise based on frame number"rotate_counter_clockwise": img.rotate(-frame_num * 10 % 360, center=(center_x, center_y), fillcolor=(255, 255, 255)), # Adjust position for a bouncing effect vertically"bounce_vertical": (0, center_y - abs(frame_num * 5 % size[1] - center_y)), # Adjust position for a bouncing effect horizontally"bounce_horizontal": (center_x - abs(frame_num * 5 % size[0] - center_x), 0), # Adjust position for a zigzag effect vertically"zigzag_vertical": (0, center_y - frame_num * 5 % size[1]) if frame_num % 2 == 0 else (0, center_y + frame_num * 5 % size[1]), # Adjust position for a zigzag effect horizontally"zigzag_horizontal": (center_x - frame_num * 5 % size[0], center_y) if frame_num % 2 == 0 else (center_x + frame_num * 5 % size[0], center_y), # Adjust position upwards and to the right based on frame number"up_right": (frame_num * 5 % size[0], -frame_num * 5 % size[1]), # Adjust position upwards and to the left based on frame number"up_left": (-frame_num * 5 % size[0], -frame_num * 5 % size[1]), # Adjust position downwards and to the right based on frame number"down_right": (frame_num * 5 % size[0], frame_num * 5 % size[1]), # Adjust position downwards and to the left based on frame number"down_left": (-frame_num * 5 % size[0], frame_num * 5 % size[1]) }# Check if direction is in the direction mapif direction in direction_map: # Check if the direction maps to a position adjustmentif isinstance(direction_map[direction], tuple): # Update position based on the adjustmentposition = tuple(np.add(position, direction_map[direction])) else: # If the direction maps to an image transformation# Update the image based on the transformationimg = direction_map[direction] # Return the image as a numpy arrayreturn np.array(img)

上述函数用于根据所选方向在每一帧中移动我们的圆。我们只需在其上运行一个循环,直至生成所有视频的次数。

# Iterate over the number of videos to generate
for i in range(num_videos):# Randomly choose a prompt and movement from the predefined listprompt, shape, direction = random.choice(prompts_and_movements)# Create a directory for the current videovideo_dir = f'training_dataset/video_{i}'os.makedirs(video_dir, exist_ok=True)# Write the chosen prompt to a text file in the video directorywith open(f'{video_dir}/prompt.txt', 'w') as f:f.write(prompt)# Generate frames for the current videofor frame_num in range(frames_per_video):# Create an image with a moving shape based on the current frame number, shape, and directionimg = create_image_with_moving_shape(img_size, frame_num, shape, direction)# Save the generated image as a PNG file in the video directorycv2.imwrite(f'{video_dir}/frame_{frame_num}.png', img)

运行上述代码后,就会生成整个训练数据集。以下是训练数据集文件的结构。

图片

每个训练视频文件夹包含其帧以及对应的文本 prompt。让我们看一下我们的训练数据集样本。

在我们的训练数据集中,我们没有包含圆圈先向上移动然后向右移动的运动。我们将使用这个作为测试 prompt,来评估我们训练的模型在未见过的数据上的表现。

图片

还有一个重要的要点需要注意,我们的训练数据包含许多物体从场景中移出或部分出现在摄像机前方的样本,类似于我们在 OpenAI Sora 演示视频中观察到的情况。

图片

在我们的训练数据中包含此类样本的原因是为了测试当圆圈从角落进入场景时,模型是否能够保持一致性而不会破坏其形状。

现在我们的训练数据已经生成,需要将训练视频转换为张量,这是 PyTorch 等深度学习框架中使用的主要数据类型。此外,通过将数据缩放到较小的范围,执行归一化等转换有助于提高训练架构的收敛性和稳定性。

预处理训练数据

我们必须为文本转视频任务编写一个数据集类,它可以从训练数据集目录中读取视频帧及其相应的文本 prompt,使其可以在 PyTorch 中使用。

# Define a dataset class inheriting from torch.utils.data.Dataset
class TextToVideoDataset(Dataset):def __init__(self, root_dir, transform=None):# Initialize the dataset with root directory and optional transformself.root_dir = root_dirself.transform = transform# List all subdirectories in the root directoryself.video_dirs = [os.path.join(root_dir, d) for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))]# Initialize lists to store frame paths and corresponding promptsself.frame_paths = []self.prompts = []# Loop through each video directoryfor video_dir in self.video_dirs:# List all PNG files in the video directory and store their pathsframes = [os.path.join(video_dir, f) for f in os.listdir(video_dir) if f.endswith('.png')]self.frame_paths.extend(frames)# Read the prompt text file in the video directory and store its contentwith open(os.path.join(video_dir, 'prompt.txt'), 'r') as f:prompt = f.read().strip()# Repeat the prompt for each frame in the video and store in prompts listself.prompts.extend([prompt] * len(frames))# Return the total number of samples in the datasetdef __len__(self):return len(self.frame_paths)# Retrieve a sample from the dataset given an indexdef __getitem__(self, idx):# Get the path of the frame corresponding to the given indexframe_path = self.frame_paths[idx]# Open the image using PIL (Python Imaging Library)image = Image.open(frame_path)# Get the prompt corresponding to the given indexprompt = self.prompts[idx]# Apply transformation if specifiedif self.transform:image = self.transform(image)# Return the transformed image and the promptreturn image, prompt

在继续编写架构代码之前,我们需要对训练数据进行归一化处理。我们使用 16 的 batch 大小并对数据进行混洗以引入更多随机性。

实现文本嵌入层

你可能已经看到,在 Transformer 架构中,起点是将文本输入转换为嵌入,从而在多头注意力中进行进一步处理。类似地,我们在这里必须编写一个文本嵌入层。基于该层,GAN 架构训练在我们的嵌入数据和图像张量上进行。

# Define a class for text embedding
class TextEmbedding(nn.Module):# Constructor method with vocab_size and embed_size parametersdef __init__(self, vocab_size, embed_size):# Call the superclass constructorsuper(TextEmbedding, self).__init__()# Initialize embedding layerself.embedding = nn.Embedding(vocab_size, embed_size)# Define the forward pass methoddef forward(self, x):# Return embedded representation of inputreturn self.embedding(x) 

词汇量将基于我们的训练数据,在稍后进行计算。嵌入大小将为 10。如果使用更大的数据集,你还可以使用 Hugging Face 上已有的嵌入模型。

实现生成器层

现在我们已经知道生成器在 GAN 中的作用,接下来让我们对这一层进行编码,然后了解其内容。

class Generator(nn.Module):def __init__(self, text_embed_size):super(Generator, self).__init__()# Fully connected layer that takes noise and text embedding as inputself.fc1 = nn.Linear(100 + text_embed_size, 256 * 8 * 8)# Transposed convolutional layers to upsample the inputself.deconv1 = nn.ConvTranspose2d(256, 128, 4, 2, 1)self.deconv2 = nn.ConvTranspose2d(128, 64, 4, 2, 1)self.deconv3 = nn.ConvTranspose2d(64, 3, 4, 2, 1) # Output has 3 channels for RGB images# Activation functionsself.relu = nn.ReLU(True) # ReLU activation functionself.tanh = nn.Tanh() # Tanh activation function for final outputdef forward(self, noise, text_embed):# Concatenate noise and text embedding along the channel dimensionx = torch.cat((noise, text_embed), dim=1)# Fully connected layer followed by reshaping to 4D tensorx = self.fc1(x).view(-1, 256, 8, 8)# Upsampling through transposed convolution layers with ReLU activationx = self.relu(self.deconv1(x))x = self.relu(self.deconv2(x))# Final layer with Tanh activation to ensure output values are between -1 and 1 (for images)x = self.tanh(self.deconv3(x))return x

该 Generator 类负责根据随机噪声和文本嵌入的组合创建视频帧,旨在根据给定的文本描述生成逼真的视频帧。该网络从完全连接层 (nn.Linear) 开始,将噪声向量和文本嵌入组合成单个特征向量。然后,该向量被重新整形并经过一系列的转置卷积层 (nn.ConvTranspose2d),这些层将特征图逐步上采样到所需的视频帧大小。

这些层使用 ReLU 激活 (nn.ReLU) 实现非线性,最后一层使用 Tanh 激活 (nn.Tanh) 将输出缩放到 [-1, 1] 的范围。因此,生成器将抽象的高维输入转换为以视觉方式表示输入文本的连贯视频帧。

实现判别器层

在编写完生成器层之后,我们需要实现另一半,即判别器部分。

class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()# Convolutional layers to process input imagesself.conv1 = nn.Conv2d(3, 64, 4, 2, 1)   # 3 input channels (RGB), 64 output channels, kernel size 4x4, stride 2, padding 1self.conv2 = nn.Conv2d(64, 128, 4, 2, 1) # 64 input channels, 128 output channels, kernel size 4x4, stride 2, padding 1self.conv3 = nn.Conv2d(128, 256, 4, 2, 1) # 128 input channels, 256 output channels, kernel size 4x4, stride 2, padding 1# Fully connected layer for classificationself.fc1 = nn.Linear(256 * 8 * 8, 1)  # Input size 256x8x8 (output size of last convolution), output size 1 (binary classification)# Activation functionsself.leaky_relu = nn.LeakyReLU(0.2, inplace=True)  # Leaky ReLU activation with negative slope 0.2self.sigmoid = nn.Sigmoid()  # Sigmoid activation for final output (probability)def forward(self, input):# Pass input through convolutional layers with LeakyReLU activationx = self.leaky_relu(self.conv1(input))x = self.leaky_relu(self.conv2(x))x = self.leaky_relu(self.conv3(x))# Flatten the output of convolutional layersx = x.view(-1, 256 * 8 * 8)# Pass through fully connected layer with Sigmoid activation for binary classificationx = self.sigmoid(self.fc1(x))return x

判别器类用作二元分类器,区分真实视频帧和生成的视频帧。目的是评估视频帧的真实性,从而指导生成器产生更真实的输出。该网络由卷积层 (nn.Conv2d) 组成,这些卷积层从输入视频帧中提取分层特征, Leaky ReLU 激活 (nn.LeakyReLU) 增加非线性,同时允许负值的小梯度。

然后,特征图被展平并通过完全连接层 (nn.Linear),最终以 S 形激活 (nn.Sigmoid) 输出指示帧是真实还是假的概率分数。

通过训练判别器准确地对帧进行分类,生成器同时接受训练以创建更令人信服的视频帧,从而骗过判别器。

编写训练参数

我们必须设置用于训练 GAN 的基础组件,例如损失函数、优化器等。

# Check for GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# Create a simple vocabulary for text prompts
all_prompts = [prompt for prompt, _, _ in prompts_and_movements]  # Extract all prompts from prompts_and_movements list
vocab = {word: idx for idx, word in enumerate(set(" ".join(all_prompts).split()))}  # Create a vocabulary dictionary where each unique word is assigned an index
vocab_size = len(vocab)  # Size of the vocabulary
embed_size = 10  # Size of the text embedding vectordef encode_text(prompt):# Encode a given prompt into a tensor of indices using the vocabularyreturn torch.tensor([vocab[word] for word in prompt.split()])# Initialize models, loss function, and optimizers
text_embedding = TextEmbedding(vocab_size, embed_size).to(device)  # Initialize TextEmbedding model with vocab_size and embed_size
netG = Generator(embed_size).to(device)  # Initialize Generator model with embed_size
netD = Discriminator().to(device)  # Initialize Discriminator model
criterion = nn.BCELoss().to(device)  # Binary Cross Entropy loss function
optimizerD = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999))  # Adam optimizer for Discriminator
optimizerG = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999))  # Adam optimizer for Generator

这是我们必须转换代码以在 GPU 上运行的部分(如果可用)。我们已经编写了代码来查找 vocab_size,并且我们正在为生成器和判别器使用 ADAM 优化器。你可以选择自己的优化器。在这里,我们将学习率设置为较小的值 0.0002,嵌入大小为 10,这比其他可供公众使用的 Hugging Face 模型要小得多。

编写训练 loop

就像其他神经网络一样,我们将以类似的方式对 GAN 架构训练进行编码。

# Number of epochs
num_epochs = 13# Iterate over each epoch
for epoch in range(num_epochs):# Iterate over each batch of datafor i, (data, prompts) in enumerate(dataloader):# Move real data to devicereal_data = data.to(device)# Convert prompts to listprompts = [prompt for prompt in prompts]# Update DiscriminatornetD.zero_grad()  # Zero the gradients of the Discriminatorbatch_size = real_data.size(0)  # Get the batch sizelabels = torch.ones(batch_size, 1).to(device)  # Create labels for real data (ones)output = netD(real_data)  # Forward pass real data through DiscriminatorlossD_real = criterion(output, labels)  # Calculate loss on real datalossD_real.backward()  # Backward pass to calculate gradients# Generate fake datanoise = torch.randn(batch_size, 100).to(device)  # Generate random noisetext_embeds = torch.stack([text_embedding(encode_text(prompt).to(device)).mean(dim=0) for prompt in prompts])  # Encode prompts into text embeddingsfake_data = netG(noise, text_embeds)  # Generate fake data from noise and text embeddingslabels = torch.zeros(batch_size, 1).to(device)  # Create labels for fake data (zeros)output = netD(fake_data.detach())  # Forward pass fake data through Discriminator (detach to avoid gradients flowing back to Generator)lossD_fake = criterion(output, labels)  # Calculate loss on fake datalossD_fake.backward()  # Backward pass to calculate gradientsoptimizerD.step()  # Update Discriminator parameters# Update GeneratornetG.zero_grad()  # Zero the gradients of the Generatorlabels = torch.ones(batch_size, 1).to(device)  # Create labels for fake data (ones) to fool Discriminatoroutput = netD(fake_data)  # Forward pass fake data (now updated) through DiscriminatorlossG = criterion(output, labels)  # Calculate loss for Generator based on Discriminator's responselossG.backward()  # Backward pass to calculate gradientsoptimizerG.step()  # Update Generator parameters# Print epoch informationprint(f"Epoch [{epoch + 1}/{num_epochs}] Loss D: {lossD_real + lossD_fake}, Loss G: {lossG}")

通过反向传播,我们的损失将针对生成器和判别器进行调整。我们在训练 loop 中使用了 13 个 epoch。我们测试了不同的值,但如果 epoch 高于这个值,结果并没有太大差异。此外,过度拟合的风险很高。如果我们的数据集更加多样化,包含更多动作和形状,则可以考虑使用更高的 epoch,但在这里没有这样做。

当我们运行此代码时,它会开始训练,并在每个 epoch 之后 print 生成器和判别器的损失。

## OUTPUT ##Epoch [1/13] Loss D: 0.8798642754554749, Loss G: 1.300612449645996
Epoch [2/13] Loss D: 0.8235711455345154, Loss G: 1.3729925155639648
Epoch [3/13] Loss D: 0.6098687052726746, Loss G: 1.3266581296920776...

保存训练的模型

训练完成后,我们需要保存训练好的 GAN 架构的判别器和生成器,这只需两行代码即可实现。

# Save the Generator model's state dictionary to a file named 'generator.pth'
torch.save(netG.state_dict(), 'generator.pth')# Save the Discriminator model's state dictionary to a file named 'discriminator.pth'
torch.save(netD.state_dict(), 'discriminator.pth')

生成 AI 视频

正如我们所讨论的,我们在未见过的数据上测试模型的方法与我们训练数据中涉及狗取球和猫追老鼠的示例类似。因此,我们的测试 prompt 可能涉及猫取球或狗追老鼠等场景。

在我们的特定情况下,圆圈向上移动然后向右移动的运动在训练数据中不存在,因此模型不熟悉这种特定运动。但是,模型已经在其他动作上进行了训练。我们可以使用此动作作为 prompt 来测试我们训练过的模型并观察其性能。

# Inference function to generate a video based on a given text promptdef generate_video(text_prompt, num_frames=10):    # Create a directory for the generated video frames based on the text prompt    os.makedirs(f'generated_video_{text_prompt.replace(" ", "_")}', exist_ok=True)        # Encode the text prompt into a text embedding tensor    text_embed = text_embedding(encode_text(text_prompt).to(device)).mean(dim=0).unsqueeze(0)        # Generate frames for the video    for frame_num in range(num_frames):        # Generate random noise        noise = torch.randn(1, 100).to(device)                # Generate a fake frame using the Generator network        with torch.no_grad():            fake_frame = netG(noise, text_embed)                # Save the generated fake frame as an image file        save_image(fake_frame, f'generated_video_{text_prompt.replace(" ", "_")}/frame_{frame_num}.png')# usage of the generate_video function with a specific text promptgenerate_video('circle moving up-right')

当我们运行上述代码时,它将生成一个目录,其中包含我们生成视频的所有帧。我们需要使用一些代码将所有这些帧合并为一个短视频。

# Define the path to your folder containing the PNG frames
folder_path = 'generated_video_circle_moving_up-right'# Get the list of all PNG files in the folder
image_files = [f for f in os.listdir(folder_path) if f.endswith('.png')]# Sort the images by name (assuming they are numbered sequentially)
image_files.sort()# Create a list to store the frames
frames = []# Read each image and append it to the frames list
for image_file in image_files:image_path = os.path.join(folder_path, image_file)frame = cv2.imread(image_path)frames.append(frame)# Convert the frames list to a numpy array for easier processing
frames = np.array(frames)# Define the frame rate (frames per second)
fps = 10# Create a video writer object
fourcc = cv2.VideoWriter_fourcc(*'XVID')
out = cv2.VideoWriter('generated_video.avi', fourcc, fps, (frames[0].shape[1], frames[0].shape[0]))# Write each frame to the video
for frame in frames:out.write(frame)# Release the video writer
out.release()

确保文件夹路径指向你新生成的视频所在的位置。运行此代码后,你将成功创建 AI 视频。让我们看看它是什么样子。

图片

我们进行了多次训练,训练次数相同。在两种情况下,圆圈都是从底部开始,出现一半。好消息是,我们的模型在两种情况下都尝试执行直立运动。

例如,在尝试 1 中,圆圈沿对角线向上移动,然后执行向上运动,而在尝试 2 中,圆圈沿对角线移动,同时尺寸缩小。在两种情况下,圆圈都没有向左移动或完全消失,这是一个好兆头。

最后,作者表示已经测试了该架构的各个方面,发现训练数据是关键。通过在数据集中包含更多动作和形状,你可以增加可变性并提高模型的性能。由于数据是通过代码生成的,因此生成更多样的数据不会花费太多时间;相反,你可以专注于完善逻辑。

此外,文章中讨论的 GAN 架构相对简单。你可以通过集成高级技术或使用语言模型嵌入 (LLM) 而不是基本神经网络嵌入来使其更复杂。此外,调整嵌入大小等参数会显著影响模型的有效性。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/364926.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

C# 实现websocket双向通信

🎈个人主页:靓仔很忙i 💻B 站主页:👉B站👈 🎉欢迎 👍点赞✍评论⭐收藏 🤗收录专栏:C# 🤝希望本文对您有所裨益,如有不足之处&#xff…

AWT的菜单组件

AWT的菜单组件 前言一、菜单组件的介绍常见的菜单相关组件常见菜单相关组件集成体系图菜单相关组件使用小要点 二、AWT菜单组件的代码示例示例一示例二实现思路 前言 推荐一个网站给想要了解或者学习人工智能知识的读者,这个网站里内容讲解通俗易懂且风趣幽默&…

如何使用sr2t将你的安全扫描报告转换为表格格式

关于sr2t sr2t是一款针对安全扫描报告的格式转换工具,全称为“Scanning reports to tabular”,该工具可以获取扫描工具的输出文件,并将文件数据转换为表格格式,例如CSV、XLSX或文本表格等,能够为广大研究人员提供一个…

MySQL详细介绍:开源关系数据库管理系统的魅力

学习总结 1、掌握 JAVA入门到进阶知识(持续写作中……) 2、学会Oracle数据库入门到入土用法(创作中……) 3、手把手教你开发炫酷的vbs脚本制作(完善中……) 4、牛逼哄哄的 IDEA编程利器技巧(编写中……) 5、面经吐血整理的 面试技…

理解GPT2:无监督学习的多任务语言模型

目录 一、背景与动机 二、卖点与创新 三、几个问题 四、具体是如何做的 1、更多、优质的数据,更大的模型 2、大数据量,大模型使得zero-shot成为可能 3、使用prompt做下游任务 五、一些资料 一、背景与动机 基于 Transformer 解码器的 GPT-1 证明…

容器技术-docker4

一、docker资源限制 在使用 docker 运行容器时,一台主机上可能会运行几百个容器,这些容器虽然互相隔离,但是底层却使用着相同的 CPU、内存和磁盘资源。如果不对容器使用的资源进行限制,那么容器之间会互相影响,小的来说…

【Spring】DAO 和 Repository 的区别

DAO 和 Repository 的区别 1.概述2.DAO 模式2.1 User2.2 UserDao2.3 UserDaoImpl 3.Repository 模式3.1 UserRepository3.2 UserRepositoryImpl 4.具有多个 DAO 的 Repository 模式4.1 Tweet4.2 TweetDao 和 TweetDaoImpl4.3 增强 User 域4.4 UserRepositoryImpl 5.比较两种模式…

【机器学习】机器学习的重要技术——生成对抗网络:理论、算法与实践

引言 生成对抗网络(Generative Adversarial Networks, GANs)由Ian Goodfellow等人在2014年提出,通过生成器和判别器两个神经网络的对抗训练,成功实现了高质量数据的生成。GANs在图像生成、数据增强、风格迁移等领域取得了显著成果…

详细分析Springmvc中的@ModelAttribute基本知识(附Demo)

目录 前言1. 注解用法1.1 方法参数1.2 方法1.3 类 2. 注解场景2.1 表单参数2.2 AJAX请求2.3 文件上传 3. 实战4. 总结 前言 将请求参数绑定到模型对象上,或者在请求处理之前添加模型属性 可以在方法参数、方法或者类上使用 一般适用这几种场景: 表单…

ros笔记01--初次体验ros2

ros笔记01--初次体验ros2 介绍安装ros2测试验证ros2说明 介绍 机器人操作系统(ROS)是一组用于构建机器人应用程序的软件库和工具。从驱动程序和最先进的算法到强大的开发者工具,ROS拥有我们下一个机器人项目所需的开源工具。 当前ros已经应用到各类机器人项目开发中…

【Matlab 六自由度机器人】机器人动力学之推导拉格朗日方程(附MATLAB机器人动力学拉格朗日方程推导代码)

【Matlab 六自由度机器人】机器人动力学概述 近期更新前言正文一、拉格朗日方程的推导1. 单自由度系统2. 单连杆机械臂系统3. 双连杆机械臂系统 二、MATLAB实例推导1. 机器人模型的建立2. 动力学代码 总结参考文献 近期更新 【汇总】 【Matlab 六自由度机器人】系列文章汇总 …

Elasticsearch 第四期:搜索和过滤

序 2024年4月,小组计算建设标签平台,使用ES等工具建了一个demo,由于领导变动关系,项目基本夭折。其实这两年也陆陆续续接触和使用过ES,两年前也看过ES的官网,当时刚毕业半年多,由于历史局限性导…

大数据开发如何管理项目

在面试的时候总是 会问起项目,那在大数据开发的实际工作中,如何做好一个项目呢? 目录 1. 需求分析与项目规划1.1 需求收集与梳理1.2 可行性分析1.3 项目章程与计划 2. 数据准备与处理2.1 数据源接入2.2 数据仓库建设2.3 数据质量管理 3. 系统…

ARCGIS添加在线地图

地图服务地址:http://map.geoq.cn/ArcGIS/rest/services 具体方法: 结果展示:

Python逻辑控制语句 之 循环语句--for循环

1.for 的介绍 for 循环 也称为是 for 遍历, 也可以做指定次数的循环遍历: 是从容器中将数据逐个取出的过程.容器: 字符串/列表/元组/字典 2.for 的语法 (1)for 循环遍历字符串 for 变量 in 字符串: 重复执⾏的代码 字符串中存在多少个字符, 代码就执行…

flink的窗口

目录 窗口分类 1.按照驱动类型分类 1. 时间窗口(Time window) 2.计数窗口(Count window) 2.按照窗口分配数据的规则分类 窗口API分类 API调用 窗口分配器器: 窗口函数 增量聚合函数: 全窗口函数…

网络编程常见问题

1、TCP状态迁移图 2、TCP三次握手过程 2.1、握手流程 1、TCP服务器进程先创建传输控制块TCB,时刻准备接受客户进程的连接请求,此时服务器就进入了LISTEN(监听)状态; 2、TCP客户进程也是先创建传输控制块TCB&#xff…

Echarts地图实现:杭州市困难人数分布【动画滚动播放】

Echarts地图实现:杭州市困难人数分布 实现功能 杭州市地区以及散点图分布结合的形式数据展示动画轮播可进去杭州市下级地区可返回杭州市地图展示 效果预览 实现思路 使用ECharts的地图和散点图功能结合实现地区分布通过动画轮播展示数据变化实现下级地区数据的展…

搜索引擎的原理与相关知识

搜索引擎是一种网络服务,它通过互联网帮助用户找到所需的信息。搜索引擎的工作原理主要包括以下几个步骤: 网络爬虫(Web Crawler):搜索引擎使用网络爬虫(也称为蜘蛛或机器人)来遍历互联网&#…

Hugging Face Accelerate 两个后端的故事:FSDP 与 DeepSpeed

社区中有两个流行的零冗余优化器 (Zero Redundancy Optimizer,ZeRO)算法实现,一个来自DeepSpeed,另一个来自PyTorch。Hugging FaceAccelerate对这两者都进行了集成并通过接口暴露出来,以供最终用户在训练/微调模型时自主选择其中之…