计算机视觉-LeNet

目录

LeNet

LeNet在手写数字识别上的应用

LeNet在眼疾识别数据集iChallenge-PM上的应用


LeNet

LeNet是最早的卷积神经网络之一。1998年,Yann LeCun第一次将LeNet卷积神经网络应用到图像分类上,在手写数字识别任务中取得了巨大成功。LeNet通过连续使用卷积和池化层的组合提取图像特征,其架构如 图1 所示,这里展示的是用于MNIST手写体数字识别任务中的LeNet-5模型:
 


图1:LeNet模型网络结构示意图


 

  • 第一模块:包含5×5的6通道卷积和2×2的池化。卷积提取图像中包含的特征模式(激活函数使用Sigmoid),图像尺寸从28减小到24。经过池化层可以降低输出特征图对空间位置的敏感性,图像尺寸减到12。

  • 第二模块:和第一模块尺寸相同,通道数由6增加为16。卷积操作使图像尺寸减小到8,经过池化后变成4。

  • 第三模块:包含4×4的120通道卷积。卷积之后的图像尺寸减小到1,但是通道数增加为120。将经过第3次卷积提取到的特征图输入到全连接层。第一个全连接层的输出神经元的个数是64,第二个全连接层的输出神经元个数是分类标签的类别数,对于手写数字识别的类别数是10。然后使用Softmax激活函数即可计算出每个类别的预测概率。


【提示】:

卷积层的输出特征图如何当作全连接层的输入使用呢?

卷积层的输出数据格式是[N,C,H,W][N, C, H, W][N,C,H,W],在输入全连接层的时候,会自动将数据拉平,

也就是对每个样本,自动将其转化为长度为KKK的向量,


LeNet在手写数字识别上的应用

LeNet网络的实现代码如下:

# 导入需要的包
import paddle
import numpy as np
from paddle.nn import Conv2D, MaxPool2D, Linear## 组网
import paddle.nn.functional as F# 定义 LeNet 网络结构
class LeNet(paddle.nn.Layer):def __init__(self, num_classes=1):super(LeNet, self).__init__()# 创建卷积和池化层# 创建第1个卷积层self.conv1 = Conv2D(in_channels=1, out_channels=6, kernel_size=5)self.max_pool1 = MaxPool2D(kernel_size=2, stride=2)# 尺寸的逻辑:池化层未改变通道数;当前通道数为6# 创建第2个卷积层self.conv2 = Conv2D(in_channels=6, out_channels=16, kernel_size=5)self.max_pool2 = MaxPool2D(kernel_size=2, stride=2)# 创建第3个卷积层self.conv3 = Conv2D(in_channels=16, out_channels=120, kernel_size=4)# 尺寸的逻辑:输入层将数据拉平[B,C,H,W] -> [B,C*H*W]# 输入size是[28,28],经过三次卷积和两次池化之后,C*H*W等于120self.fc1 = Linear(in_features=120, out_features=64)# 创建全连接层,第一个全连接层的输出神经元个数为64, 第二个全连接层输出神经元个数为分类标签的类别数self.fc2 = Linear(in_features=64, out_features=num_classes)# 网络的前向计算过程def forward(self, x):x = self.conv1(x)# 每个卷积层使用Sigmoid激活函数,后面跟着一个2x2的池化x = F.sigmoid(x)x = self.max_pool1(x)x = F.sigmoid(x)x = self.conv2(x)x = self.max_pool2(x)x = self.conv3(x)# 尺寸的逻辑:输入层将数据拉平[B,C,H,W] -> [B,C*H*W]x = paddle.reshape(x, [x.shape[0], -1])x = self.fc1(x)x = F.sigmoid(x)x = self.fc2(x)return x

飞桨会根据实际图像数据的尺寸和卷积核参数自动推断中间层数据的W和H等,只需要用户表达通道数即可。下面的程序使用随机数作为输入,查看经过LeNet-5的每一层作用之后,输出数据的形状。

# 输入数据形状是 [N, 1, H, W]
# 这里用np.random创建一个随机数组作为输入数据
x = np.random.randn(*[3,1,28,28])
x = x.astype('float32')# 创建LeNet类的实例,指定模型名称和分类的类别数目
model = LeNet(num_classes=10)
# 通过调用LeNet从基类继承的sublayers()函数,
# 查看LeNet中所包含的子层
print(model.sublayers())
x = paddle.to_tensor(x)
for item in model.sublayers():# item是LeNet类中的一个子层# 查看经过子层之后的输出数据形状try:x = item(x)except:x = paddle.reshape(x, [x.shape[0], -1])x = item(x)if len(item.parameters())==2:# 查看卷积和全连接层的数据和参数的形状,# 其中item.parameters()[0]是权重参数w,item.parameters()[1]是偏置参数bprint(item.full_name(), x.shape, item.parameters()[0].shape, item.parameters()[1].shape)else:# 池化层没有参数print(item.full_name(), x.shape)

卷积Conv2D的padding参数默认为0,stride参数默认为1,当输入形状为[Bx1x28x28]时,B是batch_size,经过第一层卷积(kernel_size=5, out_channels=6)和maxpool之后,得到形状为[Bx6x12x12]的特征图;经过第二层卷积(kernel_size=5, out_channels=16)和maxpool之后,得到形状为[Bx16x4x4]的特征图;经过第三层卷积(out_channels=120, kernel_size=4)之后,得到形状为[Bx120x1x1]的特征图,在FC层计算之前,将输入特征从卷积得到的四维特征reshape到格式为[B, 120x1x1]的特征,这也是LeNet中第一层全连接层输入shape为120的原因。

# -*- coding: utf-8 -*-
# LeNet 识别手写数字
import os
import random
import paddle
import numpy as np
import paddle
from paddle.vision.transforms import ToTensor
from paddle.vision.datasets import MNIST# 定义训练过程
def train(model, opt, train_loader, valid_loader):# 开启0号GPU训练use_gpu = Truepaddle.device.set_device('gpu:0') if use_gpu else paddle.device.set_device('cpu')print('start training ... ')model.train()for epoch in range(EPOCH_NUM):for batch_id, data in enumerate(train_loader()):img = data[0]label = data[1] # 计算模型输出logits = model(img)# 计算损失函数loss_func = paddle.nn.CrossEntropyLoss(reduction='none')loss = loss_func(logits, label)avg_loss = paddle.mean(loss)if batch_id % 2000 == 0:print("epoch: {}, batch_id: {}, loss is: {:.4f}".format(epoch, batch_id, float(avg_loss.numpy())))avg_loss.backward()opt.step()opt.clear_grad()model.eval()accuracies = []losses = []for batch_id, data in enumerate(valid_loader()):img = data[0]label = data[1] # 计算模型输出logits = model(img)pred = F.softmax(logits)# 计算损失函数loss_func = paddle.nn.CrossEntropyLoss(reduction='none')loss = loss_func(logits, label)acc = paddle.metric.accuracy(pred, label)accuracies.append(acc.numpy())losses.append(loss.numpy())print("[validation] accuracy/loss: {:.4f}/{:.4f}".format(np.mean(accuracies), np.mean(losses)))model.train()# 保存模型参数paddle.save(model.state_dict(), 'mnist.pdparams')# 创建模型
model = LeNet(num_classes=10)
# 设置迭代轮数
EPOCH_NUM = 5
# 设置优化器为Momentum,学习率为0.001
opt = paddle.optimizer.Momentum(learning_rate=0.001, momentum=0.9, parameters=model.parameters())
# 定义数据读取器
train_loader = paddle.io.DataLoader(MNIST(mode='train', transform=ToTensor()), batch_size=10, shuffle=True)
valid_loader = paddle.io.DataLoader(MNIST(mode='test', transform=ToTensor()), batch_size=10)
# 启动训练过程
train(model, opt, train_loader, valid_loader)

通过运行结果可以看出,LeNet在手写数字识别MNIST验证数据集上的准确率高达92%以上。那么对于其它数据集效果如何呢?我们通过眼疾识别数据集iChallenge-PM验证一下。

LeNet在眼疾识别数据集iChallenge-PM上的应用

iChallenge-PM是百度大脑和中山大学中山眼科中心联合举办的iChallenge比赛中,提供的关于病理性近视(Pathologic Myopia,PM)的医疗类数据集,包含1200个受试者的眼底视网膜图片,训练、验证和测试数据集各400张。下面我们详细介绍LeNet在iChallenge-PM上的训练过程。


说明:

如今近视已经成为困扰人们健康的一项全球性负担,在近视人群中,有超过35%的人患有重度近视。近视会拉长眼睛的光轴,也可能引起视网膜或者络网膜的病变。随着近视度数的不断加深,高度近视有可能引发病理性病变,这将会导致以下几种症状:视网膜或者络网膜发生退化、视盘区域萎缩、漆裂样纹损害、Fuchs斑等。因此,及早发现近视患者眼睛的病变并采取治疗,显得非常重要。

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/115609.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Dubbo 应用切换 ZooKeeper 注册中心实例,流量无损迁移

首先思考一个问题:如果 Dubbo 应用使用 ZooKeeper 作为注册中心,现在需要切换到新的 ZooKeeper 实例,如何做到流量无损? 本文提供解决这个问题的一种方案。 场景 有两个基于 Dubbo 的微服务应用,一个是服务提供者&…

CXL 内存交织(Memory Interleaving)

🔥点击查看精选 CXL 系列文章🔥 🔥点击进入【芯片设计验证】社区,查看更多精彩内容🔥 📢 声明: 🥭 作者主页:【MangoPapa的CSDN主页】。⚠️ 本文首发于CSDN&#xff0c…

Java 读取TIFF JPEG GIF PNG PDF

Java 读取TIFF JPEG GIF PNG PDF 本文解决方法基于开源 tesseract 下载适合自己系统版本的tesseract ,官网链接:https://digi.bib.uni-mannheim.de/tesseract/ 2. 下载之后安装,安装的时候选择选择语言包,我选择了中文和英文 3.…

恒运资本:股票跌100%后怎么办?

在股票市场里,股票价格跌涨是日常的现象,有时候涨到令人惊喜,有时候却一路跌向谷底。股价跌到零的情况并不常见,可是,假如是跌了100%怎么办呢? 在探究该问题前,咱们需要了解股票跌100%是怎样的…

微服务之Nacos

1 版本说明 官网地址: https://github.com/alibaba/spring-cloud-alibaba/wiki/%E7%89%88%E6%9C%AC%E8%AF%B4%E6%98%8E 1.1 2021.x 分支 适配 SpringBoot 2.4, Spring Cloud 2021.x 版本及以上的Spring Cloud Alibaba 版本如下表(最新版本用*标记&am…

腾讯音乐如何基于大模型 + OLAP 构建智能数据服务平台

本文导读: 当前,大语言模型的应用正在全球范围内引发新一轮的技术革命与商业浪潮。腾讯音乐作为中国领先在线音乐娱乐平台,利用庞大用户群与多元场景的优势,持续探索大模型赛道的多元应用。本文将详细介绍腾讯音乐如何基于 Apach…

LeetCode-455-分发饼干-贪心算法

题目描述: 假设你是一位很棒的家长,想要给你的孩子们一些小饼干。但是,每个孩子最多只能给一块饼干。 对每个孩子 i,都有一个胃口值 g[i],这是能让孩子们满足胃口的饼干的最小尺寸;并且每块饼干 j&#xff…

企业级数据共享规模化模式

数据共享正在成为企业数据战略的重要元素。对于公司而言,Amazon Data Exchange 这样的亚马逊云科技服务提供了与其他公司共享增值数据或从这些数据获利的途径。一些企业希望有一个数据共享平台,他们可以在该平台上建立协作和战略方法,在封闭、…

联邦学习FedAvg-基于去中心化数据的深度网络高效通信学习

随着计算机算力的提升,机器学习作为海量数据的分析处理技术,已经广泛服务于人类社会。 然而,机器学习技术的发展过程中面临两大挑战:一是数据安全难以得到保障,隐私泄露问题亟待解决;二是网络安全隔离和行业…

使用飞桨实现的第一个AI项目——波士顿的房价预测

part1.首先引入相应的函数库: 值得说明的地方: (1)首先,numpy是一个python库,主要用于提供线性代数中的矩阵或者多维数组的运算函数,利用import numpy as np引入numpy,并将np作为它的别名 part…

linux字符串处理

目录 1. C 截取字符串,截取两个子串中间的字符串linux串口AT指令 2. 获取该字符串后面的字符串用 strstr() 函数查找需要提取的特定字符串,然后通过指针运算获取该字符串后面的字符串用 strtok() 函数分割字符串,找到需要提取的特定字符串后,…

如何在小程序中给会员设置备注

给会员设置备注是一项非常有用的功能,它可以帮助商家更好地管理和了解自己的会员。下面是一个简单的教程,告诉商家如何在小程序中给会员设置备注。 1. 找到指定的会员卡。在管理员后台->会员管理处,找到需要设置备注的会员卡。也支持对会…

宠物赛道,用AI定制宠物头像搞钱项目教程

今天给大家介绍一个非常有趣,而粉丝价值又极高,用AI去定制宠物头像或合照的AI项目。 接触过宠物行业应该知道,获取1位铲屎官到私域,这类用户的价值是极高的,一个宠物粉,是连铲个屎都要花钱的,每…

USRP 简介,对于NI软件无线电你所需要了解的一切

什么是 USRP 通用软件无线电外设( USRP ) 是由 Ettus Research 及其母公司National Instruments设计和销售的一系列软件定义无线电。USRP 产品系列由Matt Ettus领导的团队开发,被研究实验室、大学和业余爱好者广泛使用。 大多数 USRP 通过以太网线连接到主机&…

本地部署 Stable Diffusion(Mac 系统)

在 Mac 系统本地部署 Stable Diffusion 与在 Windows 系统下本地部署的方法本质上是差不多的。 一、安装 Homebrew Homebrew 是一个流行的 macOS (或 Linux)软件包管理器,用于自动下载、编译和安装各种命令行工具和应用程序。有关说明请访问官…

【分享】PDF如何拆分成2个或多个文件呢?

当我们需要把一个多页的PDF文件拆分成2个或多个独立的PDF文件,可以怎么操作呢?这种情况需要使用相关工具,下面小编就来分享两个常用的工具。 1. PDF编辑器 PDF编辑器不仅可以用来编辑PDF文件,还具备多种功能,拆分PDF文…

GPT-4.0技术大比拼:New Bing与ChatGPT,哪个更适合你

随着GPT-4.0技术的普及和发展,越来越多的平台开始将其应用于各种场景。New Bing已经成功接入GPT-4.0,并将其融入搜索和问答等功能。同样,在ChatGPT官网上,用户只需开通Plus账号,即可体验到GPT-4.0带来的智能交流和信息…

使用flink sqlserver cdc 同步数据到StarRocks

前沿: flink cdc功能越发强大,支持的数据源也越多,本篇介绍使用flink cdc实现: sqlserver-》(using flink cdc)-〉flink -》(using flink starrocks connector)-〉starrocks整个流程…

小游戏分发平台如何以技术拓流?

2023年,小游戏的发展将受到多方面的影响,例如新技术的引入、参与小游戏的新玩家以及游戏市场的激烈竞争等。首先,新技术如虚拟现实(VR)、增强现实(AR)和机器人技术都可以带来新颖的游戏体验。其…

嘉泰实业和您共创未来财富生活

每一次暖心的沟通都是一次公益,真诚不会因为它的渺小而被忽略;每一声问候都是一次公益,善意不会因为它的普通而被埋没。熟悉嘉泰实业的人都知道,这家企业不但擅长在金融理财领域里面呼风唤雨,同时也非常擅长在公益事业当中践行,属于企业的责任心,为更多有困难的群体带来大爱的传…