[Deep Learning] 深度学习中常用函数的整理与介绍(pytorch为例)

文章目录

  • 深度学习中常用函数的整理与介绍
    • 常见损失函数
      • 1. L2_loss | nn.MSELoss()
          • 公式表示:
          • 特点:
          • 应用:
          • 缺点:
          • 主要参数:
          • 示例用法:
          • 注意事项:
      • 2. L1 Loss | nn.L1Loss
          • 数学定义:
          • 特点:
          • 应用:
          • PyTorch中的实现:
          • 注意事项:
      • 3. Huber's Robust Loss | torch.nn.SmoothL1Loss
          • 数学定义:
          • 特点:
          • 应用:
          • PyTorch中的实现:
          • 注意事项:
      • 4. 交叉熵损失函数 | nn.CrossEntropyLoss
    • 常见激活函数
      • 1. softmax()
        • 数学定义:
          • 特点:
          • 应用:
          • 损失函数结合:
          • 示例:
          • 注意事项:
    • 数据处理
      • 1. data.DataLoader
      • 2. next()
      • 3. trans = transforms.ToTensor() | 数据预处理工具
          • 功能:
          • 使用场景:
          • 示例用法:
          • 注意事项:
    • 网络结构和模型训练
      • 1. torch的reshape函数及其特例reshape((-1,1))
      • 2. torch的detach函数
      • 3. torch的matual函数
      • 4. torch的no_grad()和grad.zero()
      • 5. nn.Sequential()
      • 6. nn.linear
          • 关键参数:
          • 主要方法:
          • 示例用法:
          • 权重和偏置:
          • 权重初始化:
          • 训练过程中的注意事项:
      • 7. 模型权重参数:data.normal_()与data.fill_()
      • 8. 【简述神经网络模型的训练过程】trainer = torch.optim.SGD(net.parameters(), lr=0.03)
    • 查缺补漏的小细节
      • 1. d2l.use_svg_display()
      • 2. 图片样本可视化示例函数
          • 参数说明:
          • 函数工作流程:
          • 代码示例:
          • 注意事项:
          • ax是什么?

  • 本文主要用于整理和介绍在使用pytorch进行深度学习编程的过程中所遇到的常用的函数的解析。
  • 本文也涉及部分和实现一个完整的深度学习的流程相关的其它函数的整理与解析。
  • 本文所涉及到的相关的代码材料等,主要来自于李沐老师的在线课程《动手学深度学习》。

深度学习中常用函数的整理与介绍

常见损失函数

1. L2_loss | nn.MSELoss()

在PyTorch中,nn.MSELoss()是均方误差(Mean Squared Error Loss)的实现,它是一种常用的损失函数,特别是在回归问题中。均方误差损失计算预测值与真实值之间差异的平方的平均值。这种损失函数的目标是最小化预测值与真实值之间的平方差。

公式表示:

假设有一组预测值hat-y和相应的真实值y,MSELoss的计算公式为:


其中,N是样本的数量。

特点:
  1. 非负性:均方损失总是非负的,因为差的平方不可能是负数。
  2. 可微性:均方损失是可微的,这使得它适合于梯度下降等优化算法。
  3. 惩罚大误差:均方损失对大误差的惩罚比小误差更重,因为误差的平方会放大误差的影响。
  4. 易于计算:均方损失的计算相对简单,易于实现。
应用:

均方损失广泛用于线性回归问题,以及任何需要预测连续值的场景。它是许多机器学习算法的默认损失函数,例如线性回归和一些深度学习模型。

缺点:
  • 对异常值敏感:由于均方损失会平方误差,因此它对异常值(outliers)非常敏感。异常值的存在可能会对模型训练产生较大影响。
  • 需要数据是正态分布:在某些情况下,均方损失假设数据是正态分布的,这可能不适用于所有类型的数据。

在实践中,根据问题的性质和数据的特点,可能会选择其他类型的损失函数,如平均绝对误差(Mean Absolute Error, MAE)或Huber损失等,以解决均方损失的一些局限性。

主要参数:
  • size_average (bool, 可选): 指定是否将损失除以总元素数量以获得平均损失。在PyTorch 1.0.0之后,默认行为是True,即损失会被平均化。
  • reduce (bool, 可选): 指定是否将损失减少到一个标量值。如果设置为False,将返回每个元素的损失。
  • reduction (string, 可选): 指定损失的减少策略。可以是'none''mean''sum''none'不进行任何减少,'mean'计算平均损失,'sum'计算总和。
示例用法:
import torch
import torch.nn as nn# 创建MSELoss实例
criterion = nn.MSELoss()# 假设我们有预测值和真实值
y_pred = torch.tensor([0.5, 2.0, 1.5], dtype=torch.float32)
y_true = torch.tensor([1.0, 2.0, 1.0], dtype=torch.float32)# 计算损失
loss = criterion(y_pred, y_true)
print(loss)  # 输出损失值

在这个例子中,criterion是一个均方误差损失函数的实例。我们用它来计算预测值y_pred和真实值y_true之间的损失。

注意事项:
  • 当使用nn.MSELoss()时,确保预测值和真实值的维度是匹配的。
  • 损失函数通常在模型训练的反向传播阶段使用,以计算梯度并更新模型的权重。
  • 在某些情况下,你可能需要根据特定的需求调整reduction参数,例如,如果你想要对每个样本单独计算损失而不是计算平均损失。

nn.MSELoss是评估模型性能和指导模型训练的重要工具,特别是在需要预测连续值的任务中。

2. L1 Loss | nn.L1Loss

L1损失函数,也称为曼哈顿距离损失或平均绝对偏差损失,是一种常用的损失函数,特别是在回归问题中。L1损失测量的是模型预测值与实际观测值之间差的绝对值的总和。

数学定义:

在这里插入图片描述

特点:
  1. 对异常值不敏感:与均方误差(MSE)损失相比,L1损失对异常值(outliers)的影响较小,因为它不平方误差。
  2. 稀疏性:L1损失可以导致模型参数的稀疏性,即许多参数可能会变为零,这在某些情况下有助于特征选择。
  3. 非光滑:L1损失函数在y_i=hat{y_i} 时不可微,因此在优化时可能需要使用特殊的技术,如次梯度方法。
应用:

L1损失在以下领域有广泛应用:

  • 回归问题:用于预测连续值,特别是在异常值可能影响模型性能时
  • 特征选择:由于其稀疏性,L1损失可以用于促进特征选择
PyTorch中的实现:

在PyTorch中,可以使用nn.L1Loss来实现L1损失函数。以下是如何使用nn.L1Loss的一个示例:

import torch
import torch.nn as nn# 创建L1损失函数实例
criterion = nn.L1Loss()# 假设有以下预测值和真实值
y_pred = torch.tensor([0.5, 2.0, 1.5], dtype=torch.float32)
y_true = torch.tensor([1.0, 2.0, 1.0], dtype=torch.float32)# 计算L1损失
loss = criterion(y_pred, y_true)
print(loss)  # 输出损失值

在这个例子中,criterion是一个L1损失函数的实例,用于计算预测值y_pred和真实值y_true之间的L1损失。

注意事项:
  • L1损失在某些情况下可能不是凸函数,这可能导致局部最小值的问题。
  • 由于L1损失不可微的特性,在优化过程中可能需要使用特殊的优化算法,如坐标下降法或使用次梯度的随机梯度下降。

L1损失是一种简单而有效的损失函数,特别是在对异常值不敏感或需要特征选择的应用中。然而,由于其非光滑性,在实际应用中可能需要特别注意优化策略的选择。

3. Huber’s Robust Loss | torch.nn.SmoothL1Loss

Huber损失(Huber Loss),也称为Huber损失函数或最小化加权残差,是一种在回归问题中使用的损失函数。它是一种结合了L1损失(绝对误差)和L2损失(平方误差)的特点的损失函数,旨在减少异常值对模型训练的影响,同时保持平方损失的可微性和凸性。

数学定义:

在这里插入图片描述

特点:
  1. 对异常值的鲁棒性:Huber损失对异常值具有较好的鲁棒性,因为它在误差较大时转变为线性损失。
  2. 可微性和凸性:在误差较小时,Huber损失是可微的,这使得它可以使用梯度下降方法进行优化。
  3. 平滑性:Huber损失在 ( \delta ) 附近是平滑的,这有助于避免优化过程中的不稳定性。
应用:

Huber损失在以下领域有广泛应用:

  • 回归问题:特别是当数据中存在异常值或噪声较大时。
  • 机器学习:在需要平衡模型对异常值的敏感度和优化效率时。
PyTorch中的实现:

在PyTorch中,可以通过自定义损失函数或使用torch.nn.SmoothL1Loss来实现Huber损失。torch.nn.SmoothL1Loss是PyTorch中的一个损失函数,它在小于某个阈值时表现为L2损失,在大于阈值时表现为L1损失,与Huber损失类似。

以下是如何使用torch.nn.SmoothL1Loss的一个示例:

import torch
import torch.nn as nn# 创建Smooth L1损失函数实例,可以设置beta参数来控制阈值
criterion = nn.SmoothL1Loss(beta=1.0)# 假设有以下预测值和真实值
y_pred = torch.tensor([0.5, 2.0, 1.5], dtype=torch.float32)
y_true = torch.tensor([1.0, 2.0, 1.0], dtype=torch.float32)# 计算Huber损失
loss = criterion(y_pred, y_true)
print(loss)  # 输出损失值

在这个例子中,criterion是一个Smooth L1损失函数的实例,它类似于Huber损失函数,用于计算预测值y_pred和真实值y_true之间的损失。

注意事项:
  • Huber损失的 ( \delta ) 参数需要根据具体问题进行调整,以平衡对异常值的敏感度和模型的优化效率。
  • 在PyTorch中,nn.SmoothL1Lossbeta参数与Huber损失的 ( \delta ) 参数作用相似,但具体实现可能略有不同。

Huber损失是一种在特定情况下比L1或L2损失更优的损失函数,特别是在数据包含异常值或需要对模型的稳健性有较高要求的场景中。

4. 交叉熵损失函数 | nn.CrossEntropyLoss

交叉熵损失函数(Cross-Entropy Loss)是深度学习中用于分类问题的一种损失函数,特别是在多分类问题中。它衡量的是模型预测的概率分布与真实标签的概率分布之间的差异。

在PyTorch中,nn.CrossEntropyLoss是一个实现交叉熵损失的类,它内部结合了log_softmaxnll_loss(负对数似然损失),使得使用时更加方便和高效。这个类接受的输入是未经softmax转换的原始分数(logits),然后它会应用softmax函数将这些分数转换成概率分布,再计算交叉熵损失。目标标签(target)应该是类别的索引,而不是one-hot编码的形式 。

使用nn.CrossEntropyLoss时,可以指定一些参数,例如:

  • weight:可以给不同类别赋予不同的权重,以解决类别不平衡问题。
  • ignore_index:在训练时忽略的特定类别索引,该类别的数据不会被计算在损失函数中。
  • reduction:指定损失计算方式,可以是'mean'(默认,损失的平均值)、'sum'(所有损失的总和)或'none'(不进行任何减少操作)。

在实际应用中,PyTorch还提供了F.cross_entropy函数,它是nn.CrossEntropyLoss的函数式接口,使用方式如下:

import torch
import torch.nn.functional as F# 假设有以下logits和target
logits = torch.randn(4, 3)  # 4个样本,3个类别
target = torch.tensor([0, 2, 1, 0])# 计算交叉熵损失
loss = F.cross_entropy(logits, target)

此外,交叉熵损失函数在二分类问题中也有应用,此时通常使用nn.BCELossnn.BCEWithLogitsLossnn.BCEWithLogitsLoss是将Sigmoid激活函数和BCELoss结合在一起的损失函数,它在数值上更稳定 。

在多分类问题中,交叉熵损失函数的计算可以表示为:
在这里插入图片描述
总结来说,交叉熵损失函数是深度学习中用于分类问题的关键损失函数之一,它通过最小化预测概率分布和真实标签分布之间的差异来训练模型。在PyTorch中,nn.CrossEntropyLoss提供了一个简单而强大的接口来实现这一功能 。

常见激活函数

1. softmax()

Softmax 是一个在机器学习和深度学习中常用的函数,特别是在处理分类问题时它通常用于多分类模型的输出层,将模型输出的原始分数(也称为 logits)转换为概率分布。

数学定义:

在这里插入图片描述

特点:
  1. 归一化:Softmax 函数的输出是一个概率分布,所有输出值都是非负的,并且总和为 1。
  2. 不同类别间差异放大:Softmax 函数通过指数化操作放大了原始分数之间的差异,使得最高分数的类别与其他类别之间的差异更加明显。
  3. 可微性:Softmax 函数是可微的,这使得它适用于梯度下降和其他基于梯度的优化算法。
应用:

Softmax 函数在以下领域有广泛应用:

  • 多分类问题:在神经网络的输出层,用于将原始分数转换为概率分布,以进行类别预测。
  • 概率模型:在概率模型中,Softmax 可以用来从一组潜在类别中选择概率最高的类别。
损失函数结合:

Softmax 函数经常与交叉熵损失(Cross-Entropy Loss)结合使用。交叉熵损失衡量的是模型预测的概率分布与真实标签的概率分布之间的差异。在多分类问题中,这种组合是非常常见和有效的。

示例:

在这里插入图片描述

注意事项:
  • 当 logits 中有非常大的正值或负值时,Softmax 函数可能会遇到数值稳定性问题。通常通过减去最大 logits 来解决这个问题,这被称为稳定 softmax。
  • 在某些情况下,如果类别的个数非常多,Softmax 函数的计算可能会变得低效。这时可以考虑使用其他近似方法或算法。

数据处理

1. data.DataLoader

在深度学习中,DataLoader是PyTorch中用于加载数据集的一个类,它提供了一种简便的方式来迭代数据集。DataLoader可以与PyTorch的Dataset类一起使用,以实现高效的数据加载和批处理

以下是DataLoader的一些关键特性:

  1. 批处理DataLoader可以将数据集分割成多个批次(batches),每个批次包含固定数量的样本。这有助于利用GPU的并行计算能力。

  2. 数据打乱DataLoader提供了shuffle参数,可以在每个epoch开始时对数据进行随机打乱,这有助于提高模型训练的泛化能力。

  3. 多线程加载:通过设置num_workers参数,DataLoader可以在多个进程中并行加载数据,从而加快数据加载速度。

  4. 数据采样:可以使用sampler参数指定如何从数据集中抽取样本,例如,可以定义一个自定义的采样策略。

  5. 数据预处理:在使用DataLoader之前,通常需要对数据进行预处理,如归一化、数据增强等。这些预处理步骤可以在Dataset类中实现。

  6. 持久化工作器:从PyTorch 1.8开始,DataLoader支持persistent_workers参数,允许工作器在数据加载过程中持续运行,而不是在每次迭代后重新启动。

下面是一个使用DataLoader的简单示例:

from torch.utils.data import DataLoader, Datasetclass MyDataset(Dataset):def __init__(self, data):self.data = datadef __len__(self):return len(self.data)def __getitem__(self, idx):# 假设数据是一个简单的数字列表return self.data[idx]# 创建数据集实例
dataset = MyDataset([1, 2, 3, 4, 5])# 创建DataLoader实例
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)# 使用DataLoader迭代数据
for batch in dataloader:print(batch)

在这个例子中,我们首先定义了一个简单的MyDataset类,它继承自Dataset。然后,我们创建了一个DataLoader实例,指定了批量大小为2,并设置了shuffle=True来打乱数据。最后,我们通过迭代DataLoader来获取每个批次的数据。

DataLoader是深度学习训练过程中不可或缺的工具,它提供了一种高效、灵活的方式来处理数据。

2. next()

在Python中,next() 函数用于从一个迭代器中获取下一个元素。迭代器是一个可以记住遍历位置的对象,它有一个 __next__() 方法,每次调用这个方法都会返回迭代器的下一个元素。当迭代器中没有更多元素时,__next__() 方法会抛出一个 StopIteration 异常。

以下是 next() 函数的一些用法:

  1. 获取下一个元素:

    my_list = [1, 2, 3, 4]
    my_iterator = iter(my_list)  # 创建迭代器
    print(next(my_iterator))  # 输出 1
    print(next(my_iterator))  # 输出 2
    
  2. 指定默认值:
    next() 函数还可以接受一个可选的默认值参数,当迭代器中没有更多元素时,将返回这个默认值而不是抛出异常。

    print(next(my_iterator, "迭代器已结束"))  # 输出 3
    print(next(my_iterator, "迭代器已结束"))  # 输出 4
    print(next(my_iterator, "迭代器已结束"))  # 输出 "迭代器已结束"
    
  3. 在生成器中使用:
    当使用生成器时,next() 可以用来获取生成器的下一个值。

    def my_generator():yield 1yield 2yield 3gen = my_generator()
    print(next(gen))  # 输出 1
    print(next(gen))  # 输出 2
    print(next(gen))  # 输出 3
    # 下面这行将抛出 StopIteration 异常
    # print(next(gen))
    
  4. 在数据加载中使用:
    在深度学习中,next() 经常与 DataLoader 结合使用,用于从一个批次到另一个批次的迭代。

    for data in dataloader:# 处理数据pass
    # 如果需要手动获取下一个批次
    try:next_batch = next(dataloader)# 处理 next_batch
    except StopIteration:print("所有批次已处理完毕")
    

next() 是一个非常有用的内置函数,它在Python编程中广泛应用,特别是在处理迭代器和生成器时。

3. trans = transforms.ToTensor() | 数据预处理工具

在PyTorch中,transforms.ToTensor()torchvision.transforms模块提供的一个变换函数,它用于将PIL图像或Numpy数组转换为torch.FloatTensor类型。这种转换通常在图像数据加载和预处理步骤中使用,以确保数据格式与深度学习模型的输入要求一致。

功能:
  • 将PIL图像或Numpy数组中的数值从[0, 255]归一化到[0.0, 1.0]
  • 【通道前移】 将图像数据从形状(H, W, C)转换为(C, H, W),其中H是高度,W是宽度,C是通道数。这种形状转换是因为PyTorch期望输入数据的通道维度在前
使用场景:

transforms.ToTensor()通常用在数据加载和预处理的流水线中,与其它的transforms一起使用,例如调整大小、裁剪、归一化等。

示例用法:
from torchvision import transforms
from PIL import Image# 定义变换序列
transform = transforms.Compose([transforms.Resize((100, 100)),  # 调整图像大小到100x100transforms.ToTensor()           # 将图像转换为Tensor
])# 加载一个PIL图像
img = Image.open("path_to_image.jpg")# 应用变换
img_tensor = transform(img)
print(img_tensor.size())  # 输出:torch.Size([3, 100, 100])

在这个例子中,首先定义了一个变换序列,其中包括调整图像大小和转换为张量。然后,使用Image.open加载了一个PIL图像,并通过调用transform对图像应用了定义的变换。

注意事项:
  • transforms.ToTensor()假定输入图像的通道顺序是RGB,如果你的数据是灰度图像,可能需要先转换为RGB。
  • 归一化到[0.0, 1.0]是通过除以255实现的,这意味着输入数据类型应该是整型。如果输入数据已经是浮点型,并且像素值在[0.0, 1.0]范围内,则不需要这个变换。

transforms.ToTensor()是PyTorch图像处理中非常基础且常用的一个函数,它帮助将图像数据转换为适合神经网络处理的格式。

网络结构和模型训练

1. torch的reshape函数及其特例reshape((-1,1))

  • 在PyTorch中,reshape函数用于改变张量的形状而不改变其数据。reshape函数的参数可以是整数或者整数的序列,表示新的张量形状。

  • 当你使用reshape((-1, 1))时,-1是一个特殊的参数,它告诉PyTorch自动计算这个维度的大小,以便保持张量的元素总数不变。具体来说:

  • 如果原始张量的形状是(a, b, c, …),使用reshape((-1, 1))后,新的张量形状将变为(a * b * c * …, 1)。
    这种用法通常用于将一维张量转换为两维张量,其中第二维的大小为1。这在某些深度学习模型中是必要的,比如在使用全连接层(nn.Linear)之前,需要将数据转换为二维张量,其中第一维是批次大小,第二维是特征数量。

  • 例如,假设有一个形状为(3,)的一维张量,使用reshape((-1, 1))后,它将变为形状(3, 1)的二维张量。这在处理批处理数据时非常有用,因为大多数深度学习模型都期望输入数据是二维的。

  • 可以简单理解为把一个行向量处理成为了一个列向量

2. torch的detach函数

  • 在PyTorch中,detach函数用于将张量从当前的计算图中分离出来,使其不再参与梯度计算。这对于多种场景非常有用,以下是一些例子来解释detach函数的功能:
  1. 避免梯度计算
    当你在训练神经网络时,通常需要计算梯度以更新模型的权重。但是,有时候你可能需要计算一些中间结果,但不希望这些结果影响梯度计算。使用detach可以创建张量的副本,该副本不会在反向传播中计算梯度。
import torchx = torch.tensor([1.0, 2.0], requires_grad=True)
y = x * 2
z = y.detach()  # z是一个不参与梯度计算的新张量
  • 在上面的例子中,z是从y分离出来的,即使x的requires_grad属性为True,z也不会在反向传播中计算梯度。
  1. 计算图的隔离
  • 在复杂的模型中,你可能需要在不同的计算图中使用相同的数据。使用detach可以确保数据在不同的计算图中是独立的。
x = torch.randn(3, 3, requires_grad=True)
y = x @ x.t()  # 第一个计算图
result = y.detach()  # 从第一个计算图中分离出来
z = result + x  # 第二个计算图
  • 在Python中,@符号是矩阵乘法运算符,它允许你进行两个矩阵或向量的点积(也称为内积或标量积)。当你使用x @ x.t()时,你实际上是在进行矩阵x与其转置x.t()的矩阵乘法。
  • 这里的.t()方法是PyTorch中用来获取张量转置的方法。如果x是一个二维矩阵,x.t()会返回x的转置矩阵,即将x的行变成列,列变成行。
  1. 避免梯度累积
  • 在某些情况下,你可能需要在循环中使用相同的变量,但又不希望梯度累积。detach可以用来重置梯度。
x = torch.tensor([1.0, 2.0], requires_grad=True)
for _ in range(10):y = x * 2# 执行一些操作...x = x.detach()  # 重置x的梯度
  1. 与CPU/GPU交互
  • 当你需要将张量从GPU移动到CPU或反之,并且不希望保留梯度信息时,可以使用detach。
x = torch.tensor([1.0, 2.0], device='cuda', requires_grad=True)
y = x.cpu().detach().numpy()  # 移动到CPU,分离,然后转换为NumPy数组
  • 在上面的例子中,y首先被移动到CPU,然后通过detach分离,最后转换为NumPy数组,此时y不再参与梯度计算。

3. torch的matual函数

在PyTorch中,torch.matmul函数用于计算两个张量的矩阵乘法。当你使用torch.matmul(X, w)时,X是输入张量,w是权重张量。这个操作通常用于线性层的前向传播,其中X代表输入特征,w代表权重参数。

在你给出的表达式torch.matmul(X, w) + b中,b是偏置项,它通常是一个与权重w的输出维度相同的张量。这个表达式代表了线性变换的完整计算过程,即:

这里的Xw需要满足矩阵乘法的维度要求。如果X是一个(N, D)的矩阵,其中N是批量大小,D是特征维度,那么w应该是一个(M, D)的矩阵,其中M是输出维度。b应该是一个(1, M)的向量,这样加法操作可以广播到Xw的乘积上。

下面是一个简单的例子:

import torch# 假设X是一个批量大小为3,特征维度为4的输入张量
X = torch.randn(3, 4)# w是一个输出维度为2,特征维度为4的权重张量
w = torch.randn(2, 4)# b是一个输出维度为2的偏置向量
b = torch.randn(1, 2)# 计算线性变换
output = torch.matmul(X, w) + b

在这个例子中,output将是一个(3, 2)的张量,它代表了输入X通过权重w和偏置b变换后的结果。

在实际的深度学习模型中,这种线性变换通常被封装在nn.Linear模块中,这样你不需要手动编写矩阵乘法和加法操作,而是可以直接使用PyTorch提供的高级API。例如:

import torch.nn as nn# 创建一个线性层,输入特征维度为4,输出特征维度为2
linear_layer = nn.Linear(4, 2)# 假设X是一个批量大小为3,特征维度为4的输入张量
X = torch.randn(3, 4)# 前向传播
output = linear_layer(X)

在这个例子中,linear_layer会自动处理权重w和偏置b的计算,output将是(3, 2)的张量。

4. torch的no_grad()和grad.zero()

def sgd(params, lr, batch_size):  #@save"""小批量随机梯度下降"""with torch.no_grad():for param in params:param -= lr * param.grad / batch_sizeparam.grad.zero_()

这段代码定义了一个简单的小批量随机梯度下降(SGD)优化函数。让我们逐步分析这个函数的各个部分:

  1. 函数定义: def sgd(params, lr, batch_size): 定义了一个名为sgd的函数,它接受三个参数:

    • params: 一个参数列表,包含了模型中的所有参数(通常是一个模型的.parameters()方法的返回值)。
    • lr: 学习率,这是一个超参数,用于控制每次更新步长的大小。
    • batch_size: 批量大小,表示每次迭代中用于计算梯度的样本数量。
  2. 上下文管理器: with torch.no_grad(): 这个上下文管理器指示PyTorch在代码块内不计算梯度。这是因为在执行优化更新时,我们不希望梯度被累加到现有的梯度上,而是希望使用当前计算的梯度来更新参数。

  3. 参数迭代: for param in params: 这个循环遍历params列表中的每一个参数。

  4. 梯度更新: param -= lr * param.grad / batch_size 这一行是SGD更新的核心。它计算每个参数的梯度,并用学习率乘以梯度,然后从参数中减去这个值以更新参数。这里,param.grad是当前参数的梯度,lr是学习率,batch_size用于调整梯度的大小,以匹配全批量梯度的大小。

  5. 梯度清零: param.grad.zero_() 在每次参数更新后,我们需要将梯度清零,以避免梯度累加。这是因为在torch.no_grad()上下文管理器中,梯度不会自动清零。

这个函数是一个简化版的SGD实现,用于演示梯度下降的基本原理。在实际应用中,通常会使用PyTorch提供的优化器类,如torch.optim.SGD它提供了更完整的功能,包括动量(momentum)和其他优化技术。

使用这个函数的一个例子是:

# 假设model是我们的模型实例
optimizer = sgd(model.parameters(), lr=0.01, batch_size=64)
# 在训练循环中
for inputs, labels in data_loader:optimizer.zero_grad()  # 清零梯度,这是PyTorch的推荐做法outputs = model(inputs)loss = loss_function(outputs, labels)loss.backward()  # 反向传播,计算梯度optimizer(model.parameters(), lr=0.01, batch_size=64)  # 调用自定义的SGD函数

请注意,上面的示例中optimizer.zero_grad()是PyTorch的标准做法,用于在每次迭代开始前清零梯度。而在自定义的SGD函数中,我们通过param.grad.zero_()来手动清零梯度。

5. nn.Sequential()

在PyTorch中,nn.Sequential是一个容器模块,用于包装一个有序的模块列表。它按照列表中定义的顺序,将输入依次传递给每个模块。每个模块可以是任何继承自nn.Module的类实例,比如层、激活函数、损失函数等。

nn.Sequential的主要优点是简化了模型的构建过程,特别是当模型由一系列简单的、顺序的层组成时。使用nn.Sequential,你不需要显式地编写前向传播逻辑,它会自动将输入数据从第一个模块传递到最后一个模块。

以下是nn.Sequential的一些关键特性:

  1. 简单易用:通过简单地将层添加到列表中,你可以快速构建一个顺序模型。

  2. 自动前向传播nn.Sequential会自动将输入数据从第一个模块传递到最后一个模块。

  3. 支持任意模块:可以包含任何nn.Module的子类,包括自定义模块。

  4. 易于扩展:可以很容易地添加或移除模块,或者修改模块的顺序。

下面是一个使用nn.Sequential的例子:

import torch
import torch.nn as nn# 定义模型
model = nn.Sequential(nn.Linear(in_features=10, out_features=5),nn.ReLU(),nn.Linear(in_features=5, out_features=2)
)# 创建一个输入张量
x = torch.randn(1, 10)# 前向传播
output = model(x)
print(output)

在这个例子中,我们首先导入了必要的库,然后创建了一个nn.Sequential模型,其中包含两个线性层和一个ReLU激活层。接着,我们创建了一个随机初始化的输入张量x,并进行了前向传播,得到了模型的输出。

使用nn.Sequential可以大大简化模型定义的代码,使得模型结构清晰易懂。然而,当模型结构更加复杂,比如需要分支、合并或者非顺序的层连接时,你可能需要使用更灵活的模型定义方式。

6. nn.linear

在PyTorch中,nn.Linear是一个实现线性层的模块,也被称为全连接层(fully connected layer)。它对输入数据执行一个线性变换,公式如下:

在这里插入图片描述
这里的weight是模型的参数,bias是偏置项,input是输入数据。

关键参数:
  • in_features: 输入特征的数量,即输入数据的维度。
  • out_features: 输出特征的数量,即模型输出的维度。
主要方法:
  • forward(input): 定义了如何计算前向传播的输出。
示例用法:
import torch
import torch.nn as nn# 创建一个线性层实例,输入特征为3,输出特征为2
linear_layer = nn.Linear(in_features=3, out_features=2)# 创建一个输入张量,假设有一个批量大小为4
input = torch.randn(4, 3)  # 4个样本,每个样本3个特征# 前向传播,计算输出
output = linear_layer(input)print(output)  # 输出的维度将是 [4, 2]

在这个示例中,linear_layer接收一个维度为[4, 3]的输入张量,输出一个维度为[4, 2]的张量。

权重和偏置:

nn.Linear模块拥有两个可学习的参数:

  • weight: 一个形状为[out_features, in_features]的张量。
  • bias: 一个形状为[out_features]的张量,每个输出特征都有一个偏置项。如果初始化nn.Linear时设置bias=False,则不会使用偏置项。
权重初始化:

PyTorch提供了多种权重初始化方法,可以通过weight参数传递给nn.Linear。例如:

linear_layer = nn.Linear(in_features=3, out_features=2, bias=True)
with torch.no_grad():# 使用xavier_uniform_初始化权重torch.nn.init.xavier_uniform_(linear_layer.weight)# 使用zeros_初始化偏置torch.nn.init.zeros_(linear_layer.bias)
训练过程中的注意事项:

在训练过程中,nn.Linear的权重和偏置会自动根据梯度下降或其他优化算法进行更新。此外,确保在每次迭代开始时调用optimizer.zero_grad()来清空之前的梯度,以避免梯度累积。

nn.Linear是构建神经网络时最基础和常用的模块之一,适用于各种类型的网络架构。

7. 模型权重参数:data.normal_()与data.fill_()

net[0].weight.data.normal_(0, 0.01)
net[0].bias.data.fill_(0)

这两行代码是PyTorch中对神经网络中第一个层的权重和偏置参数进行初始化的操作。

  1. net[0].weight.data.normal_(0, 0.01)

    • net[0] 表示访问神经网络 net 的第一个层(假设它是一个 nn.ModuleListnn.Sequential 或者任何其他列表形式的层集合)。
    • .weight 属性访问该层的权重参数。
    • .data 属性提供了对参数数据的直接访问,允许我们对其进行操作。
    • .normal_(0, 0.01) 是一个就地操作(in-place operation),它将权重参数的值从正态分布(均值为0,标准差为0.01)中重新抽样并赋值。这是权重初始化的一种常见方法,有时被称为“Xavier初始化”。
  2. net[0].bias.data.fill_(0)

    • .bias 属性访问该层的偏置参数。
    • .data 同样提供了对偏置参数数据的直接访问。
    • .fill_(0) 是一个就地操作,它将偏置参数的所有值设置为0。这是一种常见的偏置初始化方法,特别是在使用ReLU激活函数时,因为初始化为0的偏置可以保证在训练开始时,激活函数的输出不会太偏离0。

这种手动初始化参数的做法有助于改善模型的训练动态,特别是在模型训练初期。正确的初始化可以减少梯度消失或梯度爆炸的问题,有助于模型更快收敛。

请注意,这些操作是在模型定义之后和训练开始之前进行的。此外,如果你使用的是PyTorch的nn.Module,通常不需要手动初始化每个参数,因为许多模块在初始化时已经包含了默认的参数初始化策略。如果你需要自定义初始化,可以在模块的构造函数中设置。

8. 【简述神经网络模型的训练过程】trainer = torch.optim.SGD(net.parameters(), lr=0.03)

在PyTorch中,torch.optim.SGD是实现随机梯度下降(Stochastic Gradient Descent,SGD)优化算法的类。它用于更新模型的参数,以最小化损失函数。

以下是创建和使用torch.optim.SGD实例的步骤:

  1. 创建SGD实例:
    torch.optim.SGD的第一个参数是模型参数的迭代器,通常通过调用模型的.parameters()方法获得。第二个参数lr是学习率,它控制每次参数更新的步长。

    trainer = torch.optim.SGD(net.parameters(), lr=0.03)
    

    在这个例子中,net是一个神经网络模型,trainer是SGD优化器的实例,学习率设置为0.03。

  2. 零初始化梯度:
    在每次参数更新之前,需要清空之前迭代步骤中累积的梯度。这可以通过调用优化器的zero_grad()方法实现。

    trainer.zero_grad()
    
  3. 前向传播:
    将输入数据传递给模型,并计算输出和损失。

    output = net(input_data)
    loss = loss_function(output, target)
    
  4. 反向传播:
    使用损失值计算所有参数的梯度。

    loss.backward()
    
  5. 参数更新:
    调用优化器的step()方法来根据计算得到的梯度更新模型的参数。

    trainer.step()
    
  6. 训练循环:
    将上述步骤放入训练循环中,迭代多个epoch或直到满足停止条件。

    for epoch in range(num_epochs):for data, target in dataloader:trainer.zero_grad()           # 清空梯度output = net(data)            # 前向传播loss = loss_function(output, target)  # 计算损失loss.backward()               # 反向传播trainer.step()                # 参数更新
    

使用torch.optim.SGD时,你还可以设置其他参数,如momentum(动量)、weight_decay(权重衰减,用于正则化以防止过拟合)、dampening(阻尼)等,以进一步控制优化过程。

请注意,虽然SGD是一种简单且广泛使用的优化算法,但在某些情况下,其他优化算法(如Adam)可能会提供更快的收敛速度或更好的性能。选择哪种优化器取决于具体任务和个人偏好。

查缺补漏的小细节

1. d2l.use_svg_display()

d2l.use_svg_display()是《动手学深度学习》(Dive into Deep Learning, 简称d2l)这本书的配套代码库中的一个函数。这个函数用于设置图表的显示方式,使得在Jupyter笔记本等环境中,图表可以以SVG格式显示,从而提供更清晰的图像质量,尤其是当图表被放大时。

在Jupyter笔记本中,默认情况下,matplotlib生成的图像可能是以PNG格式显示,这可能导致图像在放大查看时出现像素化。使用d2l.use_svg_display()函数后,图像将以SVG格式渲染,SVG是矢量图,可以在不失真的情况下无限放大。

以下是如何使用这个函数的一个示例:

from d2l import use_svg_display
use_svg_display()  # 调用函数以设置SVG显示# 接下来,当你使用matplotlib生成图像时,它们将以SVG格式显示

请注意,d2l.use_svg_display()函数是d2l库的一部分,你可能需要先安装d2l包才能使用这个函数。d2l是一个开源项目,提供了深度学习的教程和代码示例,它基于Python和深度学习框架如MXNet和PyTorch。

如果你正在使用Jupyter笔记本,并且在安装了d2l之后,可以通过上述函数调用来改变图像的显示格式。如果你没有安装d2l,可以通过运行以下命令来安装:

pip install d2l

安装完成后,就可以在你的Python代码中使用d2l提供的各种功能和工具了。

2. 图片样本可视化示例函数

def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):  #@save"""绘制图像列表"""figsize = (num_cols * scale, num_rows * scale)_, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)axes = axes.flatten()for i, (ax, img) in enumerate(zip(axes, imgs)):if torch.is_tensor(img):# 图片张量ax.imshow(img.numpy())else:# PIL图片ax.imshow(img)ax.axes.get_xaxis().set_visible(False)ax.axes.get_yaxis().set_visible(False)if titles:ax.set_title(titles[i])return axes

这段代码定义了一个名为 show_images 的函数,用于在一个图形窗口中显示图像列表。函数使用 matplotlib 库来绘制图像,并且是为《动手学深度学习》(d2l)这本书的配套代码设计的。下面是对函数各部分的详细解释:

参数说明:
  • imgs: 要显示的图像列表,可以是张量列表或PIL图像列表。
  • num_rows: 要显示的图像行数。
  • num_cols: 每行要显示的图像列数。
  • titles: (可选)每个图像的标题列表。如果提供,每个图像将显示相应的标题。
  • scale: 图像显示的缩放比例,默认为1.5。
函数工作流程:
  1. 计算图形窗口的大小,基于行列数和缩放比例。
  2. 使用 d2l.plt.subplots 创建子图网格,num_rows 表示行数,num_cols 表示列数,figsize 指定图形窗口的大小。
  3. 将子图的坐标轴对象扁平化为一维数组,方便迭代。
  4. 遍历图像和坐标轴的迭代器:
    • 根据图像类型(张量或PIL图像),使用 ax.imshow 显示图像。
    • 隐藏坐标轴。
    • 如果提供了标题列表,为每个子图设置标题。
代码示例:

假设你有一些图像张量存储在列表 imgs 中,你可以使用以下代码调用 show_images 函数来显示它们:

from torch import tensor# 假设 imgs 是包含图像张量的列表
imgs = [tensor(some_image_data) for _ in range(9)]# 显示图像
show_images(imgs, num_rows=3, num_cols=3)
注意事项:
  • 函数中的 d2l.plt.subplots 调用实际上是对 matplotlibplt.subplots 的封装。确保已经导入了 d2l 库并可以使用。
  • 如果图像是张量,函数会使用 .numpy() 方法将其转换为 NumPy 数组,以便 imshow 可以处理。
  • 如果提供 titles 参数,它应该是一个与 imgs 长度相同的列表,其中包含每个图像的标题。

这个函数是 d2l 库中用于可视化图像的实用工具,可以方便地在Jupyter笔记本或其他环境中查看和比较图像。

ax是什么?

在这段代码中,ax是matplotlib中用于表示子图(subplot)的坐标轴(Axes)对象。在创建多个子图时,每个子图都有自己的坐标轴对象,允许你对每个子图进行单独的配置和绘图。

当你调用plt.subplots()函数创建一个图形窗口和一组子图时,它会返回一个图形(Figure)对象和一组坐标轴对象。这些坐标轴对象被组织成一个二维数组,其形状由num_rowsnum_cols参数决定。在这个例子中:

_, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)
axes = axes.flatten()
  • 第一行代码创建了子图,_表示我们不使用返回的图形对象,只关心坐标轴对象。
  • 第二行代码将坐标轴的二维数组扁平化为一维数组,使得你可以使用一个for循环遍历所有的子图。

在for循环中:

for i, (ax, img) in enumerate(zip(axes, imgs)):# 在这里使用 ax 对象来操作每个子图...
  • axes是扁平化后的坐标轴对象数组。
  • imgs是要显示的图像列表。
  • enumerate(zip(axes, imgs))生成一个包含索引和子图图像对的迭代器。
  • ax是当前遍历到的坐标轴对象,你可以使用它来调用如imshowset_title等方法来在特定子图上绘图或设置标题。

简而言之,ax是每个子图的坐标轴对象,它提供了在子图上绘图和设置视觉属性的方法。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/482038.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

0017. shell命令--tac

目录 17. shell命令--tac 功能说明 语法格式 选项说明 实践操作 注意事项 17. shell命令--tac 功能说明 Linux 的 tac 命令用于按行反向输出文件内容,与 cat 命令的输出顺序相反。非常有趣,好记。也就是说,当我们使用tac命令查看文件内…

SpringBoot整合Retry详细教程

问题背景 在现代的分布式系统中,服务间的调用往往需要处理各种网络异常、超时等问题。重试机制是一种常见的解决策略,它允许应用程序在网络故障或临时性错误后自动重新尝试失败的操作。Spring Boot 提供了灵活的方式来集成重试机制,这可以通过…

爬取boss直聘上海市人工智能招聘信息+LDA主题建模

爬取boss直聘上海市人工智能招聘信息 import time import tqdm import random import requests import json import pandas as pd import os from selenium import webdriver from selenium.webdriver.common.by import By from selenium.webdriver.support.ui import WebDriv…

项目快过:知识蒸馏 | 目标检测 |FGD | Focal and Global Knowledge Distillation for Detectors

公开时间:2022年3月9号 项目地址:https://github.com/yzd-v/FGD 论文地址:https://arxiv.org/pdf/2111.11837 知识蒸馏已成功地应用于图像分类。然而,目标检测要复杂得多,大多数知识蒸馏方法都失败了。本文指出&#…

【Linux】匿名管道通信场景——进程池

🔥 个人主页:大耳朵土土垚 🔥 所属专栏:Linux系统编程 这里将会不定期更新有关Linux的内容,欢迎大家点赞,收藏,评论🥳🥳🎉🎉🎉 文章目…

Sybase数据恢复—Sybase数据库无法启动,Sybase Central连接报错的处理案例

Sybase数据库数据恢复环境: Sybase数据库版本:SQL Anywhere 8.0。 Sybase数据库故障&分析: Sybase数据库无法启动。 错误提示: 使用Sybase Central连接报错。 数据库数据恢复工程师经过检测,发现Sybase数据库出现…

分布式FastDFS存储的同步方式

目录 一:FatsDFS的结构图 二:FatsDFS文件同步 前言: 1:同步日志所在目录 2:binlog格式 3:同步规则 4:binlog同步过程 1 :获取组内的其他Storage信息 tracker_report_thread_e…

【绘图】数据可视化(python)

对于数据绝对值差异较大(数据离散) 1. 对数坐标直方图(Histogram with Log Scale) import pandas as pd import matplotlib.pyplot as plt import numpy as np# 示例数据 data {count: [10, 20, 55, 90, 15, 5, 45, 80, 1000, …

使用Dify与BGE-M3搭建RAG(检索增强生成)应用-改进一,使用工作流代替Agnet

文章目录 前言Agent vs 工作流编写工作流 前言 在上一篇中,我们实现了一个基本的基于Dify的RAG的示范。 使用Dify与BGE-M3搭建RAG(检索增强生成)应用 这个效果确实很差。 我们一起来看看,该怎么改进。 今天我们就尝试一下&…

【Linux课程学习】:文件第二弹---理解一切皆文件,缓存区

🎁个人主页:我们的五年 🔍系列专栏:Linux课程学习 🌷追光的人,终会万丈光芒 🎉欢迎大家点赞👍评论📝收藏⭐文章 Linux学习笔记: https://blog.csdn.net/d…

【iOS】《Effective Objective-C 2.0》阅读笔记(一)

文章目录 前言了解OC语言的起源在类的头文件中尽量少引入其他头文件多用字面量语法,少用与之等价的方法字面量数值字面量数组字面量字典 多用类型常量,少用#define预处理指令用枚举法表示状态、选项、状态码 总结 前言 最近开始阅读一些iOS开发的相关书籍…

猫狗分类调试过程

一,下载名称为archive数据集 下载方式:机房共享文件夹 二、打开CatDogProject项目 配置环境:选择你所建的环境 三、调试运行 1,报错一:Traceback (most recent call last): File "G:/AI_Project/CatDogPro…

探索Python WebSocket新境界:picows库揭秘

文章目录 探索Python WebSocket新境界:picows库揭秘第一部分:背景介绍第二部分:picows库概述第三部分:安装picows库第四部分:简单库函数使用方法第五部分:场景应用第六部分:常见Bug及解决方案第…

零基础学安全--Burp Suite(4)proxy模块以及漏洞测试理论

目录 学习连接 一些思路 proxy模块 所在位置 功能简介 使用例子 抓包有一个很重要的点,就是我们可以看到一些在浏览器中看不到的传参点,传参点越多就意味着攻击面越广 学习连接 声明! 学习视频来自B站up主 **泷羽sec** 有兴趣的师傅可…

30 基于51单片机的手环设计仿真

目录 一、主要功能 二、硬件资源 三、程序编程 四、实现现象 一、主要功能 基于STC89C52单片机,DHT11温湿度采集温湿度,滑动变阻器连接ADC0832数模转换器模拟水位传感器检测水位,通过LCD1602显示信息,然后在程序里设置好是否…

十一、快速入门go语言之接口和反射

文章目录 接口:one: 接口基础:two: 接口类型断言和空接口:star2: 空接口实现存储不同数据类型的切片/数组:star2: 复制切片到空接口切片:star2: 类型断言 反射 📅 2024年5月9日 📦 使用版本为1.21.5 接口 十、Java类的封装和继承、多态 - 七点半的菜市…

QT6学习第六天 初识QML

QT6学习第六天 创建Qt Quick UI项目使用Qt Quick DesignerQML 语法基础导入语句 import对象 object 和属性 property布局注释表达式和属性绑定QML 编码约定 设置应用程序图标 创建Qt Quick UI项目 如果你有只测试QML相关内容快速显示界面的需求,这时可以创建Qt Qui…

图解RabbitMQ七种工作模式生产者消费者模型的补充

文章目录 1.消费者模型2.生产者-消费者模型注意事项2.1资源释放顺序问题2.2消费者的声明问题2.3虚拟机和用户的权限问题 3.七种工作模式3.1简单模式3.2工作模式3.3发布/订阅模式3.4路由模式3.5通配符模式3.6RPC通信3.7发布确认 1.消费者模型 之前学习的这个消息队列的快速上手…

C-操作符

操作符种类 在C语言中,操作符有以下几种: 算术操作符 移位操作符 位操作符 逻辑操作符 条件操作符 逗号表达式 下标引用,函数调用 拓展:整型提升 我们介绍常用的几个 算术操作符 (加)&#xff…

使用 Spring Boot 和 GraalVM 的原生镜像

🧑 博主简介:CSDN博客专家,历代文学网(PC端可以访问:历代文学,移动端可微信小程序搜索“历代文学”)总架构师,15年工作经验,精通Java编程,高并发设计&#xf…