深度学习---卷积神经网络

卷积神经网络概述

卷积神经网络是深度学习在计算机视觉领域的突破性成果。在计算机视觉领域。往往输入的图像都很大,使用全连接网络的话,计算的代价较高。另外图像也很难保留原有的特征,导致图像处理的准确率不高。

卷积神经网络(Convolutional Neural Network)是含有卷积层的神经网络。卷积层的作用就是用来自动学习、提取图像的特征。

CNN网络主要有三部分构成:卷积层、池化层和全连接层构成,其中卷积层负责提取图像中的局部特征;池化层用来大幅降低参数量级(降维);全连接层类似人工神经网络的部分,用来输出想要的结果。

图像概述

图像是由像素点组成的,每个像素点的值范围为:[0,255],像素值越大意味着较亮。比如一张 200x200 的图像,则是由 40000 个像素点组成,如果每个像素点都是 0 的话,意味着这是一张全黑的图像。

彩色图一般都是多通道的图像,所谓多通道可以理解为图像由多个不同的图像层叠加而成,例如看到的彩色图像一般都是由 RGB 三个通道组成的,还有一些图像具有 RGBA 四个通道,最后一个通道为透明通道,该值越小,则图像越透明。

卷积层

卷积计算

Padding

Stride

多通道卷积计算

多卷积核卷积计算

特征图大小

池化层

池化层 (Pooling) 降低维度,缩减模型大小,提高计算速度。即:主要对卷积层学习到的特征图进行下采样(SubSampling)处理。池化层主要有两种:最大池化、平均池化。

池化层计算

Stride

Padding

多通道池化计算

案例-图像分类

CIFAR10 数据集

from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader# 1. 数据集基本信息
def test01():# 加载数据集train = CIFAR10(root='data', train=True, transform=Compose([ToTensor()]))valid = CIFAR10(root='data', train=False, transform=Compose([ToTensor()]))# 数据集数量print('训练集数量:', len(train.targets))print('测试集数量:', len(valid.targets))# 数据集形状print("数据集形状:", train[0][0].shape)# 数据集类别print("数据集类别:", train.class_to_idx)# 2. 数据加载器
def test02():train = CIFAR10(root='data', train=True, transform=Compose([ToTensor()]))dataloader = DataLoader(train, batch_size=8, shuffle=True)for x, y in dataloader:print(x.shape)print(y)breakif __name__ == '__main__':test01()test02()

搭建图像分类网络

class ImageClassification(nn.Module):def __init__(self):super(ImageClassification, self).__init__()self.conv1 = nn.Conv2d(3, 6, stride=1, kernel_size=3)self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)self.conv2 = nn.Conv2d(6, 16, stride=1, kernel_size=3)self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)self.linear1 = nn.Linear(576, 120)self.linear2 = nn.Linear(120, 84)self.out = nn.Linear(84, 10)def forward(self, x):x = F.relu(self.conv1(x))x = self.pool1(x)x = F.relu(self.conv2(x))x = self.pool2(x)# 由于最后一个批次可能不够 32,所以需要根据批次数量来 flattenx = x.reshape(x.size(0), -1)x = F.relu(self.linear1(x))x = F.relu(self.linear2(x))return self.out(x)

编写训练函数

使用多分类交叉熵损失函数,Adam 优化器:

def train():# 加载 CIFAR10 训练集, 并将其转换为张量transgform = Compose([ToTensor()])cifar10 = torchvision.datasets.CIFAR10(root='data', train=True, download=True, transform=transgform)# 构建图像分类模型model = ImageClassification()# 构建损失函数criterion = nn.CrossEntropyLoss()# 构建优化方法optimizer = optim.Adam(model.parameters(), lr=1e-3)# 训练轮数epoch = 100for epoch_idx in range(epoch):# 构建数据加载器dataloader = DataLoader(cifar10, batch_size=BATCH_SIZE, shuffle=True)# 样本数量sam_num = 0# 损失总和total_loss = 0.0# 开始时间start = time.time()correct = 0for x, y in dataloader:# 送入模型output = model(x)# 计算损失loss = criterion(output, y)# 梯度清零optimizer.zero_grad()# 反向传播loss.backward()# 参数更新optimizer.step()correct += (torch.argmax(output, dim=-1) == y).sum()total_loss += (loss.item() * len(y))sam_num += len(y)print('epoch:%2s loss:%.5f acc:%.2f time:%.2fs' %(epoch_idx + 1,total_loss / sam_num,correct / sam_num,time.time() - start))# 序列化模型torch.save(model.state_dict(), 'model/image_classification.bin')

编写预测函数

def test():# 加载 CIFAR10 测试集, 并将其转换为张量transgform = Compose([ToTensor()])cifar10 = torchvision.datasets.CIFAR10(root='data', train=False, download=True, transform=transgform)# 构建数据加载器dataloader = DataLoader(cifar10, batch_size=BATCH_SIZE, shuffle=True)# 加载模型model = ImageClassification()model.load_state_dict(torch.load('model/image_classification.bin'))model.eval()total_correct = 0total_samples = 0for x, y in dataloader:output = model(x)total_correct += (torch.argmax(output, dim=-1) == y).sum()total_samples += len(y)print('Acc: %.2f' % (total_correct / total_samples))

总结

可以从以下几个方面来调整网络:

  1. 增加卷积核输出通道数
  2. 增加全连接层的参数量
  3. 调整学习率
  4. 调整优化方法
  5. 修改激活函数
  6. 等等...

把学习率由 1e-3 修改为 1e-4,并网络参数量增加如下代码所示:

class ImageClassification(nn.Module):def __init__(self):super(ImageClassification, self).__init__()self.conv1 = nn.Conv2d(3, 32, stride=1, kernel_size=3)self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)self.conv2 = nn.Conv2d(32, 128, stride=1, kernel_size=3)self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)self.linear1 = nn.Linear(128 * 6 * 6, 2048)self.linear2 = nn.Linear(2048, 2048)self.out = nn.Linear(2048, 10)def forward(self, x):x = F.relu(self.conv1(x))x = self.pool1(x)x = F.relu(self.conv2(x))x = self.pool2(x)# 由于最后一个批次可能不够 32,所以需要根据批次数量来 flattenx = x.reshape(x.size(0), -1)x = F.relu(self.linear1(x))x = F.dropout(x, p=0.5)x = F.relu(self.linear2(x))x = F.dropout(x, p=0.5)return self.out(x)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/168222.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

嵌入式linux系统设备树实例分析

前言 我们可以从LED程序中榨取很多知识:基本的驱动框架、驱动的简单分层、驱动的分层分离思想、总线设备驱动模型、设备树等。这大多都是结合韦老师的教程学的。 这篇笔记结合第6个demo(基于设备树)来学习、分析: 框图 下面是L…

JMeter添加插件

一、前言 ​ 在我们的工作中,我们可以利用一些插件来帮助我们更好的进行性能测试。今天我们来介绍下Jmeter怎么添加插件? 二、插件管理器 ​ 首先我们需要下载插件管理器jar包 下载地址:Install :: JMeter-Plugins.org 然后我们将下载下来…

《红蓝攻防对抗实战》一. 隧道穿透技术详解

一.隧道穿透技术详解 从技术层面来讲,隧道是一种通过互联网的基础设施在网络之间传递数据的方式,其中包括数据封装、传输和解包在内的全过程,使用隧道传递的数据(或负载)可以使用不同协议的数据帧或包。 假设我们获取到一台内网主机的权限,…

如何制作.exe免安装绿色单文件程序,将源代码打包成可独立运行的exe文件

环境: rustdesk编译文件和文件夹 文件程序制作工具 问题描述: 如何制作.exe免安装绿色单文件程序,将源代码打包成可独立运行的exe文件,像官网那种呢? 将下面编译好的rustdesk文件夹制作成一个.exe免安装绿色单文件程序,点击exe就可以运行 在github上找了半天也没有…

AIGC笔记--基于DDPM实现图片生成

目录 1--扩散模型 2--训练过程 3--损失函数 4--生成过程 5--参考 1--扩散模型 完整代码:ljf69/DDPM 扩散模型包含两个过程,前向扩散过程和反向生成过程。 前向扩散过程对一张图像逐渐添加高斯噪声,直至图像变为随机噪声。 反向生成过程…

推荐微软的开源课程《AI-For-Beginners》

今天给大家推荐一个对新手非常友好的AI入门课程《AI-For-Beginners》。 该课程由微软推出,为期12周,共24课时,对比Google的AI入门课更通俗易懂一些,强烈推荐刚入门的AI小白们学习!而且是免费!课程资源看文…

SQL UPDATE 语句(更新表中的记录)

SQL UPDATE 语句 UPDATE 语句用于更新表中已存在的记录。 还可以使用AND或OR运算符组合多个条件。 SQL UPDATE 语法 具有WHERE子句的UPDATE查询的基本语法如下所示: UPDATE table_name SET column1 value1, column2 value2, ... WHERE conditi…

【第三天】C++类和对象进阶指南:从堆区空间操作到友元的深度掌握

一、new和delete 堆区空间操作 1、new和delete操作基本类型的空间 new与C语言中malloc、delete和C语言中free 作用基本相同 区别: new 不用强制类型转换 new在申请空间的时候可以 初始化空间内容 2、 new申请基本类型的数组 3、new和delete操作类的空间 4、new申请…

数据可视化在行业解决方案中的实践应用 ——华为云Astro Canvas大屏开发研究及指南

本文主要探讨华为云Astro Canvas在数据可视化大屏开发中的应用及效果。首先阐述Astro Canvas的基本概念、功能和特性说明,接着集中分析展示其在教育、金融、交通行业等不同领域实际应用案例;之后,详细介绍使用该工具进行大屏图表创建的开发指…

从零开始 Spring Cloud 15:多级缓存

从零开始 Spring Cloud 15:多级缓存 多级缓存架构 传统的缓存使用 Redis,大致架构如下: 这个架构存在一些问题: 请求要经过Tomcat处理,Tomcat的性能成为整个系统的瓶颈 Redis缓存失效时,会对数据库产生冲…

MySQL实践——分页查询优化

问题现象 一个客户业务系统带有分页查询功能,但是随着查询页数的增加,越往后查询性能越差,有时一个查询可能需要1分钟左右的时间。分页查询的写法类似于: select * from employees limit 250000,5000;这是最传统的一种分页查询写…

Amazon图片下载器:利用Scrapy库完成图像下载任务

概述 本文介绍了如何使用Python的Scrapy库编写一个简单的爬虫程序,实现从Amazon网站下载商品图片的功能。Scrapy是一个强大的爬虫框架,提供了许多方便的特性,如选择器、管道、中间件、代理等。本文将重点介绍如何使用Scrapy的图片管道和代理…

Python爬虫:ad广告引擎的模拟登录

⭐️⭐️⭐️⭐️⭐️欢迎来到我的博客⭐️⭐️⭐️⭐️⭐️ 🐴作者:秋无之地 🐴简介:CSDN爬虫、后端、大数据领域创作者。目前从事python爬虫、后端和大数据等相关工作,主要擅长领域有:爬虫、后端、大数据…

神器抓包工具 HTTP Analyzer v7.5 的下载,安装,使用,破解说明以及可能遇到的问题

文章目录 1、HTTP Analyzer 工具能干什么?2、HTTP Analyzer 如何下载?3、如何安装?4、如何使用?5、如何破解?6、Http AnalyzerStd V7可能遇到的问题 1、HTTP Analyzer 工具能干什么? A1:HTTP A…

Linux:命令行参数和环境变量

文章目录 命令行参数环境变量环境变量的概念常见的环境变量PATH 环境变量表本地变量和环境变量命令分类 本篇主要解决以下问题: 什么是命令行参数命令行参数有什么用环境变量是什么环境变量存在的意义 命令行参数 在学习C语言中,对于main函数当初的写…

(二)docker:建立oracle数据库mount startup

这章其实我想试一下startup部分做mount,因为前一章在建完数据库容器后,需要手动创建用户,授权,建表等,好像正好这部分可以放到startup里,在创建容器时直接做好;因为setup部分我实在没想出来能做…

订水商城H5实战教程-02系统登录

目录 1 创建数据源2 创建自定义应用3 创建全局变量4 实现登录功能5 控制弹窗是否显示6 最终的效果 上一篇我们分析了订水商城的功能,功能分析好了之后,就需要开发功能。用户登录商城的第一步就是进行登录,登录的时候需要同意用户协议&#xf…

SpringBoot AOP + Redis 延时双删功能实战

一、业务场景 在多线程并发情况下,假设有两个数据库修改请求,为保证数据库与redis的数据一致性,修改请求的实现中需要修改数据库后,级联修改Redis中的数据。 请求一:A修改数据库数据 B修改Redis数据 请求二&#xff…

修炼k8s+flink+hdfs+dlink(六:学习k8s-pod)

一:增(创建)。 直接进行创建。 kubectl run nginx --imagenginx使用yaml清单方式进行创建。 直接创建方式,并建立pod。 kubectl create deployment my-nginx-deployment --imagenginx:latest 先创建employment,不…

CSS页面基本布局

前提回顾 1. 超文本标记语言(HTML)是一种标记语言,用来结构化我们的网页内容并赋予内容含义; (超文本标记语言(英语:HyperText Markup Language /ˈhaɪpətekst ˈmɑːkʌp ˈlŋɡwɪdʒ /…