使用Pytorch构建神经网络

构建神经网络的典型流程

  • 定义一个拥有可学习参数的神经网络
  • 遍历训练数据集
  • 处理输入数据使其流经神经网络
  • 计算损失值
  • 将网络参数的梯度进行反向传播
  • 以一定的规则更新网络的权重

我们首先定义一个Pytorch实现的神经网络:

# 导入若干工具包
import torch
import torch.nn as nn
import torch.nn.functional as F# 定义一个简单的网络类
class Net(nn.Module):def __init__(self):super(Net, self).__init__()# 定义第一层卷积神经网络, 输入通道维度=1, 输出通道维度=6, 卷积核大小3*3self.conv1 = nn.Conv2d(1, 6, 3)# 定义第二层卷积神经网络, 输入通道维度=6, 输出通道维度=16, 卷积核大小3*3self.conv2 = nn.Conv2d(6, 16, 3)# 定义三层全连接网络self.fc1 = nn.Linear(16 * 6 * 6, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):# (2, 2)的池化窗口下执行最大池化操作x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))x = F.max_pool2d(F.relu(self.conv2(x)), 2)x = x.view(-1, self.num_flat_features(x))x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return xdef num_flat_features(self, x):# 计算size, 除了第0个维度上的batch_sizesize = x.size()[1:]num_features = 1for s in size:num_features *= sreturn num_featuresnet = Net()
print(net)

运行结果
在这里插入图片描述
注意:
模型中所有的可训练参数, 可以通过net.parameters()来获得.

params = list(net.parameters())
print(len(params))
print(params[0].size())

运行结果:
在这里插入图片描述

  • 假设图像的输入尺寸为32 * 32:
input = torch.randn(1, 1, 32, 32)
out = net(input)
print(out)

运行结果
在这里插入图片描述

  • 有了输出张量后, 就可以执行梯度归零和反向传播的操作了.
net.zero_grad()
out.backward(torch.randn(1, 10))
  • 注意
    - torch.nn构建的神经网络只支持mini-batches的输入, 不支持单一样本的输入.
    - 比如: nn.Conv2d 需要一个4D Tensor, 形状为(nSamples, nChannels, Height, Width). 如果你的输入只有单一样本形式, 则需要执行input.unsqueeze(0), 主动将3D Tensor扩充成4D Tensor.

损失函数

  • 损失函数的输入是一个输入的pair: (output, target), 然后计算出一个数值来评估output和target之间的差距大小.
  • 在torch.nn中有若干不同的损失函数可供使用, 比如nn.MSELoss就是通过计算均方差损失来评估输入和目标值之间的差距
  • 应用nn.MSELoss计算损失的一个例子:
output = net(input)
target = torch.randn(10)# 改变target的形状为二维张量, 为了和output匹配
target = target.view(1, -1)
criterion = nn.MSELoss()loss = criterion(output, target)
print(loss)

运行结果:
在这里插入图片描述

  • 关于方向传播的链条: 如果我们跟踪loss反向传播的方向, 使用.grad_fn属性打印, 将可以看到一张完整的计算图如下:
input -> conv2d -> relu -> maxpool2d -> conv2d -> relu -> maxpool2d-> view -> linear -> relu -> linear -> relu -> linear-> MSELoss-> loss
  • 当调用loss.backward()时, 整张计算图将对loss进行自动求导, 所有属性requires_grad=True的Tensors都将参与梯度求导的运算, 并将梯度累加到Tensors中的.grad属性中.
print(loss.grad_fn)  # MSELoss
print(loss.grad_fn.next_functions[0][0])  # Linear
print(loss.grad_fn.next_functions[0][0].next_functions[0][0])  # ReLU

运行结果:
在这里插入图片描述
反向传播(backpropagation)

  • 在Pytorch中执行反向传播非常简便, 全部的操作就是loss.backward().
  • 在执行反向传播之前, 要先将梯度清零,否则梯度会在不同的批次数据之间被累加.
    执行一个反向传播的小例子:
# Pytorch中执行梯度清零的代码
net.zero_grad()print('conv1.bias.grad before backward')
print(net.conv1.bias.grad)# Pytorch中执行反向传播的代码
loss.backward()print('conv1.bias.grad after backward')
print(net.conv1.bias.grad)

运行结果:
在这里插入图片描述
更新网络参数

  • 更新参数最简单的算法就是SGD(随机梯度下降).
  • 具体的算法公式表达式为: weight = weight - learning_rate
    gradient 首先用传统的Python代码来实现SGD如下:
learning_rate = 0.01
for f in net.parameters():f.data.sub_(f.grad.data * learning_rate)

然后使用Pytorch官方推荐的标准代码如下:

# 首先导入优化器的包, optim中包含若干常用的优化算法, 比如SGD, Adam等
import torch.optim as optim# 通过optim创建优化器对象
optimizer = optim.SGD(net.parameters(), lr=0.01)# 将优化器执行梯度清零的操作
optimizer.zero_grad()output = net(input)
loss = criterion(output, target)# 对损失值执行反向传播的操作
loss.backward()
# 参数的更新通过一行标准代码来执行
optimizer.step()

小节总结
学习了构建一个神经网络的典型流程:

  • 定义一个拥有可学习参数的神经网络
  • 遍历训练数据集
  • 处理输入数据使其流经神经网络
  • 计算损失值
  • 将网络参数的梯度进行反向传播
  • 以一定的规则更新网络的权重

学习了损失函数的定义:

  • 采用torch.nn.MSELoss()计算均方误差.
  • 通过loss.backward()进行反向传播计算时, 整张计算图将对loss进行自动求导,
    所有属性requires_grad=True的Tensors都将参与梯度求导的运算, 并将梯度累加到Tensors中的.grad属性中.

学习了反向传播的计算方法:

  • 在Pytorch中执行反向传播非常简便, 全部的操作就是loss.backward().
  • 在执行反向传播之前, 要先将梯度清零, 否则梯度会在不同的批次数据之间被累加.
  • net.zero_grad()
  • loss.backward()

学习了参数的更新方法:

  • 定义优化器来执行参数的优化与更新.

    optimizer = optim.SGD(net.parameters(), lr=0.01)

  • 通过优化器来执行具体的参数更新.

    optimizer.step()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/148435.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

机器学习笔记(一)

1.线性回归模型 2. 损失函数 3.梯度下降算法 多元特征的线性回归 当有多个影响因素的时候,公式可以改写为: 当有多个影响因素的时候为了方便计算,可以使用 Numpy下面的点积方法, np.dot(w,x) 最后再加个b 就省略了很多书写步骤,这叫做矢量化 多元回归的梯度下降 左边是一…

小团队内部资料共享协作:有效实施策略与方法

在高效率的办公节奏下,传统的文件共享方式无法匹配许多团队的需求,并且在现实使用过程中往往存在许多问题,如版本混乱、权限管理困难等。那么小团队的内部资料共享协作应该怎么做呢? 小型团队可以借助专业的协作工具实现高效内部…

Lucene学习总结之Lucene的索引文件格式

四、具体格式 上面曾经交代过,Lucene保存了从Index到Segment到Document到Field一直到Term的正向信息,也包括了从Term到Document映射的反向信息,还有其他一些Lucene特有的信息。下面对这三种信息一一介绍。 4.1. 正向信息 Index –> Seg…

【网络安全---ICMP报文分析】Wireshark教程----Wireshark 分析ICMP报文数据试验

一,试验环境搭建 1-1 试验环境示例图 1-2 环境准备 两台kali主机(虚拟机) kali2022 192.168.220.129/24 kali2022 192.168.220.3/27 1-2-1 网关配置: 编辑-------- 虚拟网路编辑器 更改设置进来以后 ,先选择N…

设计模式12、代理模式 Proxy

解释说明:代理模式(Proxy Pattern)为其他对象提供了一种代理,以控制对这个对象的访问。在某些情况下,一个对象不适合或者不能直接引用另一个对象,而代理对象可以在客户端和目标对象之间起到中介的作用。 抽…

【2023年11月第四版教材】第18章《项目绩效域》(第一部分)

第18章《项目绩效域》(第一部分) 1 章节内容2 干系人绩效域2.1 绩效要点2.2 执行效果检查2.3 与其他绩效域的相互作用 3 团队绩效域3.1 绩效要点3.2 与其他绩效域的相互作用3.3 执行效果检查3.4 开发方法和生命周期绩效域 4 绩效要点4.1 与其他绩效域的相…

架构案例2022(四十二)

促销管理系统 某电子商务公司拟升级其会员与促销管理系统,向用户提供个性化服务,提高用户的粘性。在项目立项之初,公司领导层一致认为本次升级的主要目标是提升会员管理方式的灵活性,由于当前用户规模不大,业务也相对…

《数字图像处理-OpenCV/Python》连载(10)图像属性与数据类型

《数字图像处理-OpenCV/Python》连载(10)图像属性与数据类型 本书京东优惠购书链接:https://item.jd.com/14098452.html 本书CSDN独家连载专栏:https://blog.csdn.net/youcans/category_12418787.html 第2章 图像的数据格式 在P…

MyBatisPlus(九)模糊查询

说明 模糊查询&#xff0c;对应SQL语句中的 like 语句&#xff0c;模糊匹配“要查询的内容”。 like /*** 查询用户列表&#xff0c; 查询条件&#xff1a;姓名包含 "J"*/Testvoid like() {String name "J";LambdaQueryWrapper<User> wrapper ne…

mysql面试题17:MySQL引擎InnoDB与MyISAM的区别

该文章专注于面试,面试只要回答关键点即可,不需要对框架有非常深入的回答,如果你想应付面试,是足够了,抓住关键点 面试官:MySQL引擎InnoDB与MyISAM的区别 InnoDB和MyISAM是MySQL中两种常见的存储引擎,它们在功能和性能方面有一些区别。下面将详细介绍它们之间的差异。…

基于安卓android微信小程序的校园维修平台

项目介绍 随着社会的发展&#xff0c;社会的方方面面都在利用信息化时代的优势。互联网的优势和普及使得各种系统的开发成为必需。 本文以实际运用为开发背景&#xff0c;运用软件工程原理和开发方法&#xff0c;它主要是采用java语言技术和mysql数据库来完成对系统的设计。整…

OpenCV 15(SIFT/SURF算法)

一、SIFT Harris和Shi-Tomasi角点检测算法&#xff0c;这两种算法具有旋转不变性&#xff0c;但不具有尺度不变性&#xff0c;以下图为例&#xff0c;在左侧小图中可以检测到角点&#xff0c;但是图像被放大后&#xff0c;在使用同样的窗口&#xff0c;就检测不到角点了。 尺度…

如何在springboot2中利用mybatis-plus进行分页查询操作。

1.创建配置mp的配置类 在mp的拦截器中加入分页拦截器 package com.example.config;import com.baomidou.mybatisplus.extension.plugins.MybatisPlusInterceptor; import com.baomidou.mybatisplus.extension.plugins.inner.PaginationInnerInterceptor; import org.springfra…

10月4日作业

server #include "widget.h" #include "ui_widget.h"Widget::Widget(QWidget *parent): QWidget(parent), ui(new Ui::Widget) {ui->setupUi(this);//实例化一个服务器server new QTcpServer(this);connect(server, &QTcpServer::newConnection, …

数仓使用SQL脚本在数据库中添加初始数据示例

文章目录 需要在虚拟机上开启数据库 点击确定后&#xff0c;可以点开这个连接&#xff0c;查看数据库信息 运行 init_mysql.sql 创建mall 数据库 -- 设置sql_mode set sql_mode NO_ENGINE_SUBSTITUTION,STRICT_TRANS_TABLES;-- 创建数据库mall create database mall;-- 切换数…

iMazing 2.17.10官方中文版含2023最新激活许可证码

iMazing 2.17.10官方中文版是一款iOS设备管理软件&#xff0c;该软件支持对基于iOS系统的设备进行数据传输与备份&#xff0c;用户可以将包括&#xff1a;照片、音乐、铃声、视频、电子书及通讯录等在内的众多信息在Windows/Mac电脑中传输/备份/管理。 iMazing 2.17.10官方中文…

多层神经网络和激活函数

多层神经网络的结构 多层神经网络就是由单层神经网络进行叠加之后得到的&#xff0c;所以就形成了层的概念&#xff0c;常见的多层神经网络有如下结构&#xff1a; 1&#xff09;输入层&#xff08;Input layer&#xff09;&#xff0c;众多神经元&#xff08;Neuron&#xff…

【多线程进阶】死锁问题

文章目录 前言1. 什么是死锁1.1 死锁的三种典型情况 2. 死锁产生的必要条件3.如何解决死锁问题总结 前言 上文锁策略中, 当谈到可重入锁和不可重入锁时, 我们引入了一个 “死锁” 的概念, 当针对一把不可重入锁进行连续两次的加锁行为时, 就会产生死锁. 本文就重点来讲解一下…

【自动化测试】测试开发工具大合集

收集和整理各种测试工具&#xff0c;自动化测试工具&#xff0c;自动化测试框架&#xff0c;觉得有帮助记得三连一下。 欢迎提交各类测试工具到本博客。 通用测试框架 JUnit: 最著名的xUnit类的单元测试框架&#xff0c;但是不仅仅可以做单元测试。TestNG: 更强大的Java测试框…

深度学习(3)---PyTorch中的张量

文章目录 一、张量简介与创建1.1 简介1.2 张量的创建 二、张量的操作2.1 张量的拼接与切分2.2 张量索引 三、张量的数学运算 一、张量简介与创建 1.1 简介 1. 张量是一个多维数组&#xff0c;它是标量、向量、矩阵的高维拓展。 2. 在张量的定义中&#xff0c;方括号用于表示张…