目录
- 1、cond_fn()函数代码
- 2、softmax与log_softmax函数
1、cond_fn()函数代码
def cond_fn(x, t, y=None):assert y is not Nonewith th.enable_grad():x_in = x.detach().requires_grad_(True)logits = classifier(x_in, t)log_probs = F.log_softmax(logits, dim=-1)selected = log_probs[range(len(logits)), y.view(-1)]return th.autograd.grad(selected.sum(), x_in)[0] * args.classifier_scale
cond_fn 的函数接受三个参数:x、t 和一个可选的 y。这个函数的主要目的是计算一个关于输入 x 的梯度,这个梯度是基于通过某个分类器 classifier 对 x 和 t 进行分类时,针对特定标签 y 的对数概率的梯度。
参数检查: assert y is not None 确保 y 不为 None。这是必要的,因为后续的操作依赖于 y 来选择对数概率。
启用梯度计算: with torch.enable_grad(): 确保在这个代码块内,所有需要梯度的操作都会被记录,以便后续可以计算梯度。不过,在 PyTorch 中,更常见的做法是直接设置张量的 .requires_grad 属性,因为 torch.enable_grad() 主要用于全局控制梯度记录,而在这个函数中,我们只需要对 x_in 进行这样的设置。
准备输入: x_in = x.detach().requires_grad_(True) 通过 detach() 创建一个 x 的新副本,并从计算图中分离出来,然后通过 requires_grad_(True) 允许 PyTorch 对这个副本的操作进行梯度追踪。
前向传播: 通过 classifier(x_in, t) 获取分类器的输出(logits),然后使用 F.log_softmax(logits, dim=-1) 计算对数概率。
选择特定标签的对数概率: selected = log_probs[range(len(logits)), y.view(-1)] 这行代码通过索引选择每个样本对应标签 y 的对数概率。y.view(-1) 确保 y 的形状与 logits 的最后一维相匹配。log_probs[range(len(logits)), y.view(-1)]:这行代码使用高级索引(advanced indexing)来从log_probs中选择元素,range(len(logits))值是行索引, y.view(-1)是列索引。具体来说,它首先通过range(len(logits))生成一个与样本数量相等的索引序列,然后使用y.view(-1)来提供每个样本对应真实类别的索引。因此,这行代码实际上是在选择每个样本对应其真实类别的对数概率值。
计算梯度: th.autograd.grad(selected.sum(), x_in)[0] * args.classifier_scale 计算 selected.sum()(即所有选中对数概率的和)关于 x_in 的梯度,并将这个梯度乘以一个缩放因子 args.classifier_scale。th.autograd.grad 返回的是一个元组,其中包含所有需要梯度的张量的梯度,这里我们只关心 x_in 的梯度,所以通过 [0] 索引获取。
总的来说,这个函数计算了分类器对于输入 x 和条件 t,在给定标签 y 下的对数概率梯度,并对这个梯度进行了缩放。这样的梯度可以用于各种优化或学习算法中,特别是在需要基于条件梯度的场景下。
2、softmax与log_softmax函数
当Softmax的输入比较大的时候,可能会产生上溢出,超出float的能表示范围;同理,当输入为负值且绝对值比较大的时候,分子分母会极小,接近0,从而导致下溢出。log_Softmax能够很好的解决溢出问题,且可以加快运算速度,提升数据稳定性。
softmax
log_softmax