深入浅出 diffusion(3):pytorch 实现 diffusion 中的 U-Net

导入python包

import mathimport torch
import torch.nn as nn
import torch.nn.functional as F

 silu激活函数

class SiLU(nn.Module):  # SiLU激活函数@staticmethoddef forward(x):return x * torch.sigmoid(x)

归一化设置

def get_norm(norm, num_channels, num_groups):if norm == "in":return nn.InstanceNorm2d(num_channels, affine=True)elif norm == "bn":return nn.BatchNorm2d(num_channels)elif norm == "gn":return nn.GroupNorm(num_groups, num_channels)elif norm is None:return nn.Identity()else:raise ValueError("unknown normalization type")

 计算时间步长的位置嵌入,一半为sin,一半为cos

class PositionalEmbedding(nn.Module):def __init__(self, dim, scale=1.0):super().__init__()assert dim % 2 == 0self.dim = dimself.scale = scaledef forward(self, x):device      = x.devicehalf_dim    = self.dim // 2emb = math.log(10000) / half_dimemb = torch.exp(torch.arange(half_dim, device=device) * -emb)# x * self.scale和emb外积emb = torch.outer(x * self.scale, emb)emb = torch.cat((emb.sin(), emb.cos()), dim=-1)return emb

 上下采样层设置

class Downsample(nn.Module):def __init__(self, in_channels):super().__init__()self.downsample = nn.Conv2d(in_channels, in_channels, 3, stride=2, padding=1)def forward(self, x, time_emb, y):if x.shape[2] % 2 == 1:raise ValueError("downsampling tensor height should be even")if x.shape[3] % 2 == 1:raise ValueError("downsampling tensor width should be even")return self.downsample(x)class Upsample(nn.Module):def __init__(self, in_channels):super().__init__()self.upsample = nn.Sequential(nn.Upsample(scale_factor=2, mode="nearest"),nn.Conv2d(in_channels, in_channels, 3, padding=1),)def forward(self, x, time_emb, y):return self.upsample(x)

 使用Self-Attention注意力机制,做一个全局的Self-Attention

class AttentionBlock(nn.Module):def __init__(self, in_channels, norm="gn", num_groups=32):super().__init__()self.in_channels = in_channelsself.norm = get_norm(norm, in_channels, num_groups)self.to_qkv = nn.Conv2d(in_channels, in_channels * 3, 1)self.to_out = nn.Conv2d(in_channels, in_channels, 1)def forward(self, x):b, c, h, w  = x.shapeq, k, v     = torch.split(self.to_qkv(self.norm(x)), self.in_channels, dim=1)q = q.permute(0, 2, 3, 1).view(b, h * w, c)k = k.view(b, c, h * w)v = v.permute(0, 2, 3, 1).view(b, h * w, c)dot_products = torch.bmm(q, k) * (c ** (-0.5))assert dot_products.shape == (b, h * w, h * w)attention   = torch.softmax(dot_products, dim=-1)out         = torch.bmm(attention, v)assert out.shape == (b, h * w, c)out         = out.view(b, h, w, c).permute(0, 3, 1, 2)return self.to_out(out) + x

 用于特征提取的残差结构

class ResidualBlock(nn.Module):def __init__(self, in_channels, out_channels, dropout, time_emb_dim=None, num_classes=None, activation=F.relu,norm="gn", num_groups=32, use_attention=False,):super().__init__()self.activation = activationself.norm_1 = get_norm(norm, in_channels, num_groups)self.conv_1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)self.norm_2 = get_norm(norm, out_channels, num_groups)self.conv_2 = nn.Sequential(nn.Dropout(p=dropout), nn.Conv2d(out_channels, out_channels, 3, padding=1),)self.time_bias  = nn.Linear(time_emb_dim, out_channels) if time_emb_dim is not None else Noneself.class_bias = nn.Embedding(num_classes, out_channels) if num_classes is not None else Noneself.residual_connection    = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()self.attention              = nn.Identity() if not use_attention else AttentionBlock(out_channels, norm, num_groups)def forward(self, x, time_emb=None, y=None):out = self.activation(self.norm_1(x))# 第一个卷积out = self.conv_1(out)# 对时间time_emb做一个全连接,施加在通道上if self.time_bias is not None:if time_emb is None:raise ValueError("time conditioning was specified but time_emb is not passed")out += self.time_bias(self.activation(time_emb))[:, :, None, None]# 对种类y_emb做一个全连接,施加在通道上if self.class_bias is not None:if y is None:raise ValueError("class conditioning was specified but y is not passed")out += self.class_bias(y)[:, :, None, None]out = self.activation(self.norm_2(out))# 第二个卷积+残差边out = self.conv_2(out) + self.residual_connection(x)# 最后做个Attentionout = self.attention(out)return out

 U-Net模型设计

class UNet(nn.Module):def __init__(self, img_channels, base_channels=128, channel_mults=(1, 2, 2, 2),num_res_blocks=2, time_emb_dim=128 * 4, time_emb_scale=1.0, num_classes=None, activation=F.silu,dropout=0.1, attention_resolutions=(1,), norm="gn", num_groups=32, initial_pad=0,):super().__init__()# 使用到的激活函数,一般为SILUself.activation = activation# 是否对输入进行paddingself.initial_pad = initial_pad# 需要去区分的类别数self.num_classes = num_classes# 对时间轴输入的全连接层self.time_mlp = nn.Sequential(PositionalEmbedding(base_channels, time_emb_scale),nn.Linear(base_channels, time_emb_dim),nn.SiLU(),nn.Linear(time_emb_dim, time_emb_dim),) if time_emb_dim is not None else None# 对输入图片的第一个卷积self.init_conv  = nn.Conv2d(img_channels, base_channels, 3, padding=1)# self.downs用于存储下采样用到的层,首先利用ResidualBlock提取特征# 然后利用Downsample降低特征图的高宽self.downs      = nn.ModuleList()self.ups        = nn.ModuleList()# channels指的是每一个模块处理后的通道数# now_channels是一个中间变量,代表中间的通道数channels        = [base_channels]now_channels    = base_channelsfor i, mult in enumerate(channel_mults):out_channels = base_channels * multfor _ in range(num_res_blocks):self.downs.append(ResidualBlock(now_channels, out_channels, dropout,time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation,norm=norm, num_groups=num_groups, use_attention=i in attention_resolutions,))now_channels = out_channelschannels.append(now_channels)if i != len(channel_mults) - 1:self.downs.append(Downsample(now_channels))channels.append(now_channels)# 可以看作是特征整合,中间的一个特征提取模块self.mid = nn.ModuleList([ResidualBlock(now_channels, now_channels, dropout,time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation,norm=norm, num_groups=num_groups, use_attention=True,),ResidualBlock(now_channels, now_channels, dropout,time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation, norm=norm, num_groups=num_groups, use_attention=False,),])# 进行上采样,进行特征融合for i, mult in reversed(list(enumerate(channel_mults))):out_channels = base_channels * multfor _ in range(num_res_blocks + 1):self.ups.append(ResidualBlock(channels.pop() + now_channels, out_channels, dropout, time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation, norm=norm, num_groups=num_groups, use_attention=i in attention_resolutions,))now_channels = out_channelsif i != 0:self.ups.append(Upsample(now_channels))assert len(channels) == 0self.out_norm = get_norm(norm, base_channels, num_groups)self.out_conv = nn.Conv2d(base_channels, img_channels, 3, padding=1)def forward(self, x, time=None, y=None):# 是否对输入进行paddingip = self.initial_padif ip != 0:x = F.pad(x, (ip,) * 4)# 对时间轴输入的全连接层if self.time_mlp is not None:if time is None:raise ValueError("time conditioning was specified but tim is not passed")time_emb = self.time_mlp(time)else:time_emb = Noneif self.num_classes is not None and y is None:raise ValueError("class conditioning was specified but y is not passed")# 对输入图片的第一个卷积x = self.init_conv(x)# skips用于存放下采样的中间层skips = [x]for layer in self.downs:x = layer(x, time_emb, y)skips.append(x)# 特征整合与提取for layer in self.mid:x = layer(x, time_emb, y)# 上采样并进行特征融合for layer in self.ups:if isinstance(layer, ResidualBlock):x = torch.cat([x, skips.pop()], dim=1)x = layer(x, time_emb, y)# 上采样并进行特征融合x = self.activation(self.out_norm(x))x = self.out_conv(x)if self.initial_pad != 0:return x[:, :, ip:-ip, ip:-ip]else:return x

参考链接:GitCode - 开发者的代码家园icon-default.png?t=N7T8https://gitcode.com/bubbliiiing/ddpm-pytorch/tree/master?utm_source=csdn_github_accelerator&isLogin=1

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/246534.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

对话泛能网程路:能源产业互联网,行至中程

泛能网的能源产业互联网的标杆价值还不仅于此。其在产业互联之外,也更大的特殊性在于其也更在成为整个碳市场的“辅助运营商”,包括电力、碳等一系列被泛能网帮助企业改造和沉淀的要素资产,都在构成着碳交易市场的未来底层。 这恰是产业互联…

Spark运行架构以及容错机制

Spark运行架构以及容错机制 1. Spark的角色区分1.1 Driver1.2 Excuter 2. Spark-Cluster模式的任务提交流程2.1 Spark On Yarn的任务提交流程2.1.1 yarn相关概念2.1.2 任务提交流程 2.2 Spark On K8S的任务提交流程2.2.1 k8s相关概念2.2.2 任务提交流程 3. Spark-Cluster模式的…

基于GPT3.5逆向 和 本地Bert-Vits2-2.3 的语音智能助手

文章目录 一、效果演示二、操作步骤三、架构解析 一、效果演示 各位读者你们好,我最近在研究一个语音助手的项目,是基于GPT3.5网页版的逆向和本地BertVits2-2.3 文字转语音,能实现的事情感觉还挺多,目前实现【无需翻墙&#xff0…

IS-IS:07 ISIS缺省路由

IS-IS 有两种缺省路由,第一种缺省路由是由 level-1 路由器在特定条件下自动产生的,它的下一跳是离它最近的 (cost 最小)level-1-2路由器。第二种缺省路由是 IS-IS 路由器上使用 default-route-advertise 命令产生并发布的。 本次实…

海康实时监控预览视频流接入web

我们采取的方案是后端获取视频流返回给前端,然后前端播放 海康开放平台海康威视合作生态致力打造一个能力开放体系、两个生态圈,Hikvision AI Cloud开放平台是能力开放体系的核心内容。它是海康威视基于多年在视频及物联网核心技术积累之上,…

Oracle RAC 集群的安装(保姆级教程)

文章目录 一、安装前的规划1、系统规划2、网络规划3、存储规划 二、主机配置1、Linux主机安装(rac01&rac02)2、配置yum源并安装依赖包(rac01&rac02)3、网络配置(rac01&rac02)4、存储配置&#…

SG-8506CA 可编程晶体振荡器 (SPXO)

输出: LV-PECL频率范围: 50MHz ~ 800MHz电源电压: 2.5V to 3.3V外部尺寸规格: 7.0 5.0 1.5mm (8引脚)特性:用户指定一个起始频率, 7-bit I2C 地址:用户可编程: I2C 接口:基频的高频晶体:低抖动PLL技术应用:OTN, BTS, 测试设备 规格(特征) *1 这包括初…

Linux常见指令汇总

目录 pwd√ ls√ cd√ 对文件的理解√ which√ alias√ touch√ linux系统目录结构√ mkdir / tree √ rmdir / rm √ man√ nano√ cp√ mv√ cat√ echo√ linux设计理念和三种重定向总结√ more/less√ head/tail√ wc√ uniq√ date / cal√ find√…

贾玲新片刚刚发出紧急声明,是什么情况。

♥ 为方便您进行讨论和分享,同时也为能带给您不一样的参与感。请您在阅读本文之前,点击一下“关注”,非常感谢您的支持! 文 |猴哥聊娱乐 编 辑|徐 婷 校 对|侯欢庭 1月22日,一则“多位明星参演的电影涉影视投资诈骗…

C# .Net6搭建灵活的RestApi服务器

1、准备 C# .Net6后支持顶级语句,更简单的RestApi服务支持,可以快速搭建一个极为简洁的Web系统。推荐使用Visual Studio 2022,安装"ASP.NET 和Web开发"组件。 2、创建工程 关键步骤如下: 包添加了“Newtonsoft.Json”&…

从零学习Hession RPC

为什么学习Hessian RPC? 存粹的RPC,只解决PRC的四个核心问题(1.网络通信2.协议 3.序列化 4.代理)Java写的HessianRPC落伍了,但是它的序列化方式还保存着,被Dubbo(Hessian Lite)使用。 被落伍,只…

Go 从标准输入读取数据

fmt.Scan系列 fmt.Scan函数定义如下: // Scan scans text read from standard input, storing successive space-separated values into successive arguments. // Newlines count as space. // It returns the number of items successfully scanned. // If tha…

Python使用pip命令安装外部库-项目内安装外部库-全局安装外部库

一、前言 在进行Python项目开发时需要安装一些外部库来扩展项目功能,因此需要了解pip命令的详细使用。 二、基本语法 1.安装库 pip install 包名 2.安装特定版本 pip install 包名版本号 3.升级库 pip install --upgrade 包名 4.卸载库 pip uninstall 包名 5.查看已…

3 JS类型 值和变量

计算机对value进行操作。 value有不同的类型。每种语言都有其自身的类型集合。编程语言的类型集是该编程语言的基本特性。 value需要保存一个变量中。 变量的工作机制是变成语言的另一个基本特性。 3.1概述和定义 JS类型分为: 原始类型和对象类型。 原始类型&am…

单片机学习笔记---矩阵键盘

目录 矩阵键盘的介绍 独立按键和矩阵按键的相同之处: 矩阵按键的扫描 代码演示 代码模块化移植 Keil自定义模板步骤: 代码编写 矩阵键盘就是开发板上右下角的这个模块 这一节的代码是基于上一节讲的LCD1602液晶显示屏驱动代码进行的 矩阵键盘的介…

阿里云负载均衡对接

1 、开通负载均衡产品 2 、ALB / NLB / CLB ALB: 应用型负载均衡 , 给定对应服务域名与当前实例DNS绑定之后即可使用 支持: HTTP/HTTPS/QUIC等应用层流量协议 NLB: 网络型负载均衡 支持: TCP / UDP / TCPSSL C…

Rabbitmq调用FeignClient接口失败

文章目录 一、框架及逻辑介绍1.背景服务介绍2.问题逻辑介绍 二、代码1.A服务2.B服务3.C服务 三、解决思路1.确认B调用C服务接口是否能正常调通2.确认B服务是否能正常调用A服务3.确认消息能否正常消费4.总结 四、修改代码验证1.B服务异步调用C服务接口——失败2.将消费消息放到C…

分布式id-Leaf算法

一、介绍 由美团开发,开源项目链接:https://github.com/Meituan-Dianping/Leaf Leaf同时支持号段模式和snowflake算法模式,可以切换使用。ID号码是趋势递增的8byte的64位数字,满足上述数据库存储的主键要求。 Leaf的snowflake模…

基于springboot的房屋交易系统

文章目录 项目介绍主要功能截图:部分代码展示设计总结项目获取方式 🍅 作者主页:超级无敌暴龙战士塔塔开 🍅 简介:Java领域优质创作者🏆、 简历模板、学习资料、面试题库【关注我,都给你】 &…

Unity 适配器模式(实例详解)

文章目录 简介1. **Input Adapter 示例**2. **Component Adapter 示例**3. **网络数据解析适配器**4. **物理引擎适配**5. **跨平台服务适配** 简介 Unity中的适配器模式(Adapter Pattern)主要用于将一个类的接口转换为另一个接口,以便于原本…