guided-diffusion 相比于improved-diffusion的sample增加的cond_fn()

目录

  • 1、cond_fn()函数代码
  • 2、softmax与log_softmax函数

1、cond_fn()函数代码

def cond_fn(x, t, y=None):assert y is not Nonewith th.enable_grad():x_in = x.detach().requires_grad_(True)logits = classifier(x_in, t)log_probs = F.log_softmax(logits, dim=-1)selected = log_probs[range(len(logits)), y.view(-1)]return th.autograd.grad(selected.sum(), x_in)[0] * args.classifier_scale

cond_fn 的函数接受三个参数:x、t 和一个可选的 y。这个函数的主要目的是计算一个关于输入 x 的梯度,这个梯度是基于通过某个分类器 classifier 对 x 和 t 进行分类时,针对特定标签 y 的对数概率的梯度。

参数检查: assert y is not None 确保 y 不为 None。这是必要的,因为后续的操作依赖于 y 来选择对数概率。
启用梯度计算: with torch.enable_grad(): 确保在这个代码块内,所有需要梯度的操作都会被记录,以便后续可以计算梯度。不过,在 PyTorch 中,更常见的做法是直接设置张量的 .requires_grad 属性,因为 torch.enable_grad() 主要用于全局控制梯度记录,而在这个函数中,我们只需要对 x_in 进行这样的设置。
准备输入: x_in = x.detach().requires_grad_(True) 通过 detach() 创建一个 x 的新副本,并从计算图中分离出来,然后通过 requires_grad_(True) 允许 PyTorch 对这个副本的操作进行梯度追踪。
前向传播: 通过 classifier(x_in, t) 获取分类器的输出(logits),然后使用 F.log_softmax(logits, dim=-1) 计算对数概率。
选择特定标签的对数概率: selected = log_probs[range(len(logits)), y.view(-1)] 这行代码通过索引选择每个样本对应标签 y 的对数概率。y.view(-1) 确保 y 的形状与 logits 的最后一维相匹配。log_probs[range(len(logits)), y.view(-1)]:这行代码使用高级索引(advanced indexing)来从log_probs中选择元素,range(len(logits))值是行索引, y.view(-1)是列索引。具体来说,它首先通过range(len(logits))生成一个与样本数量相等的索引序列,然后使用y.view(-1)来提供每个样本对应真实类别的索引。因此,这行代码实际上是在选择每个样本对应其真实类别的对数概率值。
计算梯度: th.autograd.grad(selected.sum(), x_in)[0] * args.classifier_scale 计算 selected.sum()(即所有选中对数概率的和)关于 x_in 的梯度,并将这个梯度乘以一个缩放因子 args.classifier_scale。th.autograd.grad 返回的是一个元组,其中包含所有需要梯度的张量的梯度,这里我们只关心 x_in 的梯度,所以通过 [0] 索引获取。
总的来说,这个函数计算了分类器对于输入 x 和条件 t,在给定标签 y 下的对数概率梯度,并对这个梯度进行了缩放。这样的梯度可以用于各种优化或学习算法中,特别是在需要基于条件梯度的场景下。

2、softmax与log_softmax函数

当Softmax的输入比较大的时候,可能会产生上溢出,超出float的能表示范围;同理,当输入为负值且绝对值比较大的时候,分子分母会极小,接近0,从而导致下溢出。log_Softmax能够很好的解决溢出问题,且可以加快运算速度,提升数据稳定性。
softmax
在这里插入图片描述

log_softmax
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/374452.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Transformer特辑

https://github.com/LongxingTan/Machine-learning-interview 模型结构 基本单元:token_embedding positional encoding, encoder, token_embedding positional encoding, decoderencoder: (self-attention, skip-connect, ln), (ffn, skip-connect, ln)decoder:…

顶顶通呼叫中心中间件实现随时启动和停止质检(mod_cti基于FreeSWITCH)

文章目录 前言联系我们拨号方案启动停止ASR执行FreeSWITCH 命令接口启动ASR接口停止ASR接口 通知配置cti.json配置质检结果写入数据库 前言 顶顶通呼叫中心中间件的实时质检功能是由两个模块组成:mod_asr 和 mod_qc。 mod_asr:负责调用ASR将用户们在通…

二、Qemu+Vscode调试内核

编译内核、busybox、配置Qemu参考:Qemu调试内核 一、修改启动脚本 1、修改Qemu启动脚本 #! /bin/shqemu-system-aarch64 \-machine virt,virtualizationtrue,gic-version3 \-nographic \-m size1024M \-cpu cortex-a72 \-smp 2 \-kernel Image \-drive formatraw…

写作遇到AI痕迹困扰?这里有降低AI痕迹的实用技巧

请问有没有什么免费的论文降重网站? 副本 一句“知网是什么”,我查重查了千百遍。天临六年五月,大家的论文差不多都到了查重的阶段。好不容易论文写(shui)完了,一看查重报告,满屏的红字让人心心…

Linux--线程ID封装管理原生线程

目录 1.线程的tid(本质是线程属性集合的起始虚拟地址) 1.1pthread库中线程的tid是什么? 1.2理解库 1.3phtread库中做了什么? 1.4线程的tid,和内核中的lwp 1.5线程的局部存储 2.封装管理原生线程库 1.线程的tid…

java设计模式(十五)命令模式(Command Pattern)

1、模式介绍: 命令模式(Command Pattern)是一种行为设计模式,其主要目的是将请求封装成一个对象,从而允许使用不同的请求、队列或者日志来参数化其他对象。这种模式使得命令的请求者和实现者解耦。 2、应用场景&…

服务启动何时触发 Nacos 的注册流程?

前言: 前面的系列文章让我们对 Nacos 有了一个基本了解,并知道了如何去试用 Nacos 作为注册中心和配置中心,本篇我们将从源码层面去分析 Nacos 的服务注册流程。 Nacos 系列文章传送门: Nacos 初步认识和 Nacos 部署细节 Naco…

C++基础学习笔记

1.命名空间(namespace) 1.什么是命名空间&命名空间的作用 1.在C/C中,变量、函数、类都是大量存在的,这些变量等的名称将都存在于全局作用域中,就会导致很多的命名冲突等。使用命名空间的目的就是对标识符的名称进行本地化,以…

短视频矩阵系统全解析:让获客变得更简单

随着数字媒体的迅猛发展,短视频已成为人们生活中不可或缺的一部分。对于企业而言,如何有效利用短视频平台吸引目标用户,实现高效获客,成为了一个亟待解决的问题。本文将全面解析短视频矩阵系统,带您领略其独特魅力&…

广度优先(BFS)

先看一道简单的题&#xff0c;迷宫问题&#xff1a; 洛谷P1746 离开中山路&#xff1a;P1746 离开中山路 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn) #include<iostream> #include<cstring> #include<queue> #include <utility> #define N 1002 …

深度学习的数学PDF

链接: https://pan.baidu.com/s/1_jScZ7dcyAWGqbrad6bbCQ?pwd9gj9 提取码: 9gj9 复制这段内容后打开百度网盘手机App&#xff0c;操作更方便哦

最简单的vue3组件之间传值

localStorage 是 HTML5 引入的一个 Web Storage API 的一部分&#xff0c;它允许网页在用户的浏览器上存储数据。localStorage 提供了一种持久化的本地存储方案&#xff0c;数据不会因为浏览器关闭而丢失&#xff0c;除非用户或脚本显式地删除它们。 localStorage 是一种非常实…

VSCode神仙插件——通义灵码 (AI编程助手)

1、安装&登录插件 安装时,右下角会有弹窗,让你登录该软件 同意登录后,会跳转浏览器页面 VSCode右下角出现如下图标即登录成功 2、使用 (1)点击左侧栏中的如下图标,打开通义灵码,可以进行智能问答 (2) 选中代码,右键 但是,上述所有的操作会在左侧问答栏中提供答案,并无法直…

认识并理解webSocket

今天逛牛客&#xff0c;看到有大佬分享说前端面试的时候遇到了关于webSocket的问题&#xff0c;一看自己都没见过这个知识点&#xff0c;赶紧学习一下&#xff0c;在此记录&#xff01; WebSocket 是一种网络通信协议&#xff0c;提供了全双工通信渠道&#xff0c;即客户端和服…

31. 1049. 最后一块石头的重量 II, 494.目标和,474.一和零

class Solution { public:int lastStoneWeightII(vector<int>& stones) {int sum 0;for(int stone : stones) sum stone;int bagSize sum /2;vector<int> dp(bagSize 1, 0);for(int i 0; i < stones.size(); i){ //遍历物品for(int j bagSize; j >…

LLMs的基本组成:向量、Tokens和嵌入

编者按&#xff1a;随着人工智能技术的不断发展&#xff0c;大模型&#xff08;语言、视觉&#xff0c;或多模态模型&#xff09;已成为当今AI应用的核心组成部分。这些模型具有处理和理解自然语言等模态输入的能力&#xff0c;推动了诸如聊天机器人、智能助手、自动文本生成等…

Android初学者书籍推荐

书单 1.《Android应用开发项目式教程》&#xff0c;机械工业出版社&#xff0c;2024年出版2.《第一行代码Android》第二版3.《第一行代码Android》第三版4.《疯狂Android讲义》第四版5.《Android移动应用基础教程&#xff08;Android Studio 第2版&#xff09;》 从学安卓到用安…

Node.js如何在Windows安装?

文章目录 主要特点&#xff1a;使用场景&#xff1a;安装方法验证是否安装成功 Node.js 是一个开源、跨平台的JavaScript运行环境&#xff0c;由Ryan Dahl于2009年创建。它允许开发者在服务器端运行JavaScript代码。Node.js 基于Chrome V8 JavaScript引擎构建&#xff0c;其设计…

项目/代码规范与Apifox介绍使用

目录 目录 一、项目规范&#xff1a; &#xff08;一&#xff09;项目结构&#xff1a; &#xff08;二&#xff09;传送的数据对象体 二、代码规范&#xff1a; &#xff08;一&#xff09;数据库命名规范&#xff1a; &#xff08;二&#xff09;注释规范&#xff1a; …

关于CANNM PassiveMode

Passive Mode的要求 根据上图CANNM的规范可知&#xff1a; 处于Passive Mode的网络节点只能接收网络管理PDU&#xff0c;不能发送网络管理PDU。Passive Mode由CanNmPassiveModeEnable参数静态配置。如果一个ECU包含多个节点&#xff0c;那么所有的节点要么都是Passive Mode要么…