【扩散模型(六)】IP-Adapter 是如何训练的?2 源码篇(IP-Adapter Plus)

系列文章目录

  • 【扩散模型(二)】IP-Adapter 从条件分支的视角,快速理解相关的可控生成研究
  • 【扩散模型(三)】IP-Adapter 源码详解1-训练输入 介绍了训练代码中的 image prompt 的输入部分,即 img projection 模块。
  • 【扩散模型(四)】IP-Adapter 源码详解2-训练核心(cross-attention)详细介绍 IP-Adapter 训练代码的核心部分,即插入 Unet 中的、针对 Image prompt 的 cross-attention 模块。
  • 【扩散模型(五)】IP-Adapter 源码详解3-推理代码 详细介绍 IP-Adapter 推理过程代码。
  • 【可控图像生成系列论文(四)】IP-Adapter 具体是如何训练的?1公式篇
  • 本文则以 IP-Adapter Plus 训练代码为例,进行详细介绍。

文章目录

  • 系列文章目录
  • 整体训练框架
  • 一、训了哪些部分?
      • 第一块 - image_proj_model
      • 第二块 - adapter_modules
  • 二、训练目标


整体训练框架

在这里插入图片描述

一、训了哪些部分?

本文以原仓库 1 的 /path/IP-Adapter/tutorial_train_plus.py 为例,该文件为 SD1.5 IP-Adapter Plus 的训练代码。

从以下代码可以看出,IPAdapter 主要由 unet, image_proj_model, adapter_modules 3 个部分组成,而权重需要被优化的(训练到的)只有 ip_adapter.image_proj_model.parameters(), 和 ip_adapter.adapter_modules.parameters() 。

	ip_adapter = IPAdapter(unet, image_proj_model, adapter_modules, args.pretrained_ip_adapter_path)# optimizerparams_to_opt = itertools.chain(ip_adapter.image_proj_model.parameters(),  ip_adapter.adapter_modules.parameters())optimizer = torch.optim.AdamW(params_to_opt, lr=args.learning_rate, weight_decay=args.weight_decay)...# Prepare everything with our `accelerator`.ip_adapter, optimizer, train_dataloader = accelerator.prepare(ip_adapter, optimizer, train_dataloader)

第一块 - image_proj_model

在 IP-Adapter Plus 中,采用的是 Resampler 作为img embedding 到 ip_tokens 的映射网络,对图像(image prompt)中信息的抽取更加细粒度。其他模块都不需要梯度下降,如下代码所示。

	# freeze parameters of models to save more memoryunet.requires_grad_(False)vae.requires_grad_(False)text_encoder.requires_grad_(False)image_encoder.requires_grad_(False)#ip-adapter-plusimage_proj_model = Resampler(dim=unet.config.cross_attention_dim,depth=4,dim_head=64,heads=12,num_queries=args.num_tokens,embedding_dim=image_encoder.config.hidden_size,output_dim=unet.config.cross_attention_dim,ff_mult=4)...

第二块 - adapter_modules

Decoupled cross-attention 则在以下代码中进行初始化,关键是在特定的 unet 层中进行替换,详细位置可以参考前文中的图片,本文的重点是后续训练的实现。

	# init adapter modulesattn_procs = {}unet_sd = unet.state_dict()for name in unet.attn_processors.keys():cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dimif name.startswith("mid_block"):hidden_size = unet.config.block_out_channels[-1]elif name.startswith("up_blocks"):block_id = int(name[len("up_blocks.")])hidden_size = list(reversed(unet.config.block_out_channels))[block_id]elif name.startswith("down_blocks"):block_id = int(name[len("down_blocks.")])hidden_size = unet.config.block_out_channels[block_id]if cross_attention_dim is None:attn_procs[name] = AttnProcessor()else:layer_name = name.split(".processor")[0]weights = {"to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],"to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],}attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, num_tokens=args.num_tokens)attn_procs[name].load_state_dict(weights)unet.set_attn_processor(attn_procs)adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())

二、训练目标

每个 epoch 是遍历完一整个 dataset,我们直接从每个训练步的循环中来看:

  • latents 是通过 vae 将输入的 image prompt 压到了隐空间(latent space)中。
  • 准备相应的 noise 和 timesteps ,再通过 noise_scheduler 来制作出 noisy_latents。
        for step, batch in enumerate(train_dataloader):load_data_time = time.perf_counter() - beginwith accelerator.accumulate(ip_adapter):# Convert images to latent spacewith torch.no_grad():latents = vae.encode(batch["images"].to(accelerator.device, dtype=weight_dtype)).latent_dist.sample()latents = latents * vae.config.scaling_factor# Sample noise that we'll add to the latentsnoise = torch.randn_like(latents)bsz = latents.shape[0]# Sample a random timestep for each imagetimesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)timesteps = timesteps.long()# Add noise to the latents according to the noise magnitude at each timestep# (this is the forward diffusion process)noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
  • clip_images 和 drop_image_embed 是在准备数据的过程中,做了一个随机 drop 的方式进行数据增强,提升模型鲁棒性。
    • 数据增强:通过随机丢弃一些图像,模型被迫学习从剩余的图像中提取信息,这可以增加模型的泛化能力。
    • 模型鲁棒性:训练模型以处理不完整的数据,使其在实际应用中对缺失数据更加鲁棒。
     clip_images = []for clip_image, drop_image_embed in zip(batch["clip_images"], batch["drop_image_embeds"]):if drop_image_embed == 1:clip_images.append(torch.zeros_like(clip_image))else:clip_images.append(clip_image)clip_images = torch.stack(clip_images, dim=0)with torch.no_grad():image_embeds = image_encoder(clip_images.to(accelerator.device, dtype=weight_dtype), output_hidden_states=True).hidden_states[-2]with torch.no_grad():encoder_hidden_states = text_encoder(batch["text_input_ids"].to(accelerator.device))[0]

  1. https://github.com/tencent-ailab/IP-Adapter/tree/main ↩︎

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/411240.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Verilog 数字系统设计教程】Verilog 基础:硬件描述语言入门指南

目录 摘要 1. 引言 2. Verilog 历史与发展 3. Verilog 基本语法 4. Verilog 模块与端口 5. 组合逻辑与时序逻辑 6. 时钟域与同步设计 7. 测试与仿真 8. Verilog 高级特性 任务(Tasks) 函数(Functions) 多维数组 结构体…

【二叉树】OJ题目

🌟个人主页:落叶 目录 单值⼆叉树 【单值二叉树】代码 相同的树 【相同二叉树】代码 对称⼆叉树 【对称二叉树】代码 另一颗树的子树 【另一颗树的子树】代码 二叉树的前序遍历 【二叉树前序遍历】代码 二叉树的中序遍历 【二叉树中序遍历】…

【大模型】llama系列模型基础

前言:llama基于transformer架构,与GPT相似,只用了transformer的解码器部分。本文主要是关于llama,llama2和llama3的结构解读。 目录 1. llama1.1 整体结构1.2 RoPE1.3 SwiGLU 激活函数 2. llama22.2 GQA架构2.3 RLHF3. llama3 参考…

CAD中命令和系统变量

屏幕去除菜单全屏显示: ThisDrawing.SendCommand ("CLEANSCREENON ") 恢复原始:ThisDrawing.SendCommand ("CLEANSCREENOFF ") CAD中系统变量决定图形的基本设置。 第一个系统变量:uscicon vba代码如下: …

【Linux】——Rocky Linux配置静态IP

Rocky Linux配置静态IP Rocky Linux Rocky Linux 进入官网进行下载,下载版本自定义 官网link 获取ip地址 ip addr 获取服务器ip地址 进入网络配置文件目录: cd /etc/NetworkManager/system-connections/vi打开ens33.nmconnection 在IPv4下输入配置信…

Ubuntu美化为类Windows风格

博主的系统为 Ubuntu22.04 参考文献:How to Make Ubuntu Look Like Windows 11 | 22.04 GNOME 43 / 42 | Linux AF Tech 可能遇到的bug的解决方法:如何在 Linux 中安装和更改 GNOME 主题 先来一下视频演示: 下面正式开始安装。在主文件夹下打…

DWF 支持的 TON 链 Telegram 免费宠物游戏 Gatto_game,推出 “Paws Up! 世界锦标赛”

TON 链在这轮牛市里无疑是一匹脱缰的黑马,创造了一个又一个爆款,为持有者带来了不菲的收益。 Gatto_game 是一款 TON链 Tamagotchi 电子宠物风格的 P2E web3 游戏。可以通过喂养升级,参加比赛赚取 $TON 或者 $GTON ,或许就是下一个…

python解释器[源代码层面]

1 PyDictObject 在c中STL中的map是基于 RB-tree平衡二元树实现,搜索的时间复杂度为O(log2n) Python中PyDictObject是基于散列表(散列函数)实现,搜索时间最优为O(1) 1.1 散列列表 问题:散列冲突:多个元素计算得到相同的哈希值 …

软件设计原则之依赖倒置原则

依赖倒置原则(Dependency Inversion Principle, DIP)是软件设计中一个非常重要的原则,它属于面向对象设计的SOLID原则之一。这个原则的核心在于通过抽象来降低模块间的耦合度,使得系统更加灵活和可维护。 目录 依赖倒置原则的基本…

「软件测试」最全面试问题和回答,全文背熟不拿下offer算我输

3.公司这边测试人员分配比例 4.进入公司,我这边大概的工作安排 5,公司这么后续发展机会还有培养 6,有没有培训 7,面试没有回答上的问题,再去请教 2.5 你的职业发展规划和职业目标 根据公司况,个人原因…

【Spring Boot 3】自定义拦截器

【Spring Boot 3】自定义拦截器 背景介绍开发环境开发步骤及源码工程目录结构总结背景 软件开发是一门实践性科学,对大多数人来说,学习一种新技术不是一开始就去深究其原理,而是先从做出一个可工作的DEMO入手。但在我个人学习和工作经历中,每次学习新技术总是要花费或多或…

机器学习 之 DBSCAN算法 及实现

1.K-means 与 DBSCAN 的比较 K-means 和 DBSCAN 都是聚类算法,但它们之间有显著的区别: K-means: 基于中心点的方法,要求用户提前指定簇的数量。适用于球形簇,且簇大小相近。无法处理噪声数据和任意形状的簇。 DBSCAN…

SQLi-LABS靶场36-40通过攻略

less-36 这一关是转义函数换成了mysql_real_escape_string,绕过方法与35关一致 1.判断注入点 2.判断闭合方式 id1A0 -- 3.查看页面回显点 ?id-1%A0%27%20%20union%20select%201,2,3-- 4.查询数据库名 ?id-1%A0%27%20%20union%20select%201,database(),3-- 5.查询数据库的…

音视频封装格式之FLV

FLV(Flash Video)是一种常见的视频文件格式,FLV 格式最初是由 Adobe 公司开发的,旨在为网络视频提供一种高效、可扩展且易于流式传输的解决方案。随着在线视频的迅速发展,FLV 因其良好的兼容性和流式传输性能&#xff…

喜羊羊做Python二级(模拟考试--易错点)

今天距离Python二级考试,还有28天左右。坚持每天做几套试卷,保持记忆和手感。 个人在做题的过程中是先不断练习选择题。当你选择题不达标的时候,系统不会看大题(大概是觉得选择题都做的那么差,大题也不会那么好&#…

mac 虚拟机PD19运行E-prime实验遇到E-prime unable to set display mode:0*80004001问题解决

作者:50% unable to set display mode问题 总结: 1. 修改该Experiment的Devices中的Dispaly为640*680,Color Bit Depth设置为32。(这个分辨率仅限于学习用,实际实验应该还是在真机上) 2. 右键开始菜单中的E…

hadoop生态圈(四)- MapReduce

目录 MapReduce的基本原理 MapReduce流程图 Map阶段执行流程 Reduce阶段执行流程 Shuffle机制 MapReduce解决的是海量数据计算 MapReduce的思想核心是“分而治之”。就是把一个复杂的问题按一定的“分解”方法分为规模较小的若干部分,然后逐个解决,…

Meta AI动画生成功能的规模化部署与优化策略

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗?订阅我们的简报,深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同,从行业内部的深度分析和实用指南中受益。不要错过这个机会,成为AI领…

数据可视化大屏模板-美化图表

Axure作为一款强大的原型设计软件,不仅擅长构建交互式界面,更在数据可视化方面展现出了非凡的创意与实用性。今天,就让我们一起探索Axure设计的几款精美数据可视化大屏模板,感受数据之美。 立体图表的视觉冲击力 Axure的数据可视…

基于ROM的VGA显示

前言 在早期计算机和嵌入式系统中,图形显示和用户界面的实现主要依赖于硬件技术。VGA(视频图形阵列)标准在1980年代中期成为主流图形显示技术,其高分辨率和良好的兼容性使其在计算机显示领域中占据了重要地位。VGA标准支持640x480…