基于深度学习的图片识别系统(下)

文章目录

  • 前言
  • 1.任务描述
  • 2.模型搭建
  • 3.代码解释
    • 3.1模型加载
    • 3.2加载数据
    • 3.3模型权重的保存
    • 3.4学习率
    • 3.5过拟合
    • 3.6训练模型
    • 3.7调试检查
  • 4.结果分析
  • 5. 完整代码
  • 结语

前言

书接上回,我们已经完成数据预处理部分的内容,后续仍需要对表格进行裁剪,此处略去该操作,接着我们需要建立模型,训练模型等操作,此处采用基于密集卷积网络的文本识别模型,并结合CTC损失函数,下面是进一步的解释。

1.任务描述

这里简单说一下我们的任务,以及各文件的作用,以及路径信息。

  • src/train_imgs/samples_images/0_3827351228301812241814.jpg:表示训练集中的第一条数据,大小200*32,其中训练集共计3000张图片, 结果如下:
    在这里插入图片描述
  • src/data_3_train.txt存储相应的图片的类别标签,前半部分表示路径信息,后半部分表示标签在这里插入图片描述
  • src/labels.txt:用来存储标签对应的数字类别
    在这里插入图片描述
    注意:这里是有个错位关系的,因为第零行是空的,后续会做相关解释

同理,测试集也是如此道理,文件结构如下图所示:
在这里插入图片描述

2.模型搭建

密集卷积网络(DenseNet)是深度残差网络(ResNet)的特例,两者主要解决的是深度网络梯度消失和梯度爆炸的问题,随着网络层数的增加,网络回传过程会带来梯度弥散问题,经过几轮后的反传的梯度会彻底消失。

两者区别主要在于:

  • ResNet:通过残差连接(加法)实现特征复用,将输入直接传递到输出。
  • DenseNet:通过密集连接(拼接)实现更高效的特征复用,所有层的特征被重复利用.
    具体网络结构如下:
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Dropout, Activation, Reshape, Permute
from tensorflow.keras.layers import Conv2D, Conv2DTranspose, ZeroPadding2D
from tensorflow.keras.layers import AveragePooling2D, GlobalAveragePooling2D
from tensorflow.keras.layers import Input, Flatten
from tensorflow.keras.layers import concatenate
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.regularizers import l2
from tensorflow.keras.layers import TimeDistributed
def conv_block(input, growth_rate, dropout_rate=None, weight_decay=1e-4):x = BatchNormalization(axis=-1, epsilon=1.1e-5)(input)x = Activation('relu')(x)x = Conv2D(growth_rate, (3,3), kernel_initializer='he_normal', padding='same')(x)if(dropout_rate):x = Dropout(dropout_rate)(x)return xdef dense_block(x, nb_layers, nb_filter, growth_rate, droput_rate=0.4,
weight_decay=1e-4):for i in range(nb_layers):cb = conv_block(x, growth_rate, droput_rate, weight_decay)x = concatenate([x, cb], axis=-1)nb_filter += growth_ratereturn x, nb_filterdef transition_block(input, nb_filter, dropout_rate=None, pooltype=1, weight_decay=1e-4):x = BatchNormalization(axis=-1, epsilon=1.1e-5)(input)x = Activation('relu')(x)x = Conv2D(nb_filter, (1, 1), kernel_initializer='he_normal', padding='same', use_bias=False,kernel_regularizer=l2(weight_decay))(x)if(dropout_rate):x = Dropout(dropout_rate)(x)if(pooltype == 2):x = AveragePooling2D((2, 2), strides=(2, 2))(x)elif(pooltype == 1):x = ZeroPadding2D(padding = (0, 1))(x)x = AveragePooling2D((2, 2), strides=(2, 1))(x)elif(pooltype == 3):x = AveragePooling2D((2, 2), strides=(2, 1))(x)return x, nb_filterdef dense_cnn(input, nclass):_dropout_rate = 0.4_weight_decay = 1e-4_nb_filter = 64# conv 64 5*5 s=2x = Conv2D(_nb_filter, (5, 5), strides=(2, 2), kernel_initializer='he_normal',padding='same', use_bias=False, kernel_regularizer=l2(_weight_decay))(input)# 64 + 8 * 8 = 128x, _nb_filter = dense_block(x, 8, _nb_filter, 8, None, _weight_decay)# 128x, _nb_filter = transition_block(x, 128, _dropout_rate, 2, _weight_decay)# 128 + 8 * 8 = 192x, _nb_filter = dense_block(x, 8, _nb_filter, 8, None, _weight_decay)# 192 -> 128x, _nb_filter = transition_block(x, 128, _dropout_rate, 2, _weight_decay)# 128 + 8 * 8 = 192x, _nb_filter = dense_block(x, 8, _nb_filter, 8, None, _weight_decay)x = BatchNormalization(axis=-1, epsilon=1.1e-5)(x)x = Activation('relu')(x)x = Permute((2, 1, 3), name='permute')(x)x = TimeDistributed(Flatten(), name='flatten')(x)y_pred = Dense(nclass, name='out', activation='softmax')(x)# model = Model(inputs=input, outputs=y_pred)return y_pred

该网络主要由以下几个部分组成:

  • 卷积块(Conv Block):执行卷积操作,增加特征图的数量。
  • 密集块(Dense Block):由多个卷积块组成,每个卷积块的输出会与之前所有卷积块的输出连接起来。
  • 过渡块(Transition Block):用于减少特征图的数量,控制模型的复杂度。
  • 全连接层(Fully Connected Layer):将特征图转换为最终的分类结果。

该网络结构十分的复杂,这里我们将其保存在densenet.py后续如果使用该结构,直接import densenet即可。

3.代码解释

因为该网络结构十分复杂,同时个人对keras架构的网络并不是特别了解,这里通过断点调试的方式对代码进行解读.根据每个断点的内容对代码进行理解,最后会附上完整的代码。

3.1模型加载

在这里插入图片描述
首先定义了图片的大小,和每次训练处理图片的数量即batch_sizechar_set用于类别标签的对应关系,结果如上所示,下一步通过列表生成式对char_set做了更新,下述对列表推导式进行了解释。
在这里插入图片描述
得到的列表中,进行了切片操作,因为第一个值为空,又将结果加了一个特殊字符得到新的char_set,此处可以解释为什么第一行应该是空的,因为此处的切片下标从1开始。
在这里插入图片描述
这时通过结果可以看到char_set的值。reload_lib.reload(densenet) 这行代码的功能是重新加载 densenet 模块,个人觉得没太大必要。后续又定义了一个get_model函数用来获取模型。只传入了两个参数。
在这里插入图片描述
首先定义了输入层的数据类型,形状为 (img_h, None, 1),None 表示该维度的长度可变,这里是图像的宽度维度。批量大小也可以进行调整,因为也是None,预测输出的结果,是一个40维的向量。后续又定义了不定长的标签输入层以及长度。损失函数采用自定义的函数ctc_lambda_func
在这里插入图片描述
采用的是CTC损失函数,其中labels表示目标标签,y_pred表示预测标签,input_length表示输入序列的长度,label_length表示目标标签的长度

CTC 损失常用于处理序列数据,特别是在输入序列和目标序列长度不一致且没有明确的时间对齐关系时,比如语音识别和手写文字识别任务。

在这里插入图片描述
后续两步用来定义模型(输入包括图像输入、标签输入、输入长度输入和标签长度输入,输出为 CTC 损失)和编译模型(使用自定义的 CTC 损失函数(lambda y_true, y_pred: y_pred),优化器为 adam,评估指标为准确率),此后将模型进行返回。
这时可以查看模型结果和参数信息输出,结果如下:
在这里插入图片描述
在这里插入图片描述
通过结果可以看到,当前模型参数量很大,而且结构非常复杂

3.2加载数据

在训练前,需要将训练集和测试集进行加载,保存至迭代器中,在pytorch中就十分的很方便,直接Dataloader,而在keras中就很复杂。
在这里插入图片描述
这里首先定义了相关路径信息,加载数据集,此处使用了自定义的gen()函数。
在这里插入图片描述
其中readfile()也是自定义的函数体。
在这里插入图片描述
该函数相对简单,创建一个字典,将文件路径作为键,标签作为值。通过结果可以发现该数据量并不是3000而是2906条数据,中间可能那里少了一部分文件。
在这里插入图片描述
结果如图所示,得到上述所示的字典。
在这里插入图片描述
_imagefile用于存储文件的路径信息,即将上述所得的字典的键存储成列表形式。后续初始化了一些融入批量大小的输入。后面又定义了一个random_uniform_num函数,用来打乱顺序。
在这里插入图片描述
调用get()方法取出第一个批量,get()方法体内容很简单:
在这里插入图片描述
从属性index开始向后取一个批量大小,如果最后一部分不足一个批量,则取出最后的部分,并重新打乱,取出缺少的部分,将两者融合成一个批次大小。
在这里插入图片描述
shufimagefile获得对打乱获取的第一个批次图片的文件名,接着遍历一个批次的所有图片,首先通过os.path.join融合路径信息,获得图片的完整相对路径,并通过Image.open读取该图片,再通过convert('L')将图像转化为灰度图,接着通过一个简单的归一化处理,将图片像素值调整到[-0.5,0.5]之间,再将其保存至x[i]中,并将真实标签长度保存至label_length[i]中,真实的标签信息保存至label[i]中,input_length[i]大小是一个固定值,图片宽度大小除以8取整的结果。

在这里插入图片描述
遍历完成后将输入整合成一个字典inputoutputs输出(即每个图片的损失)表示为0,最后通过yield (inputs, outputs) 这行代码将一个批次的输入数据和输出数据作为一个元组返回,使 gen 函数成为一个生成器。
注:因为 gen 是生成器函数,调用 gen(…) 时不会执行函数体里的所有代码,只是返回一个生成器对象,gen(…)函数内部的代码并不会在此时执行,这里只是在这里解释该段代码的功能
在这里插入图片描述

3.3模型权重的保存

在这里插入图片描述
这段代码创建的 ModelCheckpoint 回调函数会在每一轮训练结束后,检查验证集上的损失值。如果当前的验证集损失值是迄今为止最优的,就将模型的权重保存到指定路径的文件中(即./models/weights_densenet-{epoch:02d}-{val_loss:.2f}.h5)。这样可以确保在训练过程中,即使出现过拟合或其他问题导致模型性能下降,也能保留验证集上表现最好的模型权重。

3.4学习率

这里的学习率并不是一个固定的值,而是随训练轮数不断减小的值。
在这里插入图片描述
该处定义了一个lambda函数,用来根据轮数计算学习率,得到如下的学习率列表。
在这里插入图片描述
接着通过LearningRateScheduler函数创建一个学习率调度器对象 changelr,用于在模型训练过程中根据训练轮数(epoch)动态调整学习率。

3.5过拟合

为了防止在训练中出现过拟合,创建一个EarlyStopping 回调函数实例,用于在模型训练过程中提前停止训练.
在这里插入图片描述
这段代码的意思是,在验证集上的损失值如果在连续 10 个训练轮次(epoch)中没有改善(即没有降低),就会停止训练,并告知用户训练已经提前停止。

3.6训练模型

在训练模型前,对训练集和测试集共有多少条数据进行了统计。
在这里插入图片描述
接着通过fit()对模型进行训练,下面对该函数的参数进行解释:
在这里插入图片描述

  1. train_loader:一个生成器对象,由 gen 函数生成。它会在训练过程中不断生成训练数据的批次,每个批次包含输入数据和对应的标签。
  2. steps_per_epoch:train_num_lines // batch_size 表示每个训练轮次(epoch)中需要执行的步数。train_num_lines 是训练集文件中的行数,即训练样本的总数;batch_size 是每个批次包含的样本数。通过整除运算得到每个轮次需要处理的批次数。
  3. epochs:8 表示模型训练的轮数,即模型会对整个训练集进行 8 次迭代训练。
  4. initial_epoch:0 表示训练开始的轮次编号。这里设置为 0,表示从第 0 轮开始训练。
  5. validation_data:test_loader 是一个生成器对象,由 gen 函数生成,用于提供验证数据。在每个训练轮次结束后,模型会使用验证数据进行评估,以监控模型在未见过的数据上的性能。
  6. validation_steps:test_num_lines // batch_size 表示在验证过程中需要执行的步数。test_num_lines 是测试集文件中的行数,即测试样本的总数;batch_size 是每个批次包含的样本数。通过整除运算得到验证过程中需要处理的批次数。
  7. callbacks:[checkpoint, earlystop, changelr] 是一个回调函数列表,包含了在训练过程中需要执行的额外操作。
  • checkpoint 是 ModelCheckpoint 回调函数,用于在验证集损失值最优时保存模型的- 权重。
  • earlystop 是 EarlyStopping 回调函数,用于在验证集损失值连续 10 个轮次没有改善时提前停止训练。
  • changelr 是 LearningRateScheduler 回调函数,用于根据训练轮数动态调整学习率。

3.7调试检查

在训练过程中,如果想观察具体的值,会遇见下述情况:
在这里插入图片描述
因为 TensorFlow 默认使用图执行模式(Graph Execution),可以通过在 model.compile 里设置 run_eagerly=True,这样模型会以动态图(Eager Execution)模式运行.
此时,观察相关变量的结果。
在这里插入图片描述
这时就能看到结果了,因为默认值labels为1.0e+04,即不是该值的就是正常标签。
在这里插入图片描述
可能因为版本不兼容问题,即可能是tensorflow版本太老了,PyCharm 调试器出现了一些关于进程的报错信息,直接运行不影响结果,因此没必要太担心,直接更新tensorflow又会导致当前环境中其他包不兼容,因此就没有解决该问题。

4.结果分析

这里由于网络结构十分复杂,同时该网络的大部分参数是采用正态分布(高斯分布)来随机初始化权重
在这里插入图片描述
所以效果可能很差,比如如下结果:
在这里插入图片描述
如果你想提高该模型的效果,你可以选择一些与训练好的模型来对网络参数进行初始化,或者调整轮数,自行训练,上述代码的训练轮数最高只能8轮,因为下述代码动态调整学习率的缘由
在这里插入图片描述
可以将上述代码中的changelr做一下调整,可得到最高训练轮数89

    changelr = LearningRateScheduler(lambda epoch: float(learning_rate[epoch//10]))

当然效果可能也不是很好,因为这种复杂的模型,训练轮数可能要成千上万,这里就不再展示,因为我的还是cpu版本的,训练速度慢,如果你的算力足够,可以将轮数调整到一个较大的值,因为检测到过拟合会自动停止训练。

5. 完整代码

可参考上述图片进行编写,调试过程对代码均已截图。如果无数据集,建议自行查找,即查找该书籍对应章节的数据集即可。

本案例参考教材:《Python机器学习实战案例(第二版)》赵卫东、董亮

结语

至此,该项目已经完成,虽然最终的效果可能并不太好,但主要理解其训练过程,撒花撒花!!!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/39986.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

再学:区块链基础与合约初探 EVM与GAS机制

目录 1.区块链是什么 2.remix ​3.账户​ ​4.以太坊三种交易​ 5.EVM 6.以太坊客户端节点 ​7.Gas费用 8.区块链浏览器 1.区块链是什么 只需要检验根节点 Merkel根是否有更改,就不用检查每个交易是否有更改。方便很多。 2.remix 3.账户 如果交易失败的话&…

Java 中装饰者模式与策略模式在埋点系统中的应用

前言 在软件开发中,装饰者模式和策略模式是两种常用的设计模式,它们在特定的业务场景下能够发挥巨大的作用。本文将通过一个实际的埋点系统案例,探讨如何在 Java 中运用装饰者模式和策略模式,以及如何结合工厂方法模式来优化代码…

HCIP_NOTE03_网络组成

网络组成 LAN MAN WAN 园区网 企业或机构内部的网络,分大中小型 行业园:企业园网 校园网 政务园 商业园 三层交换机 数据大量交换的局域网内部,转发效率高,有简单的路由功能 路由器 进出口网络,适用于复杂的网络环境,选路需求 无线网 信号传输稳定性差---- 电磁波易受干…

简记_单片机硬件最小系统设计

以STM32为例: 一、电源 1.1、数字电源 IO电源:VDD、VSS:1.8~3.6V,常用3.3V,去耦电容1 x 10u N x 100n ; 内核电源:内嵌的稳压器输出:1.2V,给内核、存储器、数字外设…

32.[前端开发-JavaScript基础]Day09-元素操作-window滚动-事件处理-事件委托

JavasScript事件处理 1 认识事件处理 认识事件(Event) 常见的事件列表 认识事件流 2 事件冒泡捕获 事件冒泡和事件捕获 事件捕获和冒泡的过程 3 事件对象event 事件对象 event常见的属性和方法 事件处理中的this 4 EventTarget使用 EventTarget类 5 事件委托模式 事件委托&am…

LeetCode hot 100 每日一题(15)——48.旋转图像

这是一道难度为中等的题目,让我们来看看题目描述: 给定一个 n n 的二维矩阵 matrix 表示一个图像。请你将图像顺时针旋转 90 度。 你必须在 原地 旋转图像,这意味着你需要直接修改输入的二维矩阵。请不要 使用另一个矩阵来旋转图像。 提示…

图灵300题-21~40-笔记002

图灵300题 图灵面试题视频:https://www.bilibili.com/video/BV17z421B7rB?spm_id_from333.788.videopod.episodes&vd_sourcebe7914db0accdc2315623a7ad0709b85&p20。 本文是学习笔记,如果需要面试没有时间阅读原博文,可以快速浏览笔…

09_从经典论文入手Seq2Seq架构

Sequence to Sequence 架构 Paper链接 Sequence to Sequence Learning with Neural Networks B站课程ShusenWang 核心思想 关键的改进点 In this paper, we show that a straightforward application of the Long Short-Term Memory (LSTM) architecture [16] can solve …

大疆上云api介绍

概述 目前对于 DJI 无人机接入第三方云平台,主要是基于 MSDK 开发定制 App,然后自己定义私有上云通信协议连接到云平台中。这样对于核心业务是开发云平台,无人机只是其中一个接入硬件设备的开发者来说,重新基于 MSDK 开发 App 工作量大、成本高,同时还需要花很多精力在无人…

3、孪生网络/连体网络(Siamese Network)

目的: 用Siamese Network (孪生网络) 解决Few-shot learning (小样本学习)。 Siamese Network并不是Meta Learning最好的方法, 但是通过学习Siamese Network,非常有助于理解其他Meta Learning算法。 这里介绍了两种方法:Siame…

OpenCV图像拼接(7)根据权重图对源图像进行归一化处理函数normalizeUsingWeightMap()

操作系统:ubuntu22.04 OpenCV版本:OpenCV4.9 IDE:Visual Studio Code 编程语言:C11 算法描述 cv::detail::normalizeUsingWeightMap 是 OpenCV 中用于图像拼接细节处理的一个函数。它根据权重图对源图像进行归一化处理,通常用于…

卷积神经网络 - AlexNet各层详解

AlexNet的层次化设计,使得 AlexNet 能够逐层提取从简单边缘到复杂图形的特征,同时结合归一化、池化和 Dropout 技术,有效提升了训练速度和泛化能力,成为推动深度学习发展的重要里程碑。本文我们来理解AlexNet各层的参数设置以及对…

【设计模式】工厂模式

首先了解一下什么是工厂方法模式? 工厂方法模式(Factory Method Pattern)是一种创建型设计模式,它提供了一种方法来封装对象的创建逻辑。具体来说,它通过定义一个创建对象的接口(即工厂方法)&a…

centos 7 部署FTP 服务用shell 脚本搭建

#!/bin/bash# 检查是否以root身份运行脚本 if [ "$EUID" -ne 0 ]; thenecho "请以root身份运行此脚本。"exit 1 fi# 安装vsftpd yum install -y vsftpd# 启动vsftpd服务并设置开机自启 systemctl start vsftpd systemctl enable vsftpd# 配置防火墙以允许F…

基于Spring Boot的个性化商铺系统的设计与实现(LW+源码+讲解)

专注于大学生项目实战开发,讲解,毕业答疑辅导,欢迎高校老师/同行前辈交流合作✌。 技术范围:SpringBoot、Vue、SSM、HLMT、小程序、Jsp、PHP、Nodejs、Python、爬虫、数据可视化、安卓app、大数据、物联网、机器学习等设计与开发。 主要内容:…

AI(DeepSeek、ChatGPT)、Python、ArcGIS Pro多技术融合下的空间数据分析、建模与科研绘图及论文写作

人工智能(AI)与ArcGIS Pro的结合,为空间数据处理和分析开辟了前所未有的创新路径。AI通过强大的数据挖掘、深度学习及自动化能力,可高效处理海量、多源、异构的空间数据,极大提升了分析效率与决策支持能力。而ArcGIS P…

2025最新3个wordpress好用的主题

红色大气的wordpress企业主题,适合服务行业的公司搭建企业官方网站使用。是一款专为中小企业和个人开发者设计的WordPress主题,旨在提供专业的网站构建解决方案。 通过此WordPress主题,用户可以轻松创建和维护一个专业的企业网站&#xff0c…

Spring AI Alibaba AudioModel使用

一、AudioModel简介 1、AudioModel 当前,Spring AI Alibaba 支持以下两种通义语音模型的适配,分别是: 文本生成语音 SpeechModel,对应于 OpenAI 的 Text-To-Speech (TTS) API录音文件生成文字 DashScopeAudioTranscriptionMode…

时隔多年,终于给它换了皮肤,并正式起了名字

时隔多年,终于更新了直播推流软件UI,并正式命名为FlashEncoder。软件仍使用MFC框架,重绘了所有用到的控件,可以有效保证软件性能,也便于后续进一步优化。 下载地址:https://download.csdn.net/download/Xi…

Python备赛笔记2

1.区间求和 题目描述 给定a1……an一共N个整数,有M次查询,每次需要查询区间【L,R】的和。 输入描述: 第一行包含两个数:N,M 第二行输入N个整数 接下来的M行,每行有两个整数,L R,中间用空格隔开&…