【构建卷积神经网络】

构建卷积神经网络

  • 卷积网络中的输入和层与传统神经网络有些区别,需重新设计,训练模块基本一致

全连接层:batch784,各个像素点之间都是没有联系的。
卷积层:batch
12828,各个像素点之间是有联系的。

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets,transforms 
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline

首先读取数据

  • 分别构建训练集和测试集(验证集)
  • DataLoader来迭代取数据
# 定义超参数 
input_size = 28  #图像的总尺寸28*28
num_classes = 10  #标签的种类数
num_epochs = 3  #训练的总循环周期
batch_size = 64  #一个撮(批次)的大小,64张图片# 训练集
train_dataset = datasets.MNIST(root='./data',  train=True,   transform=transforms.ToTensor(),  download=True) # 测试集
test_dataset = datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor())# 构建batch数据
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)

卷积网络模块构建

  • 一般卷积层,relu层,池化层可以写成一个套餐
  • 注意卷积最后结果还是一个特征图,需要把图转换成向量才能做分类或者回归任务

图像是二维卷积 conv2
视频是三维卷积 conv3
单向量是一维卷积 conv1
官网有关conv2d的输出宽度和长度的计算公式
在这里插入图片描述

class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Sequential(         # 输入大小 (1, 28, 28)nn.Conv2d(in_channels=1,              # 1:灰度图;3:RGBout_channels=16,            # 要得到几多少个特征图,即是卷积核的个数 kernel_size=5,              # 卷积核大小stride=1,                   # 步长padding=2,                  # 如果希望卷积后大小跟原来一样,需要设置padding=(kernel_size-1)/2 if stride=1),                              # 输出的特征图为 (16, 28, 28)nn.ReLU(),                      # relu层nn.MaxPool2d(kernel_size=2),    # 进行池化操作(2x2 区域), 输出结果为: (16, 14, 14))self.conv2 = nn.Sequential(         # 下一个套餐的输入 (16, 14, 14)nn.Conv2d(16, 32, 5, 1, 2),     # 输出 (32, 14, 14)nn.ReLU(),                      # relu层nn.Conv2d(32, 32, 5, 1, 2),nn.ReLU(),nn.MaxPool2d(2),                # 输出 (32, 7, 7))self.conv3 = nn.Sequential(         # 下一个套餐的输入 (32, 7, 7)nn.Conv2d(32, 64, 5, 1, 2),     # 输出 (64, 7, 7)nn.ReLU(),             # 输出 (64, 7, 7))self.out = nn.Linear(64 * 7 * 7, 10)   # 全连接层得到的结果def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)x = x.view(x.size(0), -1)           # flatten操作,结果为:(batch_size, 64 * 7 * 7)output = self.out(x)return output

准确率作为评估标准

def accuracy(predictions, labels):pred = torch.max(predictions.data, 1)[1] rights = pred.eq(labels.data.view_as(pred)).sum() return rights, len(labels) 

训练网络模型

# 实例化
net = CNN() 
#损失函数
criterion = nn.CrossEntropyLoss() 
#优化器
optimizer = optim.Adam(net.parameters(), lr=0.001) #定义优化器,普通的随机梯度下降算法#开始训练循环
for epoch in range(num_epochs):#当前epoch的结果保存下来train_rights = [] for batch_idx, (data, target) in enumerate(train_loader):  #针对容器中的每一个批进行循环net.train()                             output = net(data) loss = criterion(output, target) optimizer.zero_grad() loss.backward() optimizer.step() right = accuracy(output, target) train_rights.append(right) if batch_idx % 100 == 0: net.eval() val_rights = [] for (data, target) in test_loader:output = net(data) right = accuracy(output, target) val_rights.append(right)#准确率计算train_r = (sum([tup[0] for tup in train_rights]), sum([tup[1] for tup in train_rights]))val_r = (sum([tup[0] for tup in val_rights]), sum([tup[1] for tup in val_rights]))print('当前epoch: {} [{}/{} ({:.0f}%)]\t损失: {:.6f}\t训练集准确率: {:.2f}%\t测试集正确率: {:.2f}%'.format(epoch, batch_idx * batch_size, len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.data, 100. * train_r[0].numpy() / train_r[1], 100. * val_r[0].numpy() / val_r[1]))

在这里插入图片描述

练习

  • 再加入一层卷积,效果怎么样?
  • 当前任务中为什么全连接层是3277 其中每一个数字代表什么含义

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/86097.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

将vsCode 打开的多个文件分行(栏)排列,实现全部显示,便于切换文件

目录 1. 前言 2. 设置VsCode 多文件分行(栏)排列显示 1. 前言 主流编程IDE几乎都有排列切换选择所要查看的文件功能,如下为Visual Studio 2022的该功能界面: 图 1 图 2 当在Visual Studio 2022打开很多文件时,可以按照图1、图2所示找到自…

Golang struct 结构体指针类型 / 结构体值类型

struct类型的内存分配机制 结构体变量之间的赋值是值拷贝。 type stu struct {Name stringSlice []stringMap1 map[string]string }func main() {s : stu{}s.Slice make([]string, 6)s.Slice[1] "ssss"s.Slice[2] "xxxx"s.Map1 make(map[string]stri…

基础堆排序

目录 基础堆排序 一、概念及其介绍 二、适用说明 三、过程图示 基础堆排序

【Opencv入门到项目实战】(十):项目实战|文档扫描|OCR识别

所有订阅专栏的同学可以私信博主获取源码文件 文章目录 1.引言1.1 什么是光学字符识别 (OCR)1.2 应用领域 2.项目背景介绍3.边缘检测3.1 原始图像读取3.2 预处理3.3 结果展示 3.轮廓检测4.透视变换5.OCR识别5.1 tesseract安装5.2 字符识别 1.引言 今天我们来看一个OCR相关的文…

桂林小程序https证书

现在很多APP都相继推出了小程序,比如微信小程序、百度小程序等,这些小程序的功能也越来越复杂,不可避免的和网站一样会传输数据,因此小程序想要上线就要保证信息传输的安全性,也就是说各种类型的小程序也需要部署https…

怎么用PS的魔术棒抠图?PS魔术棒抠图的操作方法

使用PS的魔术棒抠图教程: 1、首先,在ps界面上方点击“文件”选项,再在其弹出的选项栏中选择“打开”选项。然后,打开你所需要的图片。 2、然后,单击左侧的“魔术棒”工具。 3、然后,用鼠标点击图片的背景&…

【C++11】类的新功能 | 可变参数模板

文章目录 一.类的新功能1.默认成员函数2.类成员变量初始化3.强制生成默认函数的关键字default4.禁止生成默认函数的关键字delete5.继承和多态中final与override关键字 二.可变参数模板1.可变参数模板的概念2.可变参数模板的定义方式3.参数包的展开方式①递归展开参数包②逗号表…

Vue过滤器(时间戳转时间)

目录 过滤器 HTML写法: 定义过滤器: 定义全局过滤器: 过滤器串联: 带参数过滤器: 时间戳转时间 过滤器 官方地址:过滤器 — Vue.js (vuejs.org) 过滤器是指Vue.js支持在{{}}插值的尾部添加一个管道符“&#xff0…

vue3 + vite + ts 封装 SvgIcon组件

环境 vite vue3 ts "vue": "^3.3.4", "vite": "^4.4.0", "typescript": "^5.0.2",# 需要下载的依赖 "vite-plugin-svg-icons": "^2.0.1",不同版本可能存在一定差异, 这篇文章不可能对应所…

通达OA SQL注入漏洞【CVE-2023-4166】

通达OA SQL注入漏洞【CVE-2023-4166】 一、产品简介二、漏洞概述三、影响范围四、复现环境POC小龙POC检测工具: 五、修复建议 免责声明:请勿利用文章内的相关技术从事非法测试,由于传播、利用此文所提供的信息或者工具而造成的任何直接或者间接的后果及损…

宋浩线性代数笔记(五)矩阵的对角化

本章的知识点难度和重要程度都是线代中当之无愧的T0级,对于各种杂碎的知识点,多做题复盘才能良好的掌握,良好掌握的关键点在于:所谓的性质A与性质B,是谁推导得谁~

ApiPost的使用

1. 设计接口 请求参数的介绍 Query:相当于get请求,写的参数在地址栏中可以看到 Body: 相当于 post请求,请求参数不在地址栏中显示。 请求表单类型,用form-data json文件类型,用row 2. 预期响应期望 设置完每一项点一下生成响应…

设计实现数据库表扩展的7种方式

设计实现数据库表扩展的7种方式 在软件开发过程中,数据库是一项关键技术,用于存储、管理和检索数据。数据库表设计是构建健壮数据库系统的核心环节之一。然而,随着业务需求的不断演变和扩展,数据库表中的字段扩展变得至关重要。 …

neo4j的CQL命令实例演示

天行健,君子以自强不息;地势坤,君子以厚德载物。 每个人都有惰性,但不断学习是好好生活的根本,共勉! 文章均为学习整理笔记,分享记录为主,如有错误请指正,共同学习进步。…

92. 反转链表 II

92. 反转链表 II 题目-中等难度示例1. 获取头 反转中间 获取尾 -> 拼接2. 链表转换列表 -> 计算 -> 转换回链表 题目-中等难度 给你单链表的头指针 head 和两个整数 left 和 right &#xff0c;其中 left < right 。请你反转从位置 left 到位置 right 的链表节点…

elk开启组件监控

elk开启组件监控 效果&#xff1a; logstash配置 /etc/logstash/logstash.yml rootnode1:~# grep -Ev "^#|^$" /etc/logstash/logstash.yml path.data: /var/lib/logstash path.logs: /var/log/logstash xpack.monitoring.enabled: true xpack.monitoring.elasti…

深兰科技熊猫汽车牵手首恒出行,人工智能技术提升商用车运营服务

8月8日&#xff0c;深兰科技集团旗下熊猫新能源汽车(上海)有限公司(下称熊猫新能源汽车)与河南首恒出行服务有限公司(下称首恒出行)在深兰科技总部举行签约仪式&#xff0c;首恒出行将向熊猫新能源汽车年定向采购10000台商用车&#xff0c;双方将在汽车后市场领域进行技术合作。…

【ElasticSearch入门】

目录 1.ElasticSearch的简介 2.用数据库实现搜素的功能 3.ES的核心概念 3.1 NRT(Near Realtime)近实时 3.2 cluster集群&#xff0c;ES是一个分布式的系统 3.3 Node节点&#xff0c;就是集群中的一台服务器 3.4 index 索引&#xff08;索引库&#xff09; 3.5 type类型 3.6 doc…

Linux学习————redis服务

目录 一、redis主从服务 一、redis主从服务概念 二、redis主从服务作用 三、缺点 四、主从复制流程 五、搭建主从服务 配置基础环境 下载epel源&#xff0c;下载redis​编辑 二、哨兵模式 一、概念 二、作用 三、缺点 四、结构 五、搭建 修改哨兵配置文件 启动服务…

IntelliJ中文乱码问题

1、控制台乱码 运行时控制台输出的中文为乱码&#xff0c;解决方法&#xff1a;帮助 > 编辑自定义虚拟机选项… > 此时会自动创建出一个新文件&#xff0c;输入&#xff1a;-Dfile.encodingUTF-8&#xff0c;然后重启IDE即可&#xff0c;操作截图如下&#xff1a; 2、…