用 LoRA 微调 Stable Diffusion:拆开炼丹炉,动手实现你的第一次 AI 绘画

总得拆开炼丹炉看看是什么样的。这篇文章将带你从代码层面一步步实现 AI 文本生成图像(Text-to-Image)中的 LoRA 微调过程,你将:

  • 了解 Trigger Words(触发词)到底是什么,以及它们如何影响生成结果。
  • 掌握 LoRA 微调的基本原理。
  • 学习数据集的准备与结构,并知道如何根据需求定制自己的数据集。
  • 理解 Stable Diffusion 模型的微调步骤。
  • 明白在画图界面(UI)下到底发生了什么。
  • 使用代码实现 AI 绘画。

如果你想制作属于自己的数据集,最好遵循以下建议:

  1. 至少准备 20 张图片:想学到的概念越复杂就需要越多的图片。你可以尝试将样例数据集的图片数量减少到 20 张,看看效果会有什么变化。
  2. 裁剪图片:建议对图片进行裁剪,当然你也可以不裁剪,如果你不追求效果的话。这里会自动 resize 到自定义的分辨率。

与其花费大量时间去调参,更优的选择是处理好你的数据集和 Prompts。当然,这两件事情可以同步进行。

注意,当前文章使用的是自然语言标注(而非 Tag)。当然,你也可以使用 Tag,这两种方式本质上是一致的。

同时,如果你对深度学习有所了解,那么代码中的一切,都将是你曾经见过的内容翻版,没有什么新的,除了 LoRA。另外,这篇文章也为生成式人工智能导论课程中 HW10: Stable Diffusion Fine-tuning 提供中文引导。所以,我们将同步使用演员 Brad Pitt(布拉德·皮特)的图片作为训练集,共计一百张。

代码文件下载:镜像交互版 | 精简学习版

前言

下面是使用 prompt:"A man in a graphic tee and sport coat.",在默认设置下训练 2000 个步骤后模型生成的图像,训练时长约为 18 分钟。乍一看,是不是还挺不错的?

valid_image

你可能会注意到,我们的 prompt 中并没有提到 Brad Pitt(布拉德·皮特)这个演员(尽管我们的数据集完全来自于他),但模型却能够绘制长得像 Brad Pitt 的人。

这是因为,如果我们在 prompt 中直接指定 “Brad Pitt”,模型可能无法完全学习到他的特征风格。举个例子:

  • “A man in a graphic tee and sport coat. Brad Pitt.”
  • “A man in a graphic tee and sport coat.”

第一条 prompt 显然更精准,但精准并不意味着模型训练得更好。如果你用一系列包含 “Brad Pitt” 的 prompt 来训练,模型更有可能学到的是:只有在加上 “Brad Pitt” 时才进行风格转变。你可能会说:“我就是想要这个效果”,那么很好,“Brad Pitt” 就是你模型的 Trigger Word(触发词)。但有可能还有同学:“我希望模型只为 Brad Pitt 服务,我要把所有的 ‘man’ 都变成 Brad Pitt”,那么在训练时就不要在 prompt 中增加 “Brad Pitt”。简而言之:反着来

这实际上并没有反直觉,跳出来想一想:

  1. 想象一下你是一位画家,生活在一个从不变暗的世界里,整个世界永远是白天,你已经习惯画出白天背景下的各种景象,但你不知道白天是什么,这就是你所熟知的「日常」。

  2. 有一天,有人给你看了一些照片,说:“Hey,实际上世界可以是黑的,叫做夜晚”,这时候你就会理解到,日常是有另一种状态的,叫做夜晚,即便你以前从来没有过概念,但现在,你将认知到它,你将这部分新的概念聚焦到了「夜晚」。于是,从此以后,你的画作被分为了「日常」和「日常,夜晚」。

  3. 同时,在另一个平行世界,有人告诉你:“你眼中看到的世界是不对的”,他们“治”好了你的眼睛,向你展示了一个完全陌生的漆黑世界,并承诺只要你学会画出这种风格的画作,将会获得丰厚的回报,否则将无人问津你的画摊。于是你开始画“夜晚”风格的「日常」。

这是杜攥的三个小片段,希望你喜欢。

你可以分别将它理解为:

  1. 原始模型:活在自己世界的画家。
  2. LoRA 微调:当新标签(Tag)“夜晚”被引入,画家学会了夜晚的概念。Prompt:夜晚,日常。
  3. 另一个 LoRA 微调:迁移风格,画家将“夜晚”视为真正的日常风格。Prompt:日常。

因此,训练模型就像教小朋友认知世界。如果你将世界分解为不同的概念并逐一传授,孩子会学到不同的知识。这就类似于模型学习不同的标签和风格。如果你不明确区分概念,并将新概念混杂在已有的认知中,孩子的认知会被重塑,或许会将鹿“误”认为马。这是合理的,模型也是如此,取决于你如何教导(prompt)它。

Prompt 小技巧:

  • 明确你的目标:在训练前,思考你是希望模型学习特定的风格、特定的人物,还是希望模型在特定的场景下才生成特定的效果。到底是希望所有的 man 都是 Brad Pitt,还是希望模型知道 Brad Pitt 是一个 man。
  • 保持一致性:如果你希望将某个概念拆分出来,应该为它创建一个特定的标签(tag),并应用于具有相同概念的图像上。

大模型很聪明,它会自动将图像中的共性归因于共用的标签上。因此,如果不给它新的标签,它会将新学到的内容融入到已有的标签中。

这些是关于 AI 绘画 Prompt + 微调背后逻辑的大白话。扯远了,让我们回到代码部分 😃

开始动手

下面,我将带你从代码层面一步步实现 LoRA 微调 Stable Diffusion 模型。注意,这里的知识是通用的,你完全可以推广至任何需要 LoRA 微调的领域。

安装必要的库

首先,确保安装以下必要的 Python 库:

pip install timm==1.0.7
pip install fairscale==0.4.13
pip install transformers==4.41.2
pip install requests==2.31.0
pip install accelerate==0.31.0
pip install diffusers==0.29.1
pip install einops==0.6.1
pip install safetensors==0.4.3
pip install voluptuous==0.15.1
pip install jax==0.4.26
pip install peft==0.11.1
pip install deepface==0.0.92
pip install tensorflow==2.15.0
pip install keras==2.15.0
pip install opencv-python

说明:版本并没有强制要求。

导入

# ========== 标准库模块 ==========
import os
import math
import glob
import shutil
import subprocess# ========== 第三方库 ==========
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
import cv2
from tqdm.auto import tqdm# ========== 深度学习相关库 ==========
from torchvision import transforms# Transformers (Hugging Face)
from transformers import CLIPTextModel, CLIPTokenizer, CLIPModel, CLIPProcessor# Diffusers (Hugging Face)
from diffusers import (AutoencoderKL,DDPMScheduler,UNet2DConditionModel,DiffusionPipeline
)
from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_snr# ========== LoRA 模型库 ==========
from peft import LoraConfig, get_peft_model, PeftModel# ========== 面部检测库 ==========
from deepface import DeepFace

准备数据

当前演示使用的是 Brad Pitt(布拉德·皮特),我们的目标是让模型绘制的 man 是 Brad Pitt,粗略地换个表述:AI 换脸。

那根据我们之前的描述,标注应该长什么样呢?

:都带 “man”,下面是我们当前数据集的标注示例:

  1. a man with a beard and a suit jacket
  2. a man in a suit and tie standing in front of a crowd
  3. a man with long hair and a tie

相信你发现了,所有的标注,都不会含有 “Brad Pitt”,那这篇文章训练出的 LoRA 模型的 Trigger Words(触发词)是什么?

:“a man”。

是不是很有趣,看似简单的 Prompt 中也有一些真实有用的小技巧和逻辑。别急着去炼丹,我们继续往下看。

在这里,我们使用 Brad Pitt 的 100 张图片进行演示,数据集已经上传到了Demos/data/14,你可以下载后放到当前目录下的 ./data/14 下。这个路径没有什么说法,单纯是为了对齐示例代码,你也可以修改代码关于数据的路径,这里不会有限制,你甚至可以直接用其他的数据集,只要它的文件组织如下:

-- 图片1
-- 图片1.txt
-- 图片2
-- 图片2.txt
...

注意:图片和对应的文本标注需要同名,且位于同一文件夹中。

值得一提的是,样例数据集的裁剪大小和比例都是不一致的,只是接近正方形,但这没有太大的关系,因为在数据预处理的时候会自动放缩(resize),所以在这里不用担心你的数据集无法训练。

设置项目路径

很好!现在你已经知道这篇文章数据集相关的所有前置知识,直接复制下面的代码运行,不用在意其中的任何代码细节,你只需要知道会创建一个文件夹SD,之后的所有结果都会被存放在其中:

# 项目名称和数据集名称
project_name = "Brad"
dataset_name = "Brad"# 根目录和主要目录
root_dir = "./"  # 当前目录
main_dir = os.path.join(root_dir, "SD")  # 主目录# 项目目录
project_dir = os.path.join(main_dir, project_name)  # 项目目录# 数据集和模型路径
images_folder = os.path.join(main_dir, "Datasets", dataset_name)
prompts_folder = os.path.join(main_dir, "Datasets", "prompts")
captions_folder = images_folder  # 与原始代码一致
output_folder = os.path.join(project_dir, "logs")  # 存放 model checkpoints 和 validation 的文件夹# prompt 文件路径
validation_prompt_name = "validation_prompt.txt"
validation_prompt_path = os.path.join(prompts_folder, validation_prompt_name)# 模型检查点路径
model_path = os.path.join(project_dir, "logs", "checkpoint-last")# 其他路径设置
zip_file = os.path.join("./", "data/14/Datasets.zip")
inference_path = os.path.join(project_dir, "inference")  # 保存推理结果的文件夹os.makedirs(images_folder, exist_ok=True)
os.makedirs(prompts_folder, exist_ok=True)
os.makedirs(output_folder, exist_ok=True)
os.makedirs(inference_path, exist_ok=True)# 检查并解压数据集
print("📂 正在检查并解压样例数据集...")if not os.path.exists(zip_file):print("❌ 未找到数据集压缩文件 Datasets.zip!")print("请下载数据集:\nhttps://github.com/Hoper-J/AI-Guide-and-Demos-zh_CN/blob/master/Demos/data/14/Datasets.zip\n并放在 ./data/14 文件夹下")
else:subprocess.run(f"unzip -q -o {zip_file} -d {main_dir}", shell=True)print(f"✅ 项目 {project_name} 已准备好!")

如果你用的是自己的数据集,修改 zip_file 即可(压缩为 zip 格式):

zip_file = # 改为你自己的数据集路径

导入数据

下面,我们需要自定义一个 Dataset 类,它的作用是告诉模型如何处理你的数据集,这个自定义的类能够返回图像和文本标注分别作为 datalabel。接下来的内容会有点“干”,你也可以将其先当作黑盒,我会在每个函数之后提供一个简练的解释帮你理解。

怎么扩充数据集?

这里有一个非常熟悉的词:transform,但这个跟我们耳熟能详的 transformer 可不同,transform 就是单纯的对图像进行操作,比如说调整大小,翻转,又或者随机的裁剪一部分区域,这些操作统称为数据增强。

数据增强就是扩充数据集的外挂,以下图为例,即便进行水平翻转+颜色变化+中心裁剪,它也是一只企鹅。

Transform

这大大地扩充了数据集。知道了概念后,我们简单定义当前的数据增强如下:

# 训练图像的分辨率
resolution = 512# 数据增强操作
train_transform = transforms.Compose([transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BILINEAR),  # 调整图像大小transforms.CenterCrop(resolution),  # 中心裁剪图像transforms.RandomHorizontalFlip(),  # 随机水平翻转transforms.ToTensor(),  # 将图像转换为张量]
)

怎么让模型理解文本?

使用 CLIPTokenizer,这是 Hugging Face transformers 库中的一个类,专门用于对文本进行分词(tokenization)操作。CLIP,全称 Contrastive Language-Image Pretraining(对比语言-图像预训练),Contrastive 这个词说透了它的由来,这是一个非常有意思的自训练思想:通过最大化对应文本-图像对的相似性,同时最小化不同文本-图像对的相似性实现训练。

学习资料

论文链接:Learning Transferable Visual Models From Natural Language Supervision
对理论感兴趣的话可以进一步查看以下四个非常棒的视频:

  1. 对比学习论文综述【论文精读】
  2. CLIP 论文逐段精读【论文精读】
  3. CLIP 改进工作串讲(上)【论文精读·42】
  4. CLIP 改进工作串讲(下)【论文精读·42】

你将发现两个宝藏 UP 主,我无法用语言表达对他们的赞美,只能道一句:“导师好!”。

具体来说,CLIPTokenizer 将输入的 prompt 拆解为 token(单词或子词),并将这些 token 映射为input_ids 供 CLIP 模型的 text_encoder 处理,从而生成 prompt 的嵌入向量,以让模型理解。

就像一切数据到了计算机中都变成 0,1 让其处理,所以向上抽象一下,CLIP 就是将人类可以阅读的文本描述变成模型能够理解的形式。

拓展:看看 Tokenizer 实际上做了什么

from transformers import CLIPTokenizer# 初始化 CLIPTokenizer
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")# 示例 prompt
prompt_text = "A man in a graphic tee and sport coat."# 先使用 tokenizer.tokenize 查看分词后的 token
tokens = tokenizer.tokenize(prompt_text)
print("Tokens:", tokens)# 将文本转化为 token
inputs = tokenizer(prompt_text,padding="max_length",  # 如果输入长度不足最大长度,进行填充truncation=True,       # 如果输入过长,进行截断return_tensors="pt"    # 返回 PyTorch 张量
)# 打印分词后的结果
print("Tokenized Input IDs:", inputs.input_ids)
print("Attention Mask:", inputs.attention_mask)

输出:

Tokens: ['a</w>', 'man</w>', 'in</w>', 'a</w>', 'graphic</w>', 'tee</w>', 'and</w>', 'sport</w>', 'coat</w>', '.</w>']
Tokenized Input IDs: tensor([[49406,   320,   786,   530,   320,  4245,  3385,   537,  2364,  7356,269, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,49407, 49407, 49407, 49407, 49407, 49407, 49407]])
Attention Mask: tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,0, 0, 0, 0, 0]])

问:49407 是什么?我们的 prompt 中似乎没有重复的词。

答:结束标记,这是因为我们设置了 padding="max_length"。思考一下,设置padding=False后输出应该是什么样的?先不要往下滑。

具体解释:
  • Tokenized Input IDs:这个张量展示了输入文本 A man in a graphic tee and sport coat. 被转换为的数字 ID 序列。每个数字 ID 对应于词汇表中的一个 token,49406 是起始标记,49407 是结束标记。
  • Attention Mask:用于标记哪些 token 需要模型的关注,1 表示有效 token,0 表示填充的无效 token。

padding=False时的输出:

Tokens: ['a</w>', 'man</w>', 'in</w>', 'a</w>', 'graphic</w>', 'tee</w>', 'and</w>', 'sport</w>', 'coat</w>', '.</w>']
Tokenized Input IDs: tensor([[49406,   320,   786,   530,   320,  4245,  3385,   537,  2364,  7356,269, 49407]])
Attention Mask: tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])

是不是和预期一致呢?

接下来,input_ids 将被传入 text_encoder,生成文本的嵌入向量。

自定义数据集

在认识 transformtokenizer 之后,我们可以定义自己的数据集。这个 Text2ImageDataset 负责将图像和文本配对,并进行数据的预处理,以便输入到模型中。

# 识别图片后缀
IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"]class Text2ImageDataset(torch.utils.data.Dataset):"""(1) 目标:- 用于构建文本到图像模型的微调数据集"""def __init__(self, images_folder, captions_folder, transform, tokenizer):"""(2) 参数:- images_folder: str, 图像文件夹路径- captions_folder: str, 标注文件夹路径- transform: function, 将原始图像转换为 torch.Tensor- tokenizer: CLIPTokenizer, 将文本标注转为 word ids"""# 初始化图像路径列表,并根据指定的扩展名找到所有图像文件self.image_paths = []for ext in IMAGE_EXTENSIONS:self.image_paths.extend(glob.glob(os.path.join(images_folder, f"*{ext}")))self.image_paths = sorted(self.image_paths)# 加载对应的文本标注,依次读取每个文本文件中的内容caption_paths = sorted(glob.glob(os.path.join(captions_folder, "*.txt")))captions = []for p in caption_paths:with open(p, "r", encoding="utf-8") as f:captions.append(f.readline().strip())# 确保图像和文本标注数量一致if len(captions) != len(self.image_paths):raise ValueError("图像数量与文本标注数量不一致,请检查数据集。")# 使用 tokenizer 将文本标注转换为 word idsinputs = tokenizer(captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt")self.input_ids = inputs.input_idsself.transform = transformdef __getitem__(self, idx):img_path = self.image_paths[idx]input_id = self.input_ids[idx]try:# 加载图像并将其转换为 RGB 模式,然后应用数据增强image = Image.open(img_path).convert("RGB")tensor = self.transform(image)except Exception as e:print(f"⚠️ 无法加载图像路径: {img_path}, 错误: {e}")# 返回一个全零的张量和空的输入 ID 以避免崩溃tensor = torch.zeros((3, resolution, resolution))input_id = torch.zeros_like(input_id)return tensor, input_id  # 返回处理后的图像和相应的文本标注def __len__(self):return len(self.image_paths)

解释

  • IMAGE_EXTENSIONS:定义可接受的图像文件扩展名列表。
  • __init__ 方法
    • 图像路径:通过遍历指定的图像文件夹,获取所有符合扩展名的图像文件路径,并排序。
    • 文本标注:在标注文件夹中查找所有 .txt 文件,读取其内容并存储为列表。
    • 一致性检查:确保图像数量与文本标注数量一致。
    • 文本编码:使用 tokenizer 将文本标注转换为 token IDs。
    • 数据转换:存储图像的预处理方法 transform
  • __getitem__ 方法
    • 根据索引获取图像路径和对应的文本 token ID。
    • 尝试加载并预处理图像,失败时返回全零张量。
  • __len__ 方法:返回数据集的长度。

定义微调相关的函数

加载 LoRA

先阅读这两篇文章来加深理解:

  • 认识 LoRA:从线性层到注意力机制
  • PEFT微调:在大模型中快速应用 LoRA

LoRA(Low-Rank Adaptation) 是一种非常高效的参数微调方法,通过在预训练模型的特定层添加小的低秩矩阵(可以联想线性代数中的奇异值分解),来实现模型的微调,这也是一类 Adapter。

LoRA 的核心思想是将大模型中的某些权重矩阵近似为两个低秩矩阵进行更新,从而大幅减少需要微调的参数数量,提高训练效率和节省存储空间。一般而言,模型越大,减小比例越夸张,对于 GPT-3,LoRA 微调的训练参数量为原来的 1/10000。

通常,在微调时我们只对模型的特定部分(如注意力机制中的 Q、K、V 矩阵)进行 LoRA 微调,而不是微调整个模型。这里选择对 unettext_encoder 增加 LoRA,因为这两个模块直接负责图像生成和文本引导中的关键任务:unet 处理扩散过程的逆运算,text_encoder 将输入文本转换为特征向量。下面,我们定义一个函数来应用 LoRA 模型。

def prepare_lora_model(lora_config, pretrained_model_name_or_path, model_path=None, resume=False, merge_lora=False):"""(1) 目标:- 加载完整的 Stable Diffusion 模型,包括 LoRA 层,并根据需要合并 LoRA 权重。这包括 Tokenizer、噪声调度器、UNet、VAE 和文本编码器。(2) 参数:- lora_config: LoraConfig, LoRA 的配置对象- pretrained_model_name_or_path: str, Hugging Face 上的模型名称或路径- model_path: str, 预训练模型的路径- resume: bool, 是否从上一次训练中恢复- merge_lora: bool, 是否在推理时合并 LoRA 权重(3) 返回:- tokenizer: CLIPTokenizer- noise_scheduler: DDPMScheduler- unet: UNet2DConditionModel- vae: AutoencoderKL- text_encoder: CLIPTextModel"""# 加载噪声调度器,用于控制扩散模型的噪声添加和移除过程noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")# 加载 Tokenizer,用于将文本标注转换为 tokenstokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path,subfolder="tokenizer")# 加载 CLIP 文本编码器,用于将文本标注转换为特征向量text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path,torch_dtype=weight_dtype,subfolder="text_encoder")# 加载 VAE 模型,用于在扩散模型中处理图像的潜在表示vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path,subfolder="vae")# 加载 UNet 模型,负责处理扩散模型中的图像生成和推理过程unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path,torch_dtype=weight_dtype,subfolder="unet")# 如果设置为继续训练,则加载上一次的模型权重if resume:if model_path is None or not os.path.exists(model_path):raise ValueError("当 resume 设置为 True 时,必须提供有效的 model_path")# 使用 PEFT 的 from_pretrained 方法加载 LoRA 模型text_encoder = PeftModel.from_pretrained(text_encoder, os.path.join(model_path, "text_encoder"))unet = PeftModel.from_pretrained(unet, os.path.join(model_path, "unet"))# 确保 UNet 的可训练参数的 requires_grad 为 Truefor param in unet.parameters():if param.requires_grad is False:param.requires_grad = True# 确保文本编码器的可训练参数的 requires_grad 为 Truefor param in text_encoder.parameters():if param.requires_grad is False:param.requires_grad = Trueprint(f"✅ 已从 {model_path} 恢复模型权重")else:# 将 LoRA 配置应用到 text_encoder 和 unettext_encoder = get_peft_model(text_encoder, lora_config)unet = get_peft_model(unet, lora_config)# 打印可训练参数数量print("📊 Text Encoder 可训练参数:")text_encoder.print_trainable_parameters()print("📊 UNet 可训练参数:")unet.print_trainable_parameters()if merge_lora:# 合并 LoRA 权重到基础模型,仅在推理时调用text_encoder = text_encoder.merge_and_unload()unet = unet.merge_and_unload()# 切换为评估模式text_encoder.eval()unet.eval()# 冻结 VAE 参数vae.requires_grad_(False)# 将模型移动到 GPU 上并设置权重的数据类型unet.to(DEVICE, dtype=weight_dtype)vae.to(DEVICE, dtype=weight_dtype)text_encoder.to(DEVICE, dtype=weight_dtype)return tokenizer, noise_scheduler, unet, vae, text_encoder

解释:

  • 加载模型组件: 依次加载了噪声调度器、Tokenizer、文本编码器(text_encoder)、VAE 和 UNet 模型。
  • 应用 LoRA: 使用 get_peft_model 函数将 LoRA 配置应用到 text_encoderunet 模型中。这会在模型中插入可训练的 LoRA 层。
  • 打印可训练参数: 调用 print_trainable_parameters() 来查看 LoRA 添加了多少可训练参数。
  • 恢复训练: 如果设置了 resume=True,则从指定的 model_path 加载之前保存的模型权重。
  • 合并 LoRA 权重: 如果 merge_lora=True,则将 LoRA 的权重合并到基础模型中,以便在推理时使用,感兴趣的话阅读:PEFT:在大模型中快速应用 LoRA。
  • 冻结 VAE 参数: 调用 vae.requires_grad_(False) 来冻结 VAE 的参数,使其在训练中不更新。
  • 移动模型到设备: 将所有模型组件移动到指定的设备(CPU 或 GPU),并设置数据类型。

为什么只微调 unettext_encoder 最终却返回这么多模块?

因为在后面的微调中,我们将从文本开始处理而非将其当作又一个黑盒。

准备优化器

接下来,需要对于应用了 LoRA 的 UNet 和文本编码器(text_encoder)分别使用不同的学习率,这也是炼丹炉 UI 中常需要调节的选项。

def prepare_optimizer(unet, text_encoder, unet_learning_rate=5e-4, text_encoder_learning_rate=1e-4):"""(1) 目标:- 为 UNet 和文本编码器的可训练参数分别设置优化器,并指定不同的学习率。(2) 参数:- unet: UNet2DConditionModel, Hugging Face 的 UNet 模型- text_encoder: CLIPTextModel, Hugging Face 的文本编码器- unet_learning_rate: float, UNet 的学习率- text_encoder_learning_rate: float, 文本编码器的学习率(3) 返回:- 输出: 优化器 Optimizer"""# 筛选出 UNet 中需要训练的 Lora 层参数unet_lora_layers = [p for p in unet.parameters() if p.requires_grad]# 筛选出文本编码器中需要训练的 Lora 层参数text_encoder_lora_layers = [p for p in text_encoder.parameters() if p.requires_grad]# 将需要训练的参数分组并设置不同的学习率trainable_params = [{"params": unet_lora_layers, "lr": unet_learning_rate},{"params": text_encoder_lora_layers, "lr": text_encoder_learning_rate}]# 使用 AdamW 优化器optimizer = torch.optim.AdamW(trainable_params)return optimizer

定义 collate_fn 函数

在大多数常见的机器学习任务中(例如图像分类或回归),数据集通常是简单的 (data, label) 结构,PyTorch 的 DataLoader 默认能够处理这样的简单数据结构,将样本打包成批次(batch)。在我们的项目中,每个样本也是一个包含图像张量和文本编码的元组 (tensor, input_id)。默认的 collate_fn 可以将这些样本打包成批次,但返回的批次是一个元组,访问时需要使用索引,例如 batch[0]batch[1]

为了使代码更具可读性,我们可以自定义一个 collate_fn 函数,将批次数据组织成字典的形式,方便通过键名直接访问,例如 batch["pixel_values"]batch["input_ids"]。自定义的 collate_fn 定义如下:

def collate_fn(examples):pixel_values = []input_ids = []for tensor, input_id in examples:pixel_values.append(tensor)input_ids.append(input_id)pixel_values = torch.stack(pixel_values, dim=0).float()input_ids = torch.stack(input_ids, dim=0)# 如果你喜欢列表推导式的话,使用下面的方法#pixel_values = torch.stack([example[0] for example in examples], dim=0).float()#input_ids = torch.stack([example[1] for example in examples], dim=0)return {"pixel_values": pixel_values, "input_ids": input_ids}

解释:

  • examples 是什么?
    • examples 是一个列表,包含了一个批次中的多个样本。
    • 其中的每个样本都是从我们自定义的 Text2ImageDataset 数据集中获取的,形式为 (tensor, input_id)
      • tensor:经过预处理的图像张量,形状为 (C, H, W),即通道数(Channel)和图像的高度(Height)、宽度(Weight)。
      • input_id:对应的文本标注经过 tokenizer 编码后的张量,形状为 (sequence_length,)

补充:PyTorch 的 torch.stack() 函数会将多个张量沿新维度拼接在一起。例如,将一批图像张量拼接成 (batch_size, C, H, W) 的形式,确保每个批次数据的组织结构一致。

拓展:自定义和默认 collate_fn 的对比

下面提供了一个对比函数,来展示自定义 collate_fn 和默认 collate_fn 在处理当前数据时的不同。你可以通过运行代码来观察自定义和默认方式的使用差异。

import torch
from torch.utils.data import DataLoader, Datasetdef compare_dataloaders(dataset, batch_size):# 第一种情况:使用自定义的 collate_fntrain_dataloader_custom = DataLoader(dataset,shuffle=True,collate_fn=collate_fn,  # 使用自定义的 collate_fnbatch_size=batch_size,)# 第二种情况:不使用自定义的 collate_fn(默认方式)train_dataloader_default = DataLoader(dataset,shuffle=True,batch_size=batch_size,)# 从每个数据加载器中取一个批次进行对比custom_batch = next(iter(train_dataloader_custom))default_batch = next(iter(train_dataloader_default))# 打印自定义 collate_fn 的输出结果print("使用自定义 collate_fn:")print("批次的类型:", type(custom_batch))print("批次 pixel_values 的形状:", custom_batch["pixel_values"].shape)print("批次 input_ids 的形状:", custom_batch["input_ids"].shape)# 打印默认 DataLoader 的输出结果print("\n使用默认 collate_fn:")print("批次的类型:", type(default_batch))pixel_values, input_ids = default_batchprint("批次 pixel_values 的形状:", pixel_values.shape)print("批次 input_ids 的形状:", input_ids.shape)return custom_batch, default_batch# 对比
custom_batch, default_batch = compare_dataloaders(dataset, batch_size=2)

输出

使用自定义 collate_fn:
批次的类型: <class 'dict'>
批次 pixel_values 的形状: torch.Size([2, 3, 224, 224])
批次 input_ids 的形状: torch.Size([2, 16])使用默认 collate_fn:
批次的类型: <class 'list'>
批次 pixel_values 的形状: torch.Size([2, 3, 224, 224])
批次 input_ids 的形状: torch.Size([2, 16])

具体选择哪一种由你决定,默认的方法实际上更普遍。

设置相关参数

设备配置

当前的微调毫无疑问需要用到显卡(GPU),对于 Apple 芯片的 Mac 来说,把 “cuda” 改为 “mps”,也就是使用第二行代码,但需要注意的是,对于PyTorch版本过低的环境, torch.backends.mps.is_available() 会报错,所以这里选择注释。

# 设备配置
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")# For Mac M1, M2...
# DEVICE = torch.device("mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu"))print(f"🖥 当前使用的设备: {DEVICE}")

模型与训练参数配置

这里的参数大多与之前的函数相关,下面是你可以调节的内容:

  • 训练参数:设置批次大小、数据类型、随机种子等。
    • train_batch_size = 2 时,微调显存要求为 5G,在命令行输入 nvidia-smi 可以查看当前显存占用。
  • 优化器参数:为 UNet 和文本编码器分别设置学习率。
  • 学习率调度器:选择 cosine_with_restarts 调度器,这一点一般无关紧要。
  • 预训练模型:指定预训练的 Stable Diffusion 模型。
  • LoRA 配置:设置 LoRA 的相关参数,如秩 rlora_alpha、应用模块等。
# 训练相关参数
train_batch_size = 2  # 训练批次大小,即每次训练中处理的样本数量
weight_dtype = torch.bfloat16  # 权重数据类型,使用 bfloat16 以节省内存并加快计算速度
snr_gamma = 5  # SNR 参数,用于信噪比加权损失的调节系数# 设置随机数种子以确保可重复性
seed = 1126  # 随机数种子
torch.manual_seed(seed)
if torch.cuda.is_available():torch.cuda.manual_seed_all(seed)# Stable Diffusion LoRA 的微调参数# 优化器参数
unet_learning_rate = 1e-4  # UNet 的学习率,控制 UNet 参数更新的步长
text_encoder_learning_rate = 1e-4  # 文本编码器的学习率,控制文本嵌入层的参数更新步长# 学习率调度器参数
lr_scheduler_name = "cosine_with_restarts"  # 设置学习率调度器为 Cosine annealing with restarts,逐渐减少学习率并定期重启
lr_warmup_steps = 100  # 学习率预热步数,在最初的 100 步中逐渐增加学习率到最大值
max_train_steps = 2000  # 总训练步数,决定了整个训练过程的迭代次数
num_cycles = 3  # Cosine 调度器的周期数量,在训练期间会重复 3 次学习率周期性递减并重启# 预训练的 Stable Diffusion 模型路径,用于加载模型进行微调
pretrained_model_name_or_path = "stablediffusionapi/cyberrealistic-41"  # LoRA 配置
lora_config = LoraConfig(r=32,  # LoRA 的秩,即低秩矩阵的维度,决定了参数调整的自由度lora_alpha=16,  # 缩放系数,控制 LoRA 权重对模型的影响target_modules=["q_proj", "v_proj", "k_proj", "out_proj",  # 指定 Text encoder 的 LoRA 应用对象(用于调整注意力机制中的投影矩阵)"to_k", "to_q", "to_v", "to_out.0"  # 指定 UNet 的 LoRA 应用对象(用于调整 UNet 中的注意力机制)],lora_dropout=0  # LoRA dropout 概率,0 表示不使用 dropout
)

微调前的准备

准备数据集

# 初始化 tokenizer
tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path,subfolder="tokenizer"
)# 准备数据集
dataset = Text2ImageDataset(images_folder=images_folder,captions_folder=captions_folder,transform=train_transform,tokenizer=tokenizer,
)train_dataloader = torch.utils.data.DataLoader(dataset,shuffle=True,collate_fn=collate_fn,  # 之前定义的collate_fn()batch_size=train_batch_size,num_workers=8,
)print("✅ 数据集准备完成!")

解释:

  • 加载 Tokenizer: 使用与预训练模型相同的 Tokenizer。
  • 创建数据集: 使用我们之前定义的 Text2ImageDataset
  • 创建数据加载器: 使用 PyTorch 的 DataLoader

准备模型和优化器

# 准备模型
tokenizer, noise_scheduler, unet, vae, text_encoder = prepare_lora_model(lora_config,pretrained_model_name_or_path,model_path,resume=False,merge_lora=False
)# 准备优化器
optimizer = prepare_optimizer(unet, text_encoder, unet_learning_rate=unet_learning_rate, text_encoder_learning_rate=text_encoder_learning_rate
)# 设置学习率调度器
lr_scheduler = get_scheduler(lr_scheduler_name,optimizer=optimizer,num_warmup_steps=lr_warmup_steps,num_training_steps=max_train_steps,num_cycles=num_cycles
)print("✅ 模型和优化器准备完成!可以开始训练。")

解释:

  • 准备模型: 调用之前定义的 prepare_lora_model 函数。
  • 准备优化器: 调用之前定义的 prepare_optimizer 函数。
  • 设置学习率调度器: 使用 Hugging Face 的 get_scheduler 函数。

开始微调

主要流程和结构如下:

  • 训练循环: 我们在多个 epoch 中进行训练,直到达到 max_train_steps。每个 epoch 代表一轮数据的完整训练,在常见的 UI 界面中也可以看到 epochmax_train_steps 的参数。
  • 编码图像: 使用 VAE(变分自编码器)将图像编码为潜在表示(latent space),以便后续在扩散模型中添加噪声并进行处理。
  • 添加噪声: 使用噪声调度器(noise_scheduler)为潜在表示添加随机噪声,模拟图像从清晰到噪声的退化过程。这是扩散模型的关键步骤,训练时模型通过学习如何还原噪声,从而在推理过程中通过逐步去噪生成清晰的图像。
  • 获取文本嵌入: 使用文本编码器(text_encoder)将输入的文本 prompt 转换为隐藏状态(我们见过很多类似的表达:隐藏向量/特征向量/embedding/…),为图像生成提供文本引导信息。
  • 计算目标值: 根据扩散模型的类型(epsilonv_prediction),确定模型的目标输出(噪声或速度向量)。
  • UNet 预测: 使用 UNet 模型对带噪声的潜在表示进行预测,生成的输出用于还原噪声或预测速度向量。
  • 计算损失: 通过加权均方误差(MSE)计算模型损失,并进行反向传播。
  • 优化与保存:通过优化器更新模型参数,并在适当时保存检查点。
# 禁用并行化,避免警告
os.environ["TOKENIZERS_PARALLELISM"] = "false"# 初始化
global_step = 0
best_face_score = float("inf")  # 初始化为正无穷大,存储最佳面部相似度分数# 进度条显示训练进度
progress_bar = tqdm(range(max_train_steps),  # 根据 num_training_steps 设置desc="训练步骤",
)# 训练循环
for epoch in range(math.ceil(max_train_steps / len(train_dataloader))):# 如果你想在训练中增加评估,那在循环中增加 train() 是有必要的unet.train()text_encoder.train()for step, batch in enumerate(train_dataloader):if global_step >= max_train_steps:break# 编码图像为潜在表示(latent)latents = vae.encode(batch["pixel_values"].to(DEVICE, dtype=weight_dtype)).latent_dist.sample()latents = latents * vae.config.scaling_factor  # 根据 VAE 的缩放因子调整潜在空间# 为潜在表示添加噪声,生成带噪声的图像noise = torch.randn_like(latents)  # 生成与潜在表示相同形状的随机噪声timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (latents.shape[0],), device=DEVICE).long()noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)# 获取文本的嵌入表示encoder_hidden_states = text_encoder(batch["input_ids"].to(DEVICE))[0]# 计算目标值if noise_scheduler.config.prediction_type == "epsilon":target = noise  # 预测噪声elif noise_scheduler.config.prediction_type == "v_prediction":target = noise_scheduler.get_velocity(latents, noise, timesteps)  # 预测速度向量# UNet 模型预测model_pred = unet(noisy_latents, timesteps, encoder_hidden_states)[0]# 计算损失if not snr_gamma:loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")else:# 计算信噪比 (SNR) 并根据 SNR 加权 MSE 损失snr = compute_snr(noise_scheduler, timesteps)mse_loss_weights = torch.stack([snr, snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0]if noise_scheduler.config.prediction_type == "epsilon":mse_loss_weights = mse_loss_weights / snrelif noise_scheduler.config.prediction_type == "v_prediction":mse_loss_weights = mse_loss_weights / (snr + 1)# 计算加权的 MSE 损失loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weightsloss = loss.mean()# 反向传播loss.backward()optimizer.step()lr_scheduler.step()optimizer.zero_grad()progress_bar.update(1)global_step += 1# 打印训练损失if global_step % 100 == 0 or global_step == max_train_steps:print(f"🔥 步骤 {global_step}, 损失: {loss.item()}")# 保存中间检查点,当前简单设置为每 500 步保存一次if global_step % 500 == 0:save_path = os.path.join(output_folder, f"checkpoint-{global_step}")os.makedirs(save_path, exist_ok=True)# 使用 save_pretrained 保存 PeftModelunet.save_pretrained(os.path.join(save_path, "unet"))text_encoder.save_pretrained(os.path.join(save_path, "text_encoder"))print(f"💾 已保存中间模型到 {save_path}")# 保存最终模型到 checkpoint-last
save_path = os.path.join(output_folder, "checkpoint-last")
os.makedirs(save_path, exist_ok=True)
unet.save_pretrained(os.path.join(save_path, "unet"))
text_encoder.save_pretrained(os.path.join(save_path, "text_encoder"))
print(f"💾 已保存最终模型到 {save_path}")print("🎉 微调完成!")

训练完成后的 checkpoint 会保存到 ./SD/Brad/logs/checkpoint-last 中,以 max_train_steps=200 为例,模型输出如下:

image-20240929230118158

生成图像和评估

什么是 pipeline

pipeline 是 Hugging Face 库中一种高层次的封装工具,通常用于推理。默认情况下,pipelineeval 模式加载模型,因此适合用于生成或评估场景。我们这里使用的是 Diffusers.DiffusionPipeline,它将之前提到的多个模型组件(如 UNet、VAE、文本编码器等)组合在一起,实现从文本到图像的生成。

pipeline 的工作原理也跟之前微调过程类似:

  1. 文本编码pipeline 中的文本编码器会将输入的 prompt 转换为特征向量。
  2. 噪声注入:在潜在空间中,模型从随机噪声开始生成图像。
  3. 迭代去噪:UNet 使用从文本编码器得到的特征向量指导去噪过程,逐步将噪声还原为高质量图像。
  4. 图像解码:最终,VAE 将潜在表示解码为实际的图像。

推理相关的参数

  1. 什么是推理步数(num_inference_steps)?

    • 推理步数控制扩散模型生成图像时的去噪迭代次数。步数越多,生成的图像质量越高,但推理时间也相应增加。这是一个需要你根据图像质量和时间需求去权衡的参数,通常在肉眼觉得够好的时候,就可以了。
  2. 如何决定 prompt 的影响程度(guidance_scale)?

    • guidance_scale 决定了文本提示对生成图像的影响程度。较高的 guidance_scale 会让模型更严格地按照 prompt 生成图像,数值通常在 7.5 到 10 之间调整,过高可能会导致图像失真,同样需要你去权衡。这个参数与文本生成任务中的 temperature 参数类似,适用于不同场景。
  3. 怎么确保相同 prompt 生成相同的图像?

    • 设置固定的随机数种子(seed),可以确保同样的 prompt 在每次运行时生成相同的图像。可以通过使用 torch.Generator 生成随机数并设置种子(seed),示例如下:
    generator = torch.Generator().manual_seed(42)
    

加载用于验证的 prompts

这是一组用于生图的文本提示(prompts),本实验中位于./SD/Datasets/prompts/validation_prompt.txt,下面摘取几行 prompt 预览:

  • A man in a black hoodie and khaki pants.
  • A man sports a red polo and denim jacket.
  • A man wears a blue shirt and brown blazer.

定义加载 prompts 的函数如下:

def load_validation_prompts(validation_prompt_path):"""(1) 目标:- 加载验证提示文本。(2) 参数:- validation_prompt_path: str, 验证提示文件的路径(3) 返回:- validation_prompt: list, 验证提示的字符串列表,每一行就是一个prompt"""with open(validation_prompt_path, "r", encoding="utf-8") as f:validation_prompt = [line.strip() for line in f.readlines()]return validation_prompt

定义生成图像的函数

结合之前的讨论,我们可以定义一个生成图像的函数:

def generate_images(pipeline, prompts, num_inference_steps=50, guidance_scale=7.5, output_folder="inference", generator=None):"""(1) 目标:- 使用 DiffusionPipeline 生成图像,保存到指定文件夹并返回生成的图像列表。(2) 参数:- pipeline: DiffusionPipeline, 已加载并配置好的 Pipeline- prompts: list, 文本提示列表- num_inference_steps: int, 推理步骤数,越高图像质量越好,但推理时间也会增加- guidance_scale: float, 决定文本提示对生成图像的影响程度- output_folder: str, 保存生成图像的文件夹路径- generator: torch.Generator, 控制生成随机数的种子,确保图像生成的一致性。如果不提供,生成的图像每次可能不同(3) 返回:- 生成的图像列表,同时图像也会保存到指定文件夹。"""print("🎨 正在生成图像...")os.makedirs(output_folder, exist_ok=True)generated_images = []for i, prompt in enumerate(tqdm(prompts, desc="生成图像中")):# 使用 pipeline 生成图像image = pipeline(prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator).images[0]# 保存图像到指定文件夹save_file = os.path.join(output_folder, f"generated_{i+1}.png")image.save(save_file)# 将图像保存到列表中,稍后返回generated_images.append(image)print(f"✅ 已生成并保存 {len(prompts)} 张图像到 {output_folder}")return generated_images

定义评估函数

虽然图像生成的好与坏现在更多的由人去判断,但最基础的模块还是可以交给机器,以当前实验为例,我们的目的是 “AI 换脸”,那就可以有两个新的度量:

  • 无脸图像的数量

    • 使用 DeepFace 库检测生成图像中的人脸。如果没有检测到人脸,则该图像计为无脸图像,数量加 1。
  • 面部相似性

    • 利用 DeepFace 库提取生成图像中的人脸特征,然后与训练集中人脸的特征进行对比。通过计算欧氏距离来衡量相似度,距离越小,表示生成的人脸与训练集中人脸的相似度越高。

      拓展:什么是欧式距离?

      听起来很复杂,实际上非常简单,以二维空间为例:

      如果 p = ( x 1 , y 1 ) \mathbf{p} = (x_1, y_1) p=(x1,y1) q = ( x 2 , y 2 ) \mathbf{q} = (x_2, y_2) q=(x2,y2),它们之间的欧式距离公式为:
      d ( p , q ) = ( x 1 − x 2 ) 2 + ( y 1 − y 2 ) 2 d(\mathbf{p}, \mathbf{q}) = \sqrt{(x_1 - x_2)^2 + (y_1 - y_2)^2} d(p,q)=(x1x2)2+(y1y2)2
      是不是很熟悉?这就是我们在几何学中学过的两点之间的距离公式。

      将其拓展到 n n n 维空间,对于两个点 p = ( p 1 , p 2 , … , p n ) \mathbf{p} = (p_1, p_2, \dots, p_n) p=(p1,p2,,pn) q = ( q 1 , q 2 , … , q n ) \mathbf{q} = (q_1, q_2, \dots, q_n) q=(q1,q2,,qn) ,欧式距离的公式为:
      d ( p , q ) = ( p 1 − q 1 ) 2 + ( p 2 − q 2 ) 2 + ⋯ + ( p n − q n ) 2 d(\mathbf{p}, \mathbf{q}) = \sqrt{(p_1 - q_1)^2 + (p_2 - q_2)^2 + \dots + (p_n - q_n)^2} d(p,q)=(p1q1)2+(p2q2)2++(pnqn)2
      P.S. 虽然欧式距离通常适用于欧几里得空间,但我们不需要特别关注这些数学限制。

除了人脸生成之外,AI 图像生成领域还有很多其他应用场景。那么,有没有通用的评估方法来衡量生成图像与文本提示的匹配度呢?

有,CLIP 评分

是的,CLIP 除了可以处理文本输入,还可以评估最终的模型,无论生成的是人脸、风景还是物体,它都可以帮助我们判断生成图像与文本提示的相关性。

对于当前实验,我们采取这三种方式对模型进行度量,完整流程如下:

  1. 使用 load_validation_prompts() 函数从文件中加载 prompts。
  2. 使用 prepare_lora_model() 函数加载已经经过 LoRA 微调的 UNet 和文本编码器(text_encoder),并合并 LoRA 权重。模型会从上一次训练保存的文件中恢复权重。
  3. 使用已经微调的 UNet 和文本编码器来创建 DiffusionPipeline
  4. 加载 CLIP 模型后续用于评估。
  5. 使用 DeepFace 提取训练图像的面部嵌入 train_emb 与生成的图像进行对比,计算面部相似度。
  6. 进行评估,最后打印结果。
def evaluate(lora_config):"""加载模型、生成图像并评估。"""print("📂 加载验证提示...")validation_prompts = load_validation_prompts(validation_prompt_path)print("🔧 准备 LoRA 模型...")# 准备 LoRA 模型(用于推理,合并权重)tokenizer, noise_scheduler, unet, vae, text_encoder = prepare_lora_model(lora_config,pretrained_model_name_or_path,model_path=model_path,resume=True,  # 从检查点恢复merge_lora=True  # 合并 LoRA 权重)# 创建 DiffusionPipeline 并更新其组件print("🔄 创建 DiffusionPipeline...")pipeline = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path,unet=unet,  # 传递基础模型text_encoder=text_encoder,  # 传递基础模型torch_dtype=weight_dtype,safety_checker=None,)pipeline = pipeline.to(DEVICE)# 加载 CLIP 模型和处理器print("🎯 加载 CLIP 模型...")clip_model_name = "openai/clip-vit-base-patch32"clip_model = CLIPModel.from_pretrained(clip_model_name).to(DEVICE)clip_processor = CLIPProcessor.from_pretrained(clip_model_name)# CLIP 模型设置为评估模式clip_model.eval()# 设置随机数种子generator = torch.Generator(device=DEVICE)generator.manual_seed(seed)# 加载训练图像的面部嵌入print("📂 加载训练图像的面部嵌入...")train_image_paths = sorted([p for p in glob.glob(os.path.join(images_folder, "*")) if any(p.endswith(ext) for ext in IMAGE_EXTENSIONS)])train_emb_list = []for img_path in tqdm(train_image_paths, desc="提取训练图像面部嵌入"):face_representation = DeepFace.represent(img_path, detector_backend="ssd",model_name="GhostFaceNet",enforce_detection=False)if face_representation:embedding = face_representation[0]['embedding']train_emb_list.append(embedding)if len(train_emb_list) == 0:print("⚠️ 未能提取到任何训练图像的面部嵌入。")train_emb = torch.tensor([]).to(DEVICE)else:train_emb = torch.tensor(train_emb_list).to(DEVICE)# 生成图像generated_images = generate_images(pipeline=pipeline,prompts=validation_prompts,num_inference_steps=30,guidance_scale=7.5,output_folder=inference_path,# generator=generator)# 评估生成的图像,mis记录无法检测到面部的图像数量face_score, clip_score, mis = 0, 0, 0  # 初始化评估分数和计数valid_emb = []print("📊 正在计算评估分数...")for i, image in enumerate(tqdm(generated_images, desc="评估图像中")):# 使用 DeepFace 检测面部特征opencvImage = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)emb = DeepFace.represent(opencvImage,detector_backend="ssd",model_name="GhostFaceNet",enforce_detection=False,)if not emb or emb[0].get('face_confidence', 0) == 0:mis += 1  # 无法检测到面部的图像数量continue# 计算 CLIP 分数current_prompt = validation_prompts[i]inputs = clip_processor(text=current_prompt, images=image, return_tensors="pt").to(DEVICE)with torch.no_grad():outputs = clip_model(**inputs)sim = outputs.logits_per_imageclip_score += sim.item()# 收集有效的面部嵌入valid_emb.append(emb[0]['embedding'])# 如果没有有效的面部嵌入,则返回默认分数if len(valid_emb) == 0:print("⚠️ 无法检测到面部嵌入!")return 0, 0, mis# 计算面部相似度分数(使用欧氏距离)valid_emb = torch.tensor(valid_emb).to(DEVICE)valid_emb = valid_emb / valid_emb.norm(p=2, dim=-1, keepdim=True)train_emb = train_emb / train_emb.norm(p=2, dim=-1, keepdim=True)face_distance = torch.cdist(valid_emb, train_emb, p=2).mean().item()face_score = face_distance  # 平均欧氏距离作为面部相似性分数clip_score /= (len(validation_prompts) - mis) if (len(validation_prompts) - mis) > 0 else 1print("📈 评估完成!")# 打印评估结果print(f"✅ 面部相似度评分 (平均欧氏距离): {face_score:.4f} (越低越好,表示生成图像与训练图像更相似)")print(f"✅ CLIP 评分 (平均相似度): {clip_score:.4f} (越高越好,表示生成图像与文本提示的相关性更强)")print(f"✅ 无面部图像数量: {mis} (无法检测到面部的生成图像数量)")# 调用函数执行
evaluate(lora_config)

生成的图像会保存在 ./SD/Brad/inference 中。

拓展作业

  1. 当前 prompt 的触发词(trigger words)只是 “a man” 吗?
    仔细观察之前数据集的prompt:
    • a man with a beard and a suit jacket
    • a man in a suit and tie standing in front of a crowd
    • a man with long hair and a tie
  2. 使用当前数据集训练出的模型,如果 prompt 设置为 “a man”,生成的图像应该是什么样的?
  3. 除了之前设置的参数外,探究生成图像相关参数(位于 evaluate())。
    generated_images = generate_images(pipeline=pipeline,prompts=validation_prompts,num_inference_steps=30,  # 修改推理步数guidance_scale=7.5,  # 修改文本提示影响程度output_folder=inference_path,generator=generator  # 注释这一行,看看不传入 generator 时生成的图像是否有变化?尝试运行三次进行对比。)
    

希望你能通过对代码文件的运行,找到它们的答案。

参考链接

  • Learning Transferable Visual Models From Natural Language Supervision
  • DiffusionPipeline 文档, 源码
  • Customize a pipeline - Hugging Face

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/442401.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

计组与体系软题1-数据表示与校验码

一、数的编码方式 题1-0的表示 题2-补码的补码原码 1. 这道题涉及到数的编码范围和进制转换2. 题3-采用补码的目的 二、编码范围 题1-补码的表示范围(-2^(n-1)~2 ^(n-1)-1) n是字长/位数&#xff0c;2^7128&#xff0c;范围为-128~127题2-原码范围&#xff08;-2^&#xff0…

LORD-GX5-45 ROS安装

1、驱动安装 https://github.com/LORD-MicroStrain/MSCL 上述下载 x64:C&#xff0c;在下载完的deb文件下执行 sudo dpkg -i <PACKAGE_NAME>.deb #install MSCL sudo apt install -f #install dependencies2、源码安装 #新建工作空间 mkdir -p ~…

【C++】认识匿名对象

文章目录 目录 文章目录前言一、对匿名对象的解读二、匿名对象的对象类型三、匿名对象的使用总结 前言 在C中&#xff0c;匿名对象是指在没有呗命名的情况下创建的临时对象。它们通常在单个语句中执行一系列操作或调用某个函数&#xff0c;并且不需要将结果存放进变量中。 匿名…

【STM32单片机_(HAL库)】4-2-1【定时器TIM】定时器输出PWM实现呼吸灯实验

1.硬件 STM32单片机最小系统LED灯模块 2.软件 pwm驱动文件添加定时器HAL驱动层文件添加GPIO常用函数定时器输出PWM配置步骤main.c程序 #include "sys.h" #include "delay.h" #include "led.h" #include "pwm.h"int main(void) {HA…

音视频入门基础:FLV专题(13)——FFmpeg源码中,解析任意Type值的SCRIPTDATAVALUE类型的实现

一、SCRIPTDATAVALUE类型 从《音视频入门基础&#xff1a;FLV专题&#xff08;9&#xff09;——Script Tag简介》中可以知道&#xff0c;根据《video_file_format_spec_v10_1.pdf》第80到81页&#xff0c;SCRIPTDATAVALUE类型由一个8位&#xff08;1字节&#xff09;的Type和…

动态代理有用吗?一文了解靠谱的动态代理有哪些标准

在当今互联网时代中&#xff0c;从网络安全、隐私保护、市场调研和互联网营销到软件测试、缓存管理和数据库连接&#xff0c;用户为了更好地完成此类工作&#xff0c;往往会使用动态代理&#xff0c;那么进一步了解动态代理、明确动态代理的使用场景和选择标准则是十分有必要的…

OJ在线评测系统 后端微服务架构 注册中心 Nacos入门到启动

注册中心 服务架构中的注册中心是一个关键组件&#xff0c;用于管理和协助微服务之间的通信。注册中心的主要职责是服务的注册和发现&#xff0c;确保各个微服务能够相互找到并进行调用。 主要功能&#xff1a; 服务注册&#xff1a;微服务在启动时&#xff0c;将自身信息&am…

vite学习教程01、vite构建vue2

文章目录 前言一、vite初始化项目二、修改配置文件2.1、修改main.js文件2.2、修改App.vue文件2.3、修改helloworld.vue2.4、修改vite.conf.js2.5、修改vue版本--修改package.json文件 三、安装vue2和vite插件四、启动服务资料获取 前言 博主介绍&#xff1a;✌目前全网粉丝3W&…

常见激活函数总结

简介&#xff1a;个人学习分享&#xff0c;如有错误&#xff0c;欢迎批评指正。 一. 激活函数的定义 激活函数&#xff08;Activation Function&#xff09;是人工神经网络中对每个神经元的输入进行非线性变换的函数。神经网络中的每个神经元都会接受来自上一层的输入&#xf…

Windows安装HeidiSQL教程(图文)

一、软件简介 HeidiSQL是一款开源的数据库管理工具&#xff0c;主要用于管理MySQL、MariaDB、SQL Server、PostgreSQL和SQLite等数据库系统。它提供了直观的用户界面&#xff0c;使用户可以轻松地连接到数据库服务器、执行SQL查询、浏览和编辑数据、管理数据库结构等操作。 跨…

力扣hot100--链表

链表 1. 2. 两数相加 给你两个 非空 的链表&#xff0c;表示两个非负的整数。它们每位数字都是按照 逆序 的方式存储的&#xff0c;并且每个节点只能存储 一位 数字。 请你将两个数相加&#xff0c;并以相同形式返回一个表示和的链表。 你可以假设除了数字 0 之外&#xff…

【word脚注】双栏设置word脚注,脚注仅位于左栏,右栏不留白

【word脚注】双栏设置word脚注&#xff0c;脚注仅位于左栏&#xff0c;右栏不留白 调整前效果解决方法调整后效果参考文献 调整前效果 调整前&#xff1a;脚注位于左下角&#xff0c;但右栏与左栏内容对其&#xff0c;未填充右下角的空白区域 解决方法 备份源文件复制脚注内…

git创建新分支

git创建新分支 1.先在gitLab上New branch. 2.本地右键git小乌 - /切换/检出-创建新分支&#xff0c;分支名称和上一步创建的一样。 最后记得改个文件提交下&#xff0c;看看gitLab上是否提交成功。

蝶形激光器驱动(温控精度0.002°C 激光电流分辨率5uA)

蝶形半导体激光器驱动电流的稳定性直接决定了其输出波长的稳定性,进而影响检测精度.为了满足气体浓度检测中对激光器输出波长稳定可调的要求,设计了数字与模拟电路混合的恒流驱动电路.STM32为主控芯片数控模块完成扫描AD/DA转换;模拟电路主要由负反馈运算放大、高精度CMOS管和反…

22.第二阶段x86游戏实战2-背包遍历REP指令详解

免责声明&#xff1a;内容仅供学习参考&#xff0c;请合法利用知识&#xff0c;禁止进行违法犯罪活动&#xff01; 本次游戏没法给 内容参考于&#xff1a;微尘网络安全 本人写的内容纯属胡编乱造&#xff0c;全都是合成造假&#xff0c;仅仅只是为了娱乐&#xff0c;请不要…

rtmp协议转websocketflv的去队列积压

websocket server的优点 websocket server的好处&#xff1a;WebSocket 服务器能够实现实时的数据推送&#xff0c;服务器可以主动向客户端发送数据 1 不需要客户端不断轮询。 2 不需要实现httpserver跨域。 在需要修改协议的时候比较灵活&#xff0c;我们发送数据的时候比较…

【网络安全】利用XSS、OAuth配置错误实现token窃取及账户接管 (ATO)

未经许可,不得转载。 文章目录 正文正文 目标:target.com 在子域sub1.target.com上,我发现了一个XSS漏洞。由于针对该子域的漏洞悬赏较低,我希望通过此漏洞将攻击升级至app.target.com,因为该子域的悬赏更高。 分析认证机制后,我发现: sub1.target.com:使用基于Cook…

微信小程序——音乐播放器

一、界面设计 播放页面&#xff1a; 显示当前播放歌曲的封面图片、歌曲名称、歌手名称。有播放 / 暂停按钮、上一首、下一首按钮。进度条显示播放进度&#xff0c;可以拖动进度条调整播放位置。音量调节滑块。 歌曲列表页面&#xff1a; 展示歌曲列表&#xff0c;包括歌曲名称、…

C++——STL简介

目录 一、什么是STL 二、STL的版本 三、STL的六大组件 没用的话..... 不知不觉两个月没写博客了&#xff0c;暑假后期因为学校的事情在忙&#xff0c;开学又在准备学校的java免修&#xff0c;再然后才继续开始学C&#xff0c;然后最近打算继续写博客沉淀一下最近学到的几周…

构建高效团队,内部CRM系统的益处详解

内部CRM系统的最大优势之一是它能够集中并系统化客户信息&#xff0c;包括联系方式、购买历史、偏好设置、服务记录等。这种集中式的数据管理使企业能够快速响应客户需求&#xff0c;预测客户行为&#xff0c;提供个性化的服务或产品。更重要的是&#xff0c;它有助于建立一个统…