系列文章目录
- 【权重小技巧(1)】.pt文件无法打开或乱码?如何查看.pt文件的具体内容?
- 【权重小技巧(2)】模型权重文件总结: .bin、.safetensors、.pt的保存、加载方法一览
- 本文则总结权重的结构化读取和替换方法,以实现在框架 1 中训练后的部分模型 A 的权重,去替换掉框架 2 中推理时模型 B 中对应的权重。需要模型 A 和模型 B 是相同的结构。
- 本文的参考代码和 json 案例在 repo:https://github.com/wendashi/Read_weights/tree/main
文章目录
- 系列文章目录
- 背景
- 一. 结构化权重读取
- 二. 权重替换
- 总结
背景
- 目标:对 AnyText 的 base model (上图 UNet 部分)进行 custom diffusion 训练(只训 kv)。
- 难点:AnyText 的框架是 ModelScope 的,需要重新写 custom diffusion 的训练代码。
- 🔥简易解决方案:
- 由于 AnyText 的 base model 是 SD1.5,那么可以通过 diffusers 框架中已有的 custom diffusion 训练代码对 SD1.5 进行训练。
- 将训练后获得的权重去替换掉 AnyText 推理代码中所读取的 base model 中相对应的权重即可!
- 相关代码:
- AnyText: https://github.com/tyxsspa/AnyText (提供了 GUI 简易替换 base mode,搜Change base model)
- Custom Diffusion 训练示例 (diffusers):https://github.com/huggingface/diffusers/tree/main/examples/custom_diffusion
一. 结构化权重读取
- 读取 Custom Diffusion 训后的权重,获得权重的命名以及对应的维度(形状)。这里 diffusers 训练代码自动存的权重是 SD1.5 中 UNet 的 transformer 中的 cross attention 里的 k 和 v。
import torch
import json
from safetensors.torch import load_file # 导入 safetensors 库def get_weight_key_shape_pairs(data, prefix=""):pairs = []if isinstance(data, dict):for key, value in data.items():new_prefix = f"{prefix}.{key}" if prefix else keyif isinstance(value, torch.Tensor):shape = list(value.shape)pairs.append({new_prefix: shape})else:pairs.extend(get_weight_key_shape_pairs(value, new_prefix))elif isinstance(data, torch.Tensor):shape = list(data.shape)pairs.append({prefix: shape})return pairs# 加载 .ckpt 文件
# checkpoint_path = "/path/to/anytext_v1.1.ckpt"
# checkpoint = torch.load(checkpoint_path)# 加载 .safetensors 文件
checkpoint_path = '/path/to/pytorch_custom_diffusion_weights.safetensors'
checkpoint = load_file(checkpoint_path) # 获取所有权重的键和形状对
key_shape_pairs = get_weight_key_shape_pairs(checkpoint)# 将结果保存为 JSON 文件
output_json_path = "/path/to/pytorch_custom_diffusion_key_weight_pair.json"
with open(output_json_path, 'w') as f:json.dump(key_shape_pairs, f, indent=4)print(f"结果已保存到 {output_json_path}")
- 以上代码可以结构化地读取 .ckpt 和 .safetensors 权重,以及权重 weight /偏置 bias 命名和对应的形状。
- pytorch_custom_diffusion_key_weight_pair.json 得到的结果如下所示, 从命名也可以看出,是SD1.5 中 UNet 中 3 种 block (down/mid/up)的 transformer 中的 cross attention 里的 k 和 v。
[{"down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor.to_k_custom_diffusion.weight": [320,768]},{"down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor.to_v_custom_diffusion.weight": [320,768]},
...{"mid_block.attentions.0.transformer_blocks.0.attn2.processor.to_k_custom_diffusion.weight": [1280,768]},{"mid_block.attentions.0.transformer_blocks.0.attn2.processor.to_v_custom_diffusion.weight": [1280,768]},{"up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.to_k_custom_diffusion.weight": [1280,768]},{"up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.to_v_custom_diffusion.weight": [1280,768]},{
...
- 而 AnyText 中 SD1.5 的权重为 https://github.com/wendashi/Read_weights/tree/main 中的 anytext_key_weight_pair.json 文件可以找到,通过搜索 to_k,即可找到对应的权重。
- 可以看出 AnyText 中 SD1.5 中的 UNet 命名方式有一定差异,为 input_blocks/middle_block/output_blocks。
...{"model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_k.weight": [320,768]},{"model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_v.weight": [320,768]},
...{"model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_k.weight": [1280,768]},{"model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_v.weight": [1280,768]},
...{"model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_k.weight": [1280,768]},{"model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_v.weight": [1280,768]},
...
二. 权重替换
-
通过观察两个 json 文件的命名以及维度对应关系,目前作者只是通过手动对应的方式来进行了权重替换,如果读者有更好的想法欢迎评论区留言!
-
下图中,左边为 AnyText 的 SD1.5 原始权重,右边为训后的 SD1.5 中保存下来的 k和v 权重。
-
按照右边去找左边中的权重命名,发现左边 AnyText 一共是 23 个 transformer_blocks.0.attn2.to_k.weight
- 为什么不一样呢?
- 通过观察发现,第17个开始是 control net 的,而非原本的 SD1.5
- 说明看来二者是一一对应的(两边都是 16个 Unet 中的 to_k )
-
最终,手动写一个 maping 字典,让训好的权重去替换掉 AnyText 中 SD1.5 的相应权重即可。
-
完整代码在 repo:https://github.com/wendashi/Read_weights/tree/main 的 change_weights.py 中。
import torch
from safetensors.torch import load_file# 定义对应关系
mapping = {"model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight": "down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor.to_k_custom_diffusion.weight","model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_v.weight": "down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor.to_v_custom_diffusion.weight",
...}# 加载 anytext_v1.1.ckpt
original_weights = torch.load('/path/to/anytext_v1.1.ckpt')# 加载 pytorch_custom_diffusion_weights.safetensors
custom_diffusion_weights = load_file('/path/to/pytorch_custom_diffusion_weights.safetensors')# 进行权重替换
for original_key, custom_key in mapping.items():if original_key in original_weights and custom_key in custom_diffusion_weights:original_weights[original_key] = custom_diffusion_weights[custom_key]else:print(f"Key {original_key} in original weights or {custom_key} in custom weights not found.")# 保存新的权重文件
new_ckpt_path = '/path/to/anytext_v1.1_cd.ckpt'
torch.save(original_weights, new_ckpt_path)
print(f"新的权重文件已保存到 {new_ckpt_path}")
总结
提示:这里对文章进行总结:
例如:以上就是今天要讲的内容,本文仅仅简单介绍了pandas的使用,而pandas提供了大量能使我们快速便捷地处理数据的函数和方法。