深度学习(15)--PyTorch构建卷积神经网络

目录

一.PyTorch构建卷积神经网络(CNN)详细流程

二.graphviz + torchviz使PyTorch网络可视化

2.1.可视化经典网络vgg16

2.2.可视化自己定义的网络


一.PyTorch构建卷积神经网络(CNN)详细流程

卷积神经网络(Convolutional Neural Networks)是一种深度学习模型或类似于人工神经网络的多层感知器,常用来分析视觉图像。

卷积神经网络的详细介绍可以参考博主写的文章:

深度学习(2)--卷积神经网络(CNN)-CSDN博客icon-default.png?t=N7T8https://blog.csdn.net/GodFishhh/article/details/135668789?spm=1001.2014.3001.5501

PyTorch构建神经网络的第一步均为引入神经网络包

import torch.nn as nn

卷积神经网络的构建: 

class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()# 卷积层->激活函数->池化层self.conv1 = nn.Sequential(         # 输入大小 (1, 28, 28)  pytorch中是channel_first的,颜色通道写在第一个位置nn.Conv2d(                      # 1d对结构化数据 2d对图像数据 3d对视频数据in_channels=1,              # 灰度图   输入的特征图数out_channels=16,            # 要得到几多少个特征图  输出的特征图数,也就是卷积核的个数(一个卷积核进行卷积可以得到一个特征图,所以卷积核的个数与特征图的数量相同)kernel_size=5,              # 卷积核大小stride=1,                   # 步长padding=2,                  # 如果希望卷积后大小跟原来一样,需要设置padding=(kernel_size-1)/2 if stride=1  卷积后的图像大小: (h - Kernel_size + 2*p) / s + 1),                              # 输出的特征图为 (16, 28, 28)nn.ReLU(),                      # relu层nn.MaxPool2d(kernel_size=2),    # 进行池化操作(2x2 区域), 输出结果为: (16, 14, 14)  池化后特征数变少)self.conv2 = nn.Sequential(         # 下一个套餐的输入 (16, 14, 14)nn.Conv2d(16, 32, 5, 1, 2),     # 输出 (32, 14, 14)  nn.ReLU(),                      # relu层nn.MaxPool2d(2),                # 输出 (32, 7, 7))self.out = nn.Linear(32 * 7 * 7, 10)   # 全连接层得到的结果  最终数据的大小以及分类的数量def forward(self, x):# 调用卷积层x = self.conv1(x)x = self.conv2(x)x = x.view(x.size(0), -1)           # flatten操作,结果为:(batch_size, 32 * 7 * 7),分类无法对三维的数据进行处理,所以需要将三维图像拉长成一维数据再来进行分类.# -1是自动计算,只需给出一个维度的大小,会自动计算另外个维度.eg.5x4 -> x.view(2,-1),-1对应的就是10. 2x5x10 -> x.view(2,-1),-1对应的就是5x10# 在此处,给出的第一个参数x.size(0)的值为batch,所以-1对应的值就是32x7x7# 调用全连接层(全连接层的输入必须是二维的矩阵,上述的flattern操作将参数x变成了一个二维矩阵)output = self.out(x)return output

详解:

1.创建的神经网络构建类一定要继承nn.Module,后续要调用Module包里面的方法构建神经网络。

2.构造函数的第一步永远是调用父类的构造函数,利用super()进行调用:

super(CNN, self).__init__()

3.卷积神经网络的层次顺序一般为:卷积层-> 激活函数做非线性变换 ->池化层,并在输出之前设置一层全连接层。

4.上述代码构建的卷积神经网络是顺序Sequential的,设置有两个卷积层,两个激活函数,两个池化层,以及输出前的一个全连接层。(一般卷积一次就要池化一次)

nn.Sequential()

5.卷积层的构造:通过Module模块中的Conv2d来构造卷积层,其中参数分别为:输入图片数据的颜色通道数(第一个卷积层)/输入的特征图数(之后的卷积层)、输出的特征图数、卷积核的大小、步长、padding值。(其中Conv1d用来处理结构化数据,Conv2d用来处理图片数据,Conv3d用来处理视频数据)

此处设置的卷积层由输入的1个特征图数得到最后的32个特征图数

nn.Conv2d(1, 32, 5, 1, 2)
nn.Conv2d(16, 32, 5, 1, 2)

值得注意的是,如果希望卷积后大小跟原来一样,需要设置padding=(kernel_size-1)/2 if stride=1  卷积后的图像大小: (h - Kernel_size + 2*p) / s + 1。

6.此处激活函数设置的是ReLU,可以根据自己的需求设置不同的激活函数。

nn.ReLU()

7.池化层的构造: 只需要设置一个参数,即为进行池化操作的区域大小。

nn.MaxPool2d(kernel_size=2)

8.全连接层的构造:输入的数据最后经过全连接层得到输出数据,参数分别为输入数据的大小,以及最后进行分类的类别数。

self.out = nn.Linear(32 * 7 * 7, 10)

9.前向传播:PyTorch构建的神经网络,前向传播需要手动设置,此处先调用conv1和conv2两层,再将数据拉成二维的传入全连接层,得到最后的输出值。

二.graphviz + torchviz使PyTorch网络可视化

事先需要先安装graphviz库和torchviz库,graphviz具体安装步骤可以参考博主写的文章:

深度学习(9)--pydot库和graphviz库安装流程详解_pydot 怎么安装-CSDN博客icon-default.png?t=N7T8https://blog.csdn.net/GodFishhh/article/details/135929146?spm=1001.2014.3001.5501torchviz库可以直接再编译器中进行安装,也可也在cmd中对应环境中使用pip指令安装:

上述两个库安装完之后,导入网络可视化需要用到的头文件:

from torchviz import make_dot
from torchvision.models import vgg16  # 导入vgg16模型用于演示

2.1.可视化经典网络vgg16

# 随机生成一个tensor张量(对应的数据为图片有十张,图片的大小为3x32x32)
x = torch.randn(10, 3, 32, 32)
# 实例化 vgg16
model = vgg16()
# 将 x 输入网络
vgg16_out = model(x)
# 实例化 make_dot
vgg16_result = make_dot(vgg16_out)
# result.view()  直接在当前路径下保存 pdf 并打开
# 保存文件为pdf到指定路径并不打开
vgg16_result.render(filename='vgg16_net_Structure', view=False, format='pdf')

生成如下两个文件 

 

2.2.可视化自己定义的网络

# 随机生成一个tensor张量(对应的数据为图片有四张,图片的大小为1x28x28)
x = torch.randn(4, 1, 28, 28)
# 实例化 vgg16
model = CNN()
# 将 x 输入网络
CNN_out = model(x)
# 实例化 make_dot
CNN_result = make_dot(CNN_out)
# result.view()  直接在当前路径下保存 pdf 并打开
# 保存文件为pdf到指定路径并不打开
CNN_result.render(filename='CNN_net_Structure', view=False, format='pdf')

生成如下两个文件  

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/255961.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【MySQL基础】:深入探索DQL数据库查询语言的精髓(上)

🎥 屿小夏 : 个人主页 🔥个人专栏 : MySQL从入门到进阶 🌄 莫道桑榆晚,为霞尚满天! 文章目录 📑前言一. DQL1.1 基本语法1.2 基础查询1.3 条件查询1.3 聚合函数 🌤️ 全篇…

【51单片机】串口通信实验(包括波特率如何计算)

目录 串口通信实验通信的基本概念串行通信与并行通信异步通信与同步通信单工、 半双工与全双工通信通信速率 51单片机串口介绍串口介绍串口通信简介串口相关寄存器串口工作方式方式0方式1方式 2 和方式 3 串口的使用方法(计算波特率) 硬件设计软件设计1、…

Akamai 如何揪出微软 RPC 服务中的漏洞

近日,Akamai研究人员在微软Windows RPC服务中发现了两个重要漏洞:严重程度分值为4.3的CVE-2022-38034,以及分值为8.8的CVE-2022-38045。这些漏洞可以利用设计上的瑕疵,通过缓存机制绕过MS-RPC安全回调。我们已经确认,所…

clickhouse计算前后两点间经纬度距离

问题 计算如图所示前后两点经纬度的距离? 方法 1、用开窗函数将如图所示数据下移一行 selectlongitude lon1,latitude lat1,min(longitude) over(order by time1 asc rows between 1 PRECEDING and 1 PRECEDING) lon2,min(latitude) over(order by time1 asc row…

小游戏和GUI编程(5) | SVG图像格式简介

小游戏和GUI编程(5) | SVG图像格式简介 0. 问题 Q1: SVG 是什么的缩写?Q2: SVG 是一种图像格式吗?Q3: SVG 相对于其他图像格式的优点和缺点是什么?Q4: 哪些工具可以查看 SVG 图像?Q5: SVG 图像格式的规范是怎样的?Q6…

CPP项目:Boost搜索引擎

1.项目背景 对于Boost库来说,它是没有搜索功能的,所以我们可以实现一个Boost搜索引擎来实现一个简单的搜索功能,可以更快速的实现Boost库的查找,在这里,我们实现的是站内搜索,而不是全网搜索。 2.对于搜索…

11.0 Zookeeper watcher 事件机制原理剖析

zookeeper 的 watcher 机制,可以分为四个过程: 客户端注册 watcher。服务端处理 watcher。服务端触发 watcher 事件。客户端回调 watcher。 其中客户端注册 watcher 有三种方式,调用客户端 API 可以分别通过 getData、exists、getChildren …

ELAdmin 前端启动

开发工具 官方指导的是使用WebStorm,但是本人后端开发一枚,最终还是继续使用了 idea,主打一个能用就行。 idea正式版激活方式: 访问这个查找可用链接:https://3.jetbra.in/进入任意一个能用的里面,顶部提…

kafka教程

Kafka 中,Producer采用push模型,而Consumer采用pull模型。 Topic Topic(主题)是消息的逻辑分类或通道。它是Kafka中用于组织和存储消息的基本单元。一个Topic可以被看作是一个消息发布的地方,生产者将消息发布到一个…

Axios设置token到请求头的三种方式

1、为什么要携带token? 用户登录时,后端会返回一个token,并且保存到浏览器的localstorage中,可以根据localstorage中的token判断用户是否登录,登录后才有权限访问相关的页面,所以当发送请求时,都要携带to…

ubuntu彻底卸载cuda 重新安装cuda

sudo apt-get --purge remove "*cublas*" "*cufft*" "*curand*" \"*cusolver*" "*cusparse*" "*npp*" "*nvjpeg*" "cuda*" "nsight*" cuda10以上 cd /usr/local/cuda-xx.x/bin/ s…

fast.ai 机器学习笔记(四)

机器学习 1:第 11 课 原文:medium.com/hiromi_suenaga/machine-learning-1-lesson-11-7564c3c18bbb 译者:飞龙 协议:CC BY-NC-SA 4.0 来自机器学习课程的个人笔记。随着我继续复习课程以“真正”理解它,这些笔记将继续…

idea自带的HttpClient使用

1. 全局变量配置 {"local":{"baseUrl": "http://localhost:9001/"},"test": {"baseUrl": "http://localhost:9002/"} }2. 登录并将结果设置到全局变量 PostMapping("/login")public JSONObject login(H…

华为机考入门python3--(9)牛客9-提取不重复的整数

分类:列表 知识点: 从右往左遍历每一个字符 my_str[::-1] 题目来自【牛客】 def reverse_unique(n): # 将输入的整数转换为字符串,这样可以从右向左遍历每一位 str_n str(n) # 创建一个空列表来保存不重复的数字 unique_digits []…

【C++】STL之string 超详解

目录 1.string概述 2.string使用 1.构造初始化 2.成员函数 1.迭代器 2.容量操作 1.size和length 返回字符串长度 2.resize 调整字符串大小 3.capacity 获得字符串容量 4.reserve 调整容量 5.clear 清除 6.empty 判空 3.string插入、追加 、拼接 1.运算…

16 亚稳态原理和解决方案

1. 亚稳态原理 亚稳态是指触发器无法在某个规定的时间段内到达一个可以确认的状态。在同步系统中,输入总是与时钟同步,因此寄存器的setup time和hold time是满足的,一般情况下是不会发生亚稳态情况的。在异步信号采集中,由于异步…

寒假9-蓝桥杯训练

//轨道炮 #include<iostream> using namespace std; #include<algorithm> int logs[100010]; int main() {int n;cin >> n;for (int i 1;i < n;i){cin >> logs[i];}sort(logs 1, logs n 1);int ans 1000000000;for (int i 2;i < n;i){if (…

【数学建模】【2024年】【第40届】【MCM/ICM】【F题 减少非法野生动物贸易】【解题思路】

一、题目 &#xff08;一&#xff09; 赛题原文 2024 ICM Problem F: Reducing Illegal Wildlife Trade Illegal wildlife trade negatively impacts our environment and threatens global biodiversity. It is estimated to involve up to 26.5 billion US dollars per y…

分布式系统架构介绍

1、为什么需要分布式架构&#xff1f; 增大系统容量&#xff1a;单台系统的性能瓶颈&#xff0c;多台机器才能应对大规模的应用场景&#xff0c;所以就需要我们的应用支撑平台具备分布式架构。 加强系统的可用&#xff1a;为了满足业务的SLA要求&#xff0c;需要通过分布式架构…

深度学习(13)--PyTorch搭建神经网络进行气温预测

一.搭建神经网络进行气温预测流程详解 1.1.导入所需的工具包 import numpy as np # 矩阵计算 import pandas as pd # 数据读取 import matplotlib.pyplot as plt # 画图处理 import torch # 构建神经网络 import torch.optim as optim # 设置优化器 1.2.读取并处理数据…