深度学习之卷积神经网络框架模型搭建

卷积神经网络框架模型搭建

目录

  • 卷积神经网络框架模型搭建
    • 1 卷积神经网络模型
      • 1.1 卷积神经网络
      • 1.2 卷积层(Convolutional Layer)
        • 1.2.1 输出特征图
      • 1.3 激活函数
      • 1.4 池化层(Pooling Layer)
      • 1.5 全连接层(Fully Connected Layer)
    • 2 框架模型搭建
      • 2.1 框架确定
      • 2.2 框架函数定义
    • 3 代码测试

1 卷积神经网络模型


1.1 卷积神经网络

卷积神经网络(Convolutional Neural Network,CNN)是一种专门用于处理具有网格结构数据(如图像、视频)的深度学习模型卷积神经网络通过卷积层、池化层和全连接层的组合,能够高效地提取图像特征,并在计算机视觉任务中表现出色。

1.2 卷积层(Convolutional Layer)

卷积操作是 CNN 的核心,通过卷积核(Filter)提取局部特征,卷积核是一个小的权重矩阵,在输入数据上滑动并计算点积,每个卷积核会生成一个输出通道,多个卷积核可以提取多种特征。。

1.2.1 输出特征图

输出特征图尺寸由输入尺寸(h,w)、卷积核尺寸(k,k)、步长(s)、填充§决定
输出长度:H1 = (h - k + 2p) / s + 1
输出宽度:W1 = (w - k + 2p) / s + 1
其中当k,s,p为512时输出的特征图尺寸与原图相同

1.3 激活函数

在卷积操作后,通常会使用激活函数引入非线性,进行非线性映射。
常用的激活函数包括:
ReLU:f(x) = max(0, x)
Sigmoid:f(x) = 1 / (1 + e^(-x))

1.4 池化层(Pooling Layer)

池化操作用于降采样,减少特征图的尺寸,减小数据的空间大小,因此参数的数量和计算量也会下降,这在一定程度上也控制了过拟合,同时保留重要信息,池化层通常不引入额外的参数。

  • 常用的池化方法包括:
    • 最大池化(Max Pooling):取局部区域的最大值。
    • 平均池化(Average Pooling):取局部区域的平均值。

1.5 全连接层(Fully Connected Layer)

在卷积和池化操作后,特征图会被展平并输入到全连接层,全连接层用于将提取的特征映射到最终的输出(如分类结果),
卷积和池化操作可以多次操作。

2 框架模型搭建


2.1 框架确定

包含 3 个卷积块和 1 个全连接层,用于 MNIST 手写数字分类任务。卷积块包括卷积、激活以及池化操作。
卷积块1:卷积、激活、池化
卷积块2:卷积、激活、卷积、激活、池化
卷积块3:卷积、激活

2.2 框架函数定义

  • 卷积块:nn.Sequential(),括号内可以进行卷积、激活以及池化操作
  • 卷积:nn.Conv2d(in_channels=1,out_channels=16,kernel_size=5,stride=1,padding=2)
    • in_channels=1,输入通道数灰度图像为 1,rgb为3
    • out_channels=16,输出通道数
    • kernel_size=5,卷积核大小(5x5)
    • stride=1,步幅
    • padding=2 ,边缘填充数,为 2,保证输出尺寸与输入尺寸相同
  • 激活函数:nn.ReLU(),ReLU 激活函数
  • 池化:nn.MaxPool2d(kernel_size=2), 最大池化层
    • kernel_size=2,池化核大小为 2x2
  • 全连接:nn.Linear(64 * 7 * 7, 10)(输入个数,输出个数)
  • 展平操作: x.view(x.size(0), -1),将特征图展平,为一维向量
    • x.size(0):批次大小
    • -1:自动计算展平后的维度
class CNN(nn.Module):def __init__(self):super(CNN,self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(in_channels=1,out_channels=16,kernel_size=5,stride=1,padding=2,),nn.ReLU(),nn.MaxPool2d(kernel_size=2),)self.conv2 = nn.Sequential(nn.Conv2d(16,32,5,1,2),nn.ReLU(),nn.Conv2d(32, 32, 5, 1, 2),nn.ReLU(),nn.MaxPool2d(2),)self.conv3 = nn.Sequential(nn.Conv2d(32, 64, 5, 1, 2),nn.ReLU(),)self.out = nn.Linear(64*7*7,10)def forward(self,x):x =self.conv1(x)x = self.conv2(x)x = self.conv3(x)x = x.view(x.size(0),-1)output = self.out(x)return output

3 代码测试

代码展示:

import torchprint(torch.__version__)
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensortrain_data = datasets.MNIST(root = 'data',train = True,download = True,transform = ToTensor()
)
test_data = datasets.MNIST(root = 'data',train = False,download = True,transform = ToTensor()
)
print(len(train_data))
print(len(test_data))train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader= DataLoader(test_data, batch_size=64)
device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
print(f'Using {device} device')class CNN(nn.Module):def __init__(self):super(CNN,self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(in_channels=1,out_channels=16,kernel_size=5,stride=1,padding=2,),nn.ReLU(),nn.MaxPool2d(kernel_size=2),)self.conv2 = nn.Sequential(nn.Conv2d(16,32,5,1,2),nn.ReLU(),nn.Conv2d(32, 32, 5, 1, 2),nn.ReLU(),nn.MaxPool2d(2),)self.conv3 = nn.Sequential(nn.Conv2d(32, 64, 5, 1, 2),nn.ReLU(),)self.out = nn.Linear(64*7*7,10)def forward(self,x):x =self.conv1(x)x = self.conv2(x)x = self.conv3(x)x = x.view(x.size(0),-1)output = self.out(x)return outputmodel = CNN().to(device)
print(model)
optimizer = torch.optim.Adam(model.parameters(),lr=0.001)
loss_fn = nn.CrossEntropyLoss()
def train(dataloader,model,loss_fn,optimizer):model.train()batch_size_num = 1for x,y in dataloader:x,y = x.to(device),y.to(device)pred = model.forward(x)loss = loss_fn(pred,y)optimizer.zero_grad()loss.backward()optimizer.step()loss_value = loss.item()if batch_size_num %100 ==0:print(f'loss: {loss_value:>7f}  [number: {batch_size_num}]')batch_size_num +=1def test(dataloader,model,loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)model.eval()test_loss,correct = 0,0with torch.no_grad():for x,y in dataloader:x,y = x.to(device),y.to(device)pred = model.forward(x)test_loss += loss_fn(pred,y).item()correct +=(pred.argmax(1) == y).type(torch.float).sum().item()a = (pred.argmax(1)==y)b = (pred.argmax(1)==y).type(torch.float)test_loss /=num_batchescorrect /= sizeprint(f'test result: \n Accuracy: {(100*correct)}%, Avg loss:{test_loss}')
e = 8
for i in range(e):print(f'e: {i+1}\n------------------')train(train_dataloader, model, loss_fn, optimizer)
print('done')test(test_dataloader, model, loss_fn)

运行结果:

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/18812.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【深度强化学习】Actor-Critic 算法

本书之前的章节讲解了基于值函数的方法(DQN)和基于策略的方法(REINFORCE),其中基于值函数的方法只学习一个价值函数,而基于策略的方法只学习一个策略函数。那么,一个很自然的问题是,…

数据结构——二叉树(2025.2.12)

目录 一、树 1.定义 (1)树的构成 (2)度 2.二叉树 (1)定义 (2)二叉树的遍历 (3)遍历特性 二、练习 1.二叉树 (1)创建二叉树…

安科瑞光伏发电防逆流解决方案——守护电网安全,提升能源效率

安科瑞 华楠 18706163979 在当今大力发展清洁能源的时代背景下,光伏发电作为一种可持续的能源解决方案, 正得到越来越广泛的应用。然而,光伏发电过程中出现的逆流问题,给电网的安全稳定 运行带来了诸多挑战。若不能有效解决&…

3、树莓派5 安装VNC查看器 开启VNC服务器

在前序文章中( 2、树莓派5第一次开机),可以使用三种方式开机,其中使用网线及wifi的方式均需要使用到VNC查看器进行远程桌面控制,本文将介绍如何下载安装并配置及使用VNC查看器及服务器,对前序文章做一些补充…

牛客周赛 Round 80

前言 这场比赛是很有意思的,紧跟时事IG杯,大卞"神之举手",0胜拿下比赛,我当时也是完整的看完三场比赛,在第二次说直接两次罚下的时候我真是直接暴起了,然后第三场当时我正在吃饭,看到…

文档格式转换引擎开发:支持PDF与OFD的技术实现

最新技术资源(建议收藏) https://www.grapecity.com.cn/resources/ 前言 近年来,中国在信息技术领域持续追求自主创新和供应链安全,伴随信创上升为国家战略,一些行业也开始明确要求文件导出的格式必须为 OFD 格式。OF…

VSCode Error Lens插件介绍(代码静态检查与提示工具)(vscode插件)

文章目录 VSCode Error Lens 插件介绍**功能概述****开发背景****使用方法****适用场景** VSCode Error Lens 插件介绍 功能概述 Error Lens 是一款增强 VS Code 错误提示的扩展工具,通过 内联显示错误和警告信息,直接定位代码问题,提升开发…

快速幂(算法)的原理

快速幂算法 快速幂数学原理算法实现OJ题展示不用高精度计算二进制指数的高精度计算数学题等差数列和等比数列计数原理 快速幂 求 ( a b ) % n (a^b)\%n (ab)%n的结果(即 a a a的 b b b次方,再除以 n n n得到的余数)。 利用程序求解时&#…

无人机遥感在农林信息提取中的实现方法与GIS融合应用

在新一轮互联网信息技术大发展的现今,无人机、大数据、人工智能、物联网等新兴技术在各行各业都处于大爆发的前夜。为了将人工智能方法引入农业生产领域。首先在种植、养护等生产作业环节,逐步摆脱人力依赖;在施肥灌溉环节构建智慧节能系统&a…

Android设备 网络安全检测

八、网络与安全机制 6.1 网络框架对比 volley: 功能 基于HttpUrlConnection;封装了UIL图片加载框架,支持图片加载;网络请求的排序、优先级处理缓存;多级别取消请求;Activity和生命周期的联动(Activity结束生命周期同时取消所有网络请求 …

【油猴脚本/Tampermonkey】DeepSeek 服务器繁忙无限重试(20250214优化)

目录 一、 引言 二、 逻辑 三、 源代码 四、 添加新脚本 五、 使用 六、 BUG 七、 优化日志 1.获取最后消息内容报错 2.对话框切换无法正常使用 一、 引言 deepseek演都不演了,每次第一次提问就正常,后面就开始繁忙了,有一点阴招全…

C++ Primer 函数重载

欢迎阅读我的 【CPrimer】专栏 专栏简介:本专栏主要面向C初学者,解释C的一些基本概念和基础语言特性,涉及C标准库的用法,面向对象特性,泛型特性高级用法。通过使用标准库中定义的抽象设施,使你更加适应高级…

【c++初阶】类和对象②默认成员函数以及运算符重载初识

目录 ​编辑 默认成员函数: 构造函数 构造函数的特性: 析构函数: 拷贝构造函数: 1. 拷贝构造函数是构造函数的一个重载形式。 2. 拷贝构造函数的参数只有一个且必须是类类型对象的引用,使用传值方式编译器直接报…

基于AIOHTTP、Websocket和Vue3一步步实现web部署平台,无延迟控制台输出,接近原生SSH连接

背景:笔者是一名Javaer,但是最近因为某些原因迷上了Python和它的Asyncio,至于什么原因?请往下看。在着迷”犯浑“的过程中,也接触到了一些高并发高性能的组件,通过简单的学习和了解,aiohttp这个…

【鸿蒙HarmonyOS Next实战开发】lottie动画库

简介 lottie是一个适用于OpenHarmony和HarmonyOS的动画库,它可以解析Adobe After Effects软件通过Bodymovin插件导出的json格式的动画,并在移动设备上进行本地渲染。 下载安裝 ohpm install ohos/lottieOpenHarmony ohpm 环境配置等更多内容&#xff0c…

UE_C++ —— UObject Instance Creation

目录 一,UObject Instance Creation NewObject NewNamedObject ConstructObject Object Flags 二,Unreal Object Handling Automatic Property Initialization Automatic Updating of References Serialization Updating of Property Values …

PHP本地商家卡券管理系统

本地商家卡券管理系统 —— 引领智慧消费新时代 本地商家卡券管理系统,是基于ThinkPHPUni-appuView尖端技术匠心打造的一款微信小程序,它彻底颠覆了传统优惠方式,开创了多商家联合发行优惠卡、折扣券的全新模式,发卡类型灵活多变…

什么是HTTP Error 429以及如何修复

为了有效管理服务器资源并确保所有用户都可以访问,主机提供商一般都会对主机的请求发送速度上做限制,一旦用户在规定时间内向服务器发送的请求超过了允许的限额,就可能会出现429错误。 例如,一个API允许每个用户每小时发送100个请…

无人机不等同轴旋翼架构设计应用探究

“结果显示,对于不等组合,用户应将较小的螺旋桨置于上游以提高能效,但若追求最大推力,则两个相等的螺旋桨更为理想。” 在近期的研究《不等同轴旋翼性能特性探究》中,Max Miles和Stephen D. Prior博士深入探讨了不同螺…

节目选择器安卓软件编写(针对老年人)

文章目录 需求来源软件界面演示效果源码获取 对爬虫、逆向感兴趣的同学可以查看文章,一对一小班教学:https://blog.csdn.net/weixin_35770067/article/details/142514698 需求来源 由于现在的视频软件过于复杂,某客户想开发一个针对老年人、…