【权重小技巧(3) 】权重替换—训练 A 模型去替换 B 模型中的对应权重

系列文章目录

  • 【权重小技巧(1)】.pt文件无法打开或乱码?如何查看.pt文件的具体内容?
  • 【权重小技巧(2)】模型权重文件总结: .bin、.safetensors、.pt的保存、加载方法一览
  • 本文则总结权重的结构化读取替换方法,以实现在框架 1 中训练后的部分模型 A 的权重,去替换掉框架 2 中推理时模型 B 中对应的权重。需要模型 A 和模型 B 是相同的结构。
  • 本文的参考代码和 json 案例在 repo:https://github.com/wendashi/Read_weights/tree/main

文章目录

  • 系列文章目录
  • 背景
    • 一. 结构化权重读取
    • 二. 权重替换
  • 总结


背景

在这里插入图片描述

  • 目标:对 AnyText 的 base model (上图 UNet 部分)进行 custom diffusion 训练(只训 kv)。
  • 难点:AnyText 的框架是 ModelScope 的,需要重新写 custom diffusion 的训练代码。
  • 🔥简易解决方案:
    • 由于 AnyText 的 base model 是 SD1.5,那么可以通过 diffusers 框架中已有的 custom diffusion 训练代码对 SD1.5 进行训练。
    • 将训练后获得的权重去替换掉 AnyText 推理代码中所读取的 base model 中相对应的权重即可!
  • 相关代码:
    • AnyText: https://github.com/tyxsspa/AnyText (提供了 GUI 简易替换 base mode,搜Change base model)
    • Custom Diffusion 训练示例 (diffusers):https://github.com/huggingface/diffusers/tree/main/examples/custom_diffusion

一. 结构化权重读取

  1. 读取 Custom Diffusion 训后的权重,获得权重的命名以及对应的维度(形状)。这里 diffusers 训练代码自动存的权重是 SD1.5 中 UNet 的 transformer 中的 cross attention 里的 k 和 v。
import torch
import json
from safetensors.torch import load_file  # 导入 safetensors 库def get_weight_key_shape_pairs(data, prefix=""):pairs = []if isinstance(data, dict):for key, value in data.items():new_prefix = f"{prefix}.{key}" if prefix else keyif isinstance(value, torch.Tensor):shape = list(value.shape)pairs.append({new_prefix: shape})else:pairs.extend(get_weight_key_shape_pairs(value, new_prefix))elif isinstance(data, torch.Tensor):shape = list(data.shape)pairs.append({prefix: shape})return pairs# 加载 .ckpt 文件
# checkpoint_path = "/path/to/anytext_v1.1.ckpt"
# checkpoint = torch.load(checkpoint_path)# 加载 .safetensors 文件
checkpoint_path = '/path/to/pytorch_custom_diffusion_weights.safetensors'
checkpoint = load_file(checkpoint_path)  # 获取所有权重的键和形状对
key_shape_pairs = get_weight_key_shape_pairs(checkpoint)# 将结果保存为 JSON 文件
output_json_path = "/path/to/pytorch_custom_diffusion_key_weight_pair.json"
with open(output_json_path, 'w') as f:json.dump(key_shape_pairs, f, indent=4)print(f"结果已保存到 {output_json_path}")
  • 以上代码可以结构化地读取 .ckpt 和 .safetensors 权重,以及权重 weight /偏置 bias 命名和对应的形状。
    • pytorch_custom_diffusion_key_weight_pair.json 得到的结果如下所示, 从命名也可以看出,是SD1.5 中 UNet 中 3 种 block (down/mid/up)的 transformer 中的 cross attention 里的 k 和 v。
[{"down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor.to_k_custom_diffusion.weight": [320,768]},{"down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor.to_v_custom_diffusion.weight": [320,768]},
...{"mid_block.attentions.0.transformer_blocks.0.attn2.processor.to_k_custom_diffusion.weight": [1280,768]},{"mid_block.attentions.0.transformer_blocks.0.attn2.processor.to_v_custom_diffusion.weight": [1280,768]},{"up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.to_k_custom_diffusion.weight": [1280,768]},{"up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.to_v_custom_diffusion.weight": [1280,768]},{
...
  • 而 AnyText 中 SD1.5 的权重为 https://github.com/wendashi/Read_weights/tree/main 中的 anytext_key_weight_pair.json 文件可以找到,通过搜索 to_k,即可找到对应的权重。
  • 可以看出 AnyText 中 SD1.5 中的 UNet 命名方式有一定差异,为 input_blocks/middle_block/output_blocks。
...{"model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_k.weight": [320,768]},{"model.diffusion_model.output_blocks.9.1.transformer_blocks.0.attn2.to_v.weight": [320,768]},
...{"model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_k.weight": [1280,768]},{"model.diffusion_model.middle_block.1.transformer_blocks.0.attn2.to_v.weight": [1280,768]},
...{"model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_k.weight": [1280,768]},{"model.diffusion_model.output_blocks.3.1.transformer_blocks.0.attn2.to_v.weight": [1280,768]},
...

二. 权重替换

  • 通过观察两个 json 文件的命名以及维度对应关系,目前作者只是通过手动对应的方式来进行了权重替换,如果读者有更好的想法欢迎评论区留言!

  • 下图中,左边为 AnyText 的 SD1.5 原始权重,右边为训后的 SD1.5 中保存下来的 k和v 权重。
    在这里插入图片描述

  • 按照右边去找左边中的权重命名,发现左边 AnyText 一共是 23 个 transformer_blocks.0.attn2.to_k.weight

    • 为什么不一样呢?
    • 通过观察发现,第17个开始是 control net 的,而非原本的 SD1.5
    • 说明看来二者是一一对应的(两边都是 16个 Unet 中的 to_k )
      在这里插入图片描述
  • 最终,手动写一个 maping 字典,让训好的权重去替换掉 AnyText 中 SD1.5 的相应权重即可。

  • 完整代码在 repo:https://github.com/wendashi/Read_weights/tree/main 的 change_weights.py 中。

import torch
from safetensors.torch import load_file# 定义对应关系
mapping = {"model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight": "down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor.to_k_custom_diffusion.weight","model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_v.weight": "down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor.to_v_custom_diffusion.weight",
...}# 加载 anytext_v1.1.ckpt
original_weights = torch.load('/path/to/anytext_v1.1.ckpt')# 加载 pytorch_custom_diffusion_weights.safetensors
custom_diffusion_weights = load_file('/path/to/pytorch_custom_diffusion_weights.safetensors')# 进行权重替换
for original_key, custom_key in mapping.items():if original_key in original_weights and custom_key in custom_diffusion_weights:original_weights[original_key] = custom_diffusion_weights[custom_key]else:print(f"Key {original_key} in original weights or {custom_key} in custom weights not found.")# 保存新的权重文件
new_ckpt_path = '/path/to/anytext_v1.1_cd.ckpt'
torch.save(original_weights, new_ckpt_path)
print(f"新的权重文件已保存到 {new_ckpt_path}")

总结

提示:这里对文章进行总结:

例如:以上就是今天要讲的内容,本文仅仅简单介绍了pandas的使用,而pandas提供了大量能使我们快速便捷地处理数据的函数和方法。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/13769.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

VSCode中使用EmmyLua插件对Unity的tolua断点调试

一.VSCode中搜索安装EmmyLua插件 二.创建和编辑launch.json文件 初始的launch.json是这样的 手动编辑加上一段内容如下图所示: 三.启动调试模式,并选择附加的进程

k8sollama部署deepseek-R1模型,内网无坑

这是目录 linux下载ollama模型文件下载到本地,打包迁移到k8s等无网络环境使用下载打包ollama镜像非k8s环境使用k8s部署访问方式非ollama运行deepseek模型linux下载ollama 下载后可存放其他服务器 curl -L https://ollama.com/download/ollama-linux-amd64.tgz -o ollama-linu…

2025年Android NDK超全版本下载地址

Unity3D特效百例案例项目实战源码Android-Unity实战问题汇总游戏脚本-辅助自动化Android控件全解手册再战Android系列Scratch编程案例软考全系列Unity3D学习专栏蓝桥系列ChatGPT和AIGC 👉关于作者 专注于Android/Unity和各种游戏开发技巧,以及各种资源分…

通信易懂唠唠SOME/IP——SOME/IP-SD服务发现阶段和应答行为

一 SOME/IP-SD服务发现阶划分 服务发现应该包含3个阶段 1.1 Initial Wait Phase初始等待阶段 初始等待阶段的作用 初始等待阶段是服务发现过程中的一个阶段。在这个阶段,服务发现模块等待服务实例的相关条件满足,以便继续后续的发现和注册过程。 对…

1. Kubernetes组成及常用命令

Pods(k8s最小操作单元)ReplicaSet & Label(k8s副本集和标签)Deployments(声明式配置)Services(服务)k8s常用命令Kubernetes(简称K8s)是一个开源的容器编排系统,用于自动化应用程序的部署、扩展和管理。自2014年发布以来,K8s迅速成为容器编排领域的行业标准,被…

Vue全流程--Vue2组件的理解第二部分

组件命名规则 好的命名规则可以省去很多不必要的麻烦,这个好习惯还是要养成的 一个单词组成: 第一种写法(首字母小写):school 第二种写法(首字母大写):School 多个单词组成: 第一种写法(kebab-case命名)&#xf…

【OS】AUTOSAR架构下的Interrupt详解(上篇)

目录 前言 正文 1.中断概念分析 1.1 中断处理API 1.2 中断级别 1.3 中断向量表 1.4 二类中断的嵌套 1.4.1概述 1.4.2激活 1.5一类中断 1.5.1一类中断的实现 1.5.2一类中断的嵌套 1.5.3在StartOS之前的1类ISR 1.5.4使用1类中断时的注意事项 1.6中断源的初始化 1.…

红包雨项目前端部分

创建项目 pnpm i -g vue/cli vue create red_pakage pnpm i sass sass-locader -D pnpm i --save normalize.css pnpm i --save-dev postcss-px-to-viewportpnpm i vantlatest-v2 -S pnpm i babel-plugin-import -Dhttps://vant.pro/vant/v2/#/zh-CN/<van-button click&…

深入理解k8s中的容器存储接口(CSI)

CSI出现的原因 K8s原生支持一些存储类型的PV&#xff0c;像iSCSI、NFS等。但这种方式让K8s代码与三方存储厂商代码紧密相连&#xff0c;带来不少麻烦。比如更改存储代码就得更新K8s组件&#xff0c;成本高&#xff1b;存储代码的bug还会影响K8s稳定性&#xff1b;K8s社区维护和…

DeepSeek回答禅宗三重境界重构交易认知

人都是活在各自心境里&#xff0c;有些话通过语言去交流&#xff0c;还是要回归自己心境内在的&#xff0c;而不是靠外在映射到股票和技术方法&#xff1b;比如说明天市场阶段是不修复不接力节点&#xff0c;这就是最高视角看整个市场&#xff0c;还有哪一句话能概括&#xff1…

简单说一下CAP理论和Base理论

CAP理论 什么是CAP 一致性 可用性 分区容错性&#xff1a;系统如果不能再时限内达成数据一致性&#xff0c;就说明发生了分区的情况 然后当前操作在C和A之间做出选择 例如我的网络出现问题了&#xff0c;但是我们的系统不能因为网络问题就直接崩溃 只要我们的分布式系统没…

13.PPT:诺贝尔奖【28】

目录 NO1234 NO567 NO8/9/10 NO11/12 NO1234 设计→变体→字体→自定义字体 SmartArt超链接新增加节 NO567 版式删除图片中的白色背景&#xff1a;选中图片→格式→删除背景→拖拉整个图片→保留更改插入→图表→散点图 &#xff1a;图表图例、网格线、坐标轴和图表标题…

RabbitMQ的安装

1、官网地址 下载地址&#xff1a;Installing RabbitMQ | RabbitMQhttp://www.rabbitmq.com/download.htmlhttp://www.rabbitmq.com/download.html RabbitMQ Documentation | RabbitMQhttps://www.rabbitmq.com/docshttps://www.rabbitmq.com/docs 2、Windows上安装 2.1 安装…

【LeetCode】152、乘积最大子数组

【LeetCode】152、乘积最大子数组 文章目录 一、dp1.1 dp1.2 简化代码 二、多语言解法 一、dp 1.1 dp 从前向后遍历, 当遍历到 nums[i] 时, 有如下三种情况 能得到最大值: 只使用 nums[i], 例如 [0.1, 0.3, 0.2, 100] 则 [100] 是最大值使用 max(nums[0…i-1]) * nums[i], 例…

【分布式理论六】分布式调用(4):服务间的远程调用(RPC)

文章目录 一、RPC 调用过程二、RPC 动态代理&#xff1a;屏蔽远程通讯细节1. 动态代理示例2. 如何将动态代理应用于 RPC 三、RPC 序列化四、RPC 协议编码1. 协议编码的作用2. RPC 协议消息组成 五、RPC 网络传输1. 网络传输流程2. 关键优化点 一、RPC 调用过程 RPC&#xff08…

Spring Task之Cron表达式

&#x1f31f; Spring Task高能预警&#xff1a;你以为的Cron表达式可能都是错的&#xff01;【附实战避坑指南】 开篇暴击&#xff1a;为什么你的定时任务总在凌晨3点翻车&#xff1f; “明明设置了0 0 2 * * ?&#xff0c;为什么任务每天凌晨3点执行&#xff1f;” —— 来…

第16章 Single Thread Execution设计模式(Java高并发编程详解:多线程与系统设计)

简单来说&#xff0c; Single Thread Execution就是采用排他式的操作保证在同一时刻只能有一个线程访问共享资源。 1.机场过安检 1.1非线程安全 先模拟一个非线程安全的安检口类&#xff0c;旅客(线程)分别手持登机牌和身份证接受工作人员的检查&#xff0c;示例代码如所示。…

OSPF基础(2):数据包详解

OSPF数据包(可抓包) OSPF报文直接封装在IP报文中&#xff0c;协议号89 头部数据包内容&#xff1a; 版本(Version):对于OSPFv2&#xff0c;该字段值恒为2(使用在IPV4中)&#xff1b;对于OSPFv3&#xff0c;该字段值恒为3(使用在IPV6中)。类型(Message Type):该OSPF报文的类型。…

MAC 安装mysql全过程记录

4.然后等待下载吧&#xff0c;&#xff08;下载中。。。。&#xff09;&#xff0c;好了&#xff0c;网速的问题&#xff0c;半个小时终于下载好了&#xff0c;开始安装吧。 5.得到如下安装包&#xff0c;mac下也是双击直接下载&#xff0c;来&#xff0c;我们来看看下载的过程…

神经网络常见激活函数 1-sigmoid函数

sigmoid 1 函数求导 sigmoid函数 σ ( x ) 1 1 e ( − x ) \sigma(x) \frac{1}{1e^{(-x)}} σ(x)1e(−x)1​ sigmoid函数求导 d d x σ ( x ) d d x ( 1 1 e − x ) e − x ( 1 e − x ) 2 ( 1 e − x ) − 1 ( 1 e − x ) 2 1 1 e − x − 1 ( 1 e − x ) 2 …