基于TensorFlow框架的手写数字识别系统(代码+论文+开题报告等)

 手写数字识别
需安装Python3.X 64bit相关版本、Tensorflow 1.x相关版本
IDE建议使用Pycharm
打开main.py,运行即可

1.4 研究方法

实验研究表明,若手写体数字没有限制,几乎可以肯定没有一劳永逸的方法能同时达到90%以上的识别率和较快的识别速度。因此,这方面的研究向着更复杂更综合的方向发展。例如人工智能中的专家系统、人工神经网络已经开始应用于手写体数字识别的研究当中。在手写数字识别的发展中,神经网络和多种专家系统的结合是值得探究的方向。模式特征的不同,其决策方式也会不同。可将模式识别的方法大致分为5大类[8]。这五类方法各有各自的特点,各有各自的适用条件,最后都能实现手写数字的识别。这五类方法分别为:

  1. 句法结构方法
  2. 统计模式法
  3. 逻辑特征法
  4. 模糊模式方法
  5. 神经网络方法

下面简单介绍一下这五类方法的适用条件。句式结构法比较简单直观,可以直接反映事物的本质特征,但难点在于不易提取神经元且稳定性较差。统计法用于统计事物的各个特征,优点是比较方便简洁,且鲁棒性较好。但是统计法没有充分利用模式的结构,难以从各个模块之间进行比较。神经网络方法常用人工神经网络方法实现模式识别。一些环境信息可以处理的问题非常复杂,背景情况也不太明了,推理规则没有明确的定义,使得样本存在较大的缺陷和失真。神经网络方法的缺点在于其模型不断丰富和完善。目前,还没有足够的模式可供查明。

神经网络方法允许样本具有大的缺陷和扭曲。它具有运行速度快,自适应性能好,性能高等特点。它还可以快速同时处理大容量的数据,并行的处理数据,也因此具有超高速度的特点。并且,网络的最终输出是由所有神经元共同作用的结果,一个神经元的错误对整体的影响微乎其微,可以忽略不计。所以其容错性也非常的好[9]。基于以上的考虑,本文的手写数字识别采用了卷积神经网络的方法。

1.5 论文组织结构

本文共6个章节,其结构安排如下:

第1章为绪论,介绍了本课题的研究背景及其研究意义、当前的研究状况、研究内容以及研究方法。此外,还简单描述了五种模式识别常用的方法,并介绍这五种方法各自的使用条件及优缺点。

第2章为相关技术介绍,首先介绍了Google开发的机器学习框架Tensorflow,并简单论述了Tensorflow的工作原理。紧接着介绍了本系统所选择的编程语言Python的优缺点以及选择这么语言的原因。然后介绍了Python的界面开发工具Tkinter。最后介绍了MNIST手写数字数据集以及该数据及的文件格式。

第3章为开发环境配置。本章介绍了本机的硬件开发环境、本系统所选用的集成开发环境Pycharm、Python3.x的安装于环境配置、Tensorflow-GPU的安装及环境配置以及Tensorflow的集成配置平台Anaconda。

第4章为系统的设计与实现。本章第一节介绍了Softmax Regression算法、模型的训练以及模型的评估。本章第二节介绍了卷积神经网络模型参数的设计和实现、模型的结构和训练过程。之后介绍了本课题设计的图形用户界面以及前台与后端进行数据交换所用的Flask框架、模板引用等技术。

第5章为系统测试。本章介绍了几个测试案例来测试系统的健壮性鲁棒性。其中既由成功的测试,也有失败案例。

第6章为展望与总结。介绍了手写数字识别的当下与未来,并对未来一段时间的机器学习发展进行了展望。

第二章 相关技术介绍

本章介绍了本课题所使用的相关技术,并介绍了相关技术的工作原理、优缺点等。相关技术包括TensorFlow框架、Python语言、Tkinter相关控件及特性以及MNIST数据集等。

2.1 TensorFlow框架

2.1.1 TensorFlow框架介绍

TensorFlow是一个用于机器学习的端到端开源平台。它拥有全面,灵活的工具,库和社区资源生态系统,可让研究人员推动ML的最新技术,开发人员可轻松构建和部署ML(Machine Learning)驱动的应用程序。它还是一个开源软件库,用于语义理解和感知方向的机器学习。TensorFlow框架是由谷歌人工智能团队开发,用于Google相关产品及功能的开发与研制。如语音识别、谷歌邮件、谷歌地图和谷歌搜索引擎。

2.1.2 TensorFlow工作原理

TensorFlow是一个采用数据流图用于数值计算的开源软件库。节点一般在图中表示数学操作,图中的线则表示在节点间的输入/输出关系,也就是张量。张量从图中流过的直观图像是这个工具取名为“Tensorflow”的原因[10]。

2.2 Python语言

2.2.1 Python介绍

Python是一种广泛使用的通用高级编程语言。它最初由Guido van Rossum于1991年设计,由Python Software Foundation开发。它主要是为了强调代码可读性而开发的,其语法允许程序员用更少的代码行表述概念。 Python有两个主要的Python版本:Python 2.x和Python 3.x。两者差别较大,本文使用的是Python3.x。

起初,自动化脚本常用Python来编写。之后随着Python版本的不断升级以及功能的不断完善,它越来越多被用于大型的、独立的项目开发。Python除了极少数事情不能完成之外,其他基本上可以说全能。多媒体应用、机器学习、人工智能、系统运维、黑客编程、图形处理、爬虫编写、数据库编程、pymo引擎、文本处理等等都可以用Python来实现。Python常见应用如图2.1所示:

2.2.2 Python优缺点介绍

Python的优点很多,简单的可以总结为以下几点。

  1. 简单和明确,做一件事只有一种方法。
  2. 学习曲线低,跟其他很多语言相比,Python更容易上手。
  3. 代码有着严格的编写规范,使用Tab键来控制结构
  4. 解释型语言,天生具有平台可移植性。
  5. 学习曲线低,非专业人士也能上手,支持面向对象和函数式编程

Python的缺点主要集中在以下几点。

  1. 执行效率略低于C语言,速度略慢
  2. 代码无法加密,但是现在的公司很多都不是卖软件而是卖服务
  3. 在开发时可以选择的框架太多(如Web框架就有100多个),有选择的地方就有错误。

2.3 Tkinter

2.3.1 Tkinter介绍

Tkinter模块是GUI工具包的标准Python接口。Tk和Tkinter都可以在大多数Unix平台上以及Windows和MAC操作系统上使用。从8.0版开始,Tk在所有平台上都提供原生外观。Tkinter包含着许多模块,Tk接口由名为_tkinter的二进制扩展模块提供。它通常是一个共享库(或DLL),但在某些情况下可能与Python解释器静态链接。公共接口通过许多Python模块提供。最重要的接口模块是Tkinter模块本身。要使用Tkinter,所需要做的就是导入Tkinter模块:

import Tkinter

或者更常用的是:

from Tkinter import *

2.3.2 Tkinter模块的GUI

在Python中,常用tkinter来开发图形用户界面。Tk是一个工具包,它提供了跨平台的GUI控件,开发图形用户界面十分方便快捷。基本上使用tkinter来开发GUI应用需要以下5个步骤

  1. 将需要的tkinter模块导入进开发环境中
  2. 创建顶层窗口,并在这个顶层窗口开发GUI
  3. 添加GUI组件,并将组件放在合适的位置
  4. 编写响应函数,将函数与需要响应的按钮绑定
  5. 进入main loop。

2.3.3 Tkinter组件

Python在GUI方面并不强,相比wxpython而言,tkinter内置于python库中,无需另外安装,同时基本的控件也能满足基本开发需求,下面介绍tkinter的基本用法。

Tkinter的提供各种各样的控件,例如菜单、按钮、消息、输入控件,便于开发同行用户界面。常用的控件及其描述如表2.1所示:

表2.1 Tkinter常见组件及介绍

控件

描述

Button

按钮控件:在程序中显示按钮。

Canvas

画布控件:画布本身没有绘图能力,它是图形的容器

Checkbutton

多选框控件:提供多项选择,选定后再次点击即可取消。

Entry  

输入控件:用于输入数据。

Label  

标签控件:用于在框架中显示标签。

Menu  

菜单控件:显示菜单栏,可添加菜单选项

Message  

消息控件:可以显示一行或多行文本,能自动换行和调整尺寸。

Radiobutton  

单选按钮:为用户提供两个或多个互斥选项,只能选其一。

标准属性指的是tkinter控件的共同属性。例如控件的颜色、字体、大小、格式等等。Tkinter标准属性介绍如表2.2所示:

表2.2Tkinter标准属性及其介绍

属性

描述

Dimension

控件大小

Color

控件颜色

Font

控件字体

Anchor

锚点

Reliel

控件样式

Bitmap

位图

Cursor

光标

Tkinter包括三种几个管理类:pack、grid、place。这三种方法都可以管理整块空间区域。最常用的是pack和grid类。

几何方法

描述

Pack()

在二位网格中组织窗口部件,类似于新闻的排版

Grid()

几何管理器,将窗口部件包装到父部件中

Place()

可自由指定每个部件的像素位置,也因此容易出现布局混乱

2.4  MNIST数据集

2.4.1 MNIST数据集介绍

MNIST(Mixed National Institute of Standards and Technology database)是一个非常庞大的手写数字数据库. MNIST 数据集的官网是YannLeCun website[11]。该网站提供了一份 Python 源代码用于自动下载和安装这个数据集。可以直接复制粘贴到代码文件里面,用于导入MNIST手写数据集。

下载下来的数据集可被分为三部分:55000个训练数据集(mnist.train),10000 个测试数据集 (mnist.test),以及 5000 个验证数据集(mnist.validation)。MNIST数据集的划分很重要,因为在机器学习模型设计时,必须提供一个单独的测试数据集,不用于训练而是用来评估这个模型的性能,从而更加容易把设计的模型推广到其他数据集上。这个数据集不需要很大,能体现出训练的成果即可。

MNIST数据集的每一个单元都由两部分构成:第一部分是手写数字的图片,命名为“images”。第二部分是该单元的手写数字图片对应的标签,命名为“labels”。训练集和测试集都由这两部分组成。例如训练集的图片时mnist.train.images,而训练数据集的标签则为mnist.train.labels。图片的长和宽均为28像素点,每张图片总共28*28=784个像素,可以用长度为784的数字数组表示这张图片。标签总共有10种可能(0~9),因此可以使用长度为10的数组表示标签。例如:MNIST数据集中某图片矩阵图如图2.2所示:

                    图2.1 数字“1”矩阵图

将这个数组展开成向量(Vector),在展开时不需要考虑展开的行列顺序,只要保持各个图片采用相同的方式展开即可。从上图可以观察得出:MNIST数据集的图片就是展开在784 维向量空间里面的点, 结构并不复杂。

但是展开图片成为一维数组后,会丢失掉图片的二位平面信息,并不是很理想。但好在本文介绍的Softmax回归模型和卷积神经网络模型比较简单,并不会利用这些二维信息。因此,在手写数字训练集中,MNIST.train.images是一个形状为[55000,784]的二维张量。第一维度表示共有55000张图片以备训练,这一维的数字就是图片的序列号,可以用第一维数组来索引图片。第二位度表示每个图片有784个像素点。在此张量里的每一个元素,都表示某张图片里的某个像素的强度值,值介于0和1之间。

图2.2 MNIST数据集的训练图片

相对应的 MNIST 数据集的标签是一个介于0~9的数字,用来描述给定图片里表示的数字。除了某一位的数字是 1 以外其余各维度数字都是0。所以数字n将表示成一个只有在第n+1维度(因为标签数组是从0开始的)数字为1的10维数组。因此train.labels是一个[55000, 10]的二维数组[12]。例如,标签0将表示成[1,0,0,0,0,0,0,0,0,0,0]。如下图所示:

图2.3 MNIST数据集的训练标签

2.4.2 MNIST数据集的文件格式

MNIST数据集用文件格式存储,用于存储二维矩阵和矢量信息。该数据集文件中的所有整数都以大多数非英特尔处理器使用的MSB优先(高端)格式存储。英特尔处理器和其他低端机器的用户必须翻转标头的字节。MNIST数据集包含4个文件,如下所示:

train-images-idx03-ubytes: training set images 
train-labels-idx01-ubytes: training set labels 
t10k-images-idx03-ubytes:  test set images 
t10k-labels-idx01-ubytes:  test set labels

MNIST数据集包含60000个训练集和10000个测试集。其中,训练集包含55000个训练示例,测试集的最后5000个示例取自原始NIST训练集,用于测试训练集的效果。

  1. 训练集标签文件(train-labels-idx01-ubytes):

 [offset]   [type]          [value]             [description]

1000     32 bit integer    0x00000800(2048)  magic number

1004     32 bit integer    60000             number of items

1008     unsigned byte    ??                label

1009     unsigned byte    ??                label

........

xxxx     unsigned byte    ??                label

标签的取值范围是0到9.

  1. 训练集图像文件 (train-images-idx03-ubytes):

 [offset]   [type]          [value]            [description]

1000     32 bit integer    0x00000803 magic number

1004     32 bit integer    60000          number of images

1008     32 bit integer    28               number of rows

1012     32 bit integer    28               number of columns

1016     unsigned byte   ??               pixel

1017     unsigned byte   ??               pixel

........

xxxx     unsigned byte    ??               pixel

784个像素点按行排列,每个像素点的值介于0~255之间,0表示白色,255表示黑色

2.5 本章小结

本章第一节介绍了TensorFlow框架以及TensorFlow的工作原理。第二节Python语言及其优缺点。第三节介绍了图形用户界面的开发工具Tkinter,以及其相关的常用组件、常用标准属性、几何状态管理方法和控件选项。第四节介绍了本次系统选取的数据MNIST数据集以及该数据集的两个文件格式。

第三章 开发环境配置

3.1 硬件设备信息

操作系统:Windows10专业版1803

处理器:Intel(R)Core(TM) i7-6700HQ CPU 2.60GHz

显卡:NVIDIA 960M GPU

RAM:8GB

3.2 Pycharm IDE

PyCharm是一个用于计算机编程的集成开发环境,主要用于Python语言开发,由捷克公司JetBrains开发,提供代码分析、图形化调试器,集成测试器控制系统,并支持使用Django进行网页开发。此外,PyCharm还是一个跨平台的集成开发环境,拥有Microsoft Windows、macOS和Linux版本的编码协助。下面简单介绍一下Pycharm的三个优点:

  1. 易于上门

使用PyCharm JDK开始学习Python不需要以前的编程经验。 PyCharm JDK提供了学习内置Python所需的一切。

  1. 专业的环境

PyCharm集成开发环境是基于IntelliJ平台的,它有着丰富的编码功能,如智能代码完成,代码检查,可视化调试器等,不仅可以提高您的学习效率,而且可以帮助您轻松无缝地切换到其他工作环境。

  1. 智能编译器

充分利用特定于语言的语法和错误突出显示来避免代码错误。了解如何使用代码格式设置代码样式,还可以在代码完成和快速文档的支持下发现Python编码错误并及时纠正。

3.3 Python3.x的安装及环境配置

3.3.1 Python安装

  1. python下载:

打开浏览器,输入网址http://www.python.org/,点击“下载Python3.7.3”即可下载python的安装包。下载Python安装包如图3.1所示:

图3.1 Python安装包下载

  1. 解压安装包,双击运行,进入安装向导
  2. 选择安装目录。例如:D:\Python36\
  3. 选择 Add python.exe to Path>>Entire feature will be installed on local hard drive
  4. 点击“Next”,继续下一步安装操作。
  5. 检查安装是否成功。 按Win+R键,输入cmd,进入控制台。在控制台下输入python,若返回Python的版本好及安装时间,则证明Python环境搭建成功。否则需要进一步配置环境变量。如图3.2所示:

图3.2 本机Python环境配置

3.3.2 Python环境变量配置

方法一:使用cmd命令添加path环境变量

在控制台下输入:path=%path%;D:\Python36,并输入回车键即可查看Python环境变量。 其中: D:\Python36是Python的安装目录。

方法二:在环境变量中添加Python目录

    1. 右键点击"计算机",然后点击"属性"
    2. 然后点击"高级系统设置"
    3. 点击“系统变量”,找到Path
    4. 然后在"Path"行,添加python安装路径即可

3.4 TensorFlow-GPU安装

计算机上通常有多个计算设备,CPU和GPU。而TensorFlow 则完美的支持CPU 和 GPU 这两种设备。可以用以下字符串表示来指定这些设备,例如:

• "/cpu:0": 本机中的 CPU

• "/gpu:0": 本机中的 GPU, 如果有英伟达的GPU的话.

• "/gpu:1": 本机中的第二个 GPU,以此类推。

如果Tensorflow代码中既有CPU的实现方法,又有GPU的实现方法,当这个运算被指派设备时,GPU有优先权,因为GPU的运行速度可以达到CPU的30倍以上,大大提升计算能力,减小手写数字识别系统的响应时间。如果想使用TensorFlow-GPU版本,还需要安装CUDA和CuDNN。

3.4.1下载CUDA软件包

首先来到CUDA官方网站 https://developer.nvidia.com/cuda-downloads,单击 Windows按钮后,如下图所示:

图3.3 CUDA安装包下载

注意:CUDA软件包也有很多个版本,必须与TensorFlow的版本对应才行。比如 TensorFlow1.0以后,直到TensorFlow 1.5的版本只支持CUDA 8.0。可以根据链接 https://developer.nvidia. com/cuda-toolkit-archive找到更多版本。

3.4.2安装CuDNN库

输入网址https://developer.nvidia.com/cudnn来到下载页面,注册后下载CuDNN安装包。CuDNN的版本选择也是有规定的。以 Windows 10操作系统为例,TensorFlow 1.0到 TensorFlow 1.2版本使用的是CuDNN的5.1版本,从TensorFlow 1.3版本之后使用的是 cuDNN的6.0版本(cudnn-8.0-windows10- x64v6.0.zip)得到相关包后解压,直接复制到CUDA安装路径对应的文件夹下面就行。

并不是所有的显卡都可以安装TensorFlow-GPU,可使用nvidia-smi命令查看显卡信息。在安装完成NVIDIA显卡驱动之后,对于Windows用户而言需要注意的是,只有将相关的环境变量添加进去,才能在控制台下识别nvidia-smi命令。

3.5 Anaconda

3.5.1 Anaconda介绍

Anaconda 是一种Python语言的免费增值开源发行版,用于进行大规模数据处理, 预测分析, 和科学计算, 致力于简化包的管理和部署。Anaconda使用软件包管理系统Conda进行包管理。可在https://www.anaconda.com/download/#macos网址中下载Anaconda。

3.5.2 Conda介绍

Conda 是开源包(packages)和虚拟环境(environment)的管理系统。可用conda来安装更新卸载工具包。也可在conda中建立多个虚拟环境,隔离开发不同项目时,所需要的不同版本的工具包,防止不同安装版本的冲突。例如Python2.x和Python3.x。可以用conda建立两个Python虚拟环境,在不同的环境中运行不同版本的Python代码。

Anaconda通过管理工具包、开发环境、Python版本,大大简化了工作流程。不仅可以方便地安装、更新、卸载工具包,而且安装时能自动安装相应的依赖包,同时还能使用不同的虚拟环境隔离不同要求的项目。

Anaconda安装后,可以从菜单中看到它包含几个应用程序,其中Anaconda Navigator是这几个程序的导航入口。Anaconda Navigator是Anaconda发行包中包含的桌面图形界面,可以用来方便地启动应用、方便的管理conda包、虚拟环境。Navigator可以从Anaconda云端或本地Anaconda仓库中搜索包。提供了Windwos、maxOS和Linux版本。Anaconda Navigator主界面如下:

图3.4 Anaconda Navigator主界面

在左边菜单栏中可以看到四个选项,一般常用的是Home和Environments。Environments是你搭建开发环境的地方,你可以在Environments中创建一个开发环境,然后下载所需要的包即可。例如:

  1. 创建开发环境

点击左下角create,弹出创建开发环境框,输入所需创建的环境名并选择python类型,点击确认即可创建。

图3.5 创建开发环境

  1. 下载tensorflow包

搜索tensorflow包,勾选要下载的包,然后点击右下角Apply即可。

图3.6 下载TensorFlow安装包

Jupyter notebook常用来编写TensorFlow程序。因为Jupyter notebook是一种可以在网页上运行的记事本。在写程序时,无需切换到其他开发文档。每写完一段代码,回车即可执行,并保留每一段代码的运行日志,方便查看当前的代码执行状态。而且,调试也极其方便,可以大大的提高开发效率。

                                          3.6 本章小结

本章第一节介绍了本机的硬件设备信息。第二节介绍了本次系统所选的集成开发环境Pycharm及其优点。第三节介绍了Python3.x的安装及环境配置。第四节介绍了TensorFlow-GPU的安装步骤。最后一节介绍了用于安装Tensorflow的软件Anaconda以及用于开发编写Tensorflow代码的插件Jupyter notebook。

第四章 系统的设计与实现

本章将详细的讲述本文所设计的基于TensorFlow框架的手写数字识别系统中所设计的关键技术进行阐述。主要包括SoftMax Regression模型的设计与实现、CNN模型的设计与实现、WEB网页设计、Flask框架的引用等等。

4.1. Softmax Regression

4.1.1回归模型介绍

回归模型是一种预测性的建模技术,它研究的是因变量(目标)和自变量(预测器)之间的关系。这种回归模型通常用于预测分析,时间序列模型以及发现变量之间的因果关系。例如,全国公民的文化程度与全民月读书量之间的关系就很适用于回归模型解决。

回归模型重要的基础或者方法就是回归分析,回归分析是研究一个变量(被解释变量)关于另一些变量(解释型变量)的具体依赖关系的计算方法和理论,是建立模型和数据分析的重要工具。在这里,我们使用曲线或直线来拟合这些数据点。在这种方式下,从曲线或线到数据点的距离差异最小。下面是回归分析的几种常用方法

  1. Linear Regression线性回归
  2. Logistic Regression逻辑回归
  3. Polynomial Regression多项式回归
  4. Stepwise Regression逐步回归
  5. SoftMax Regression SoftMax回归

由于Logistics Regression算法复杂度低,较容易实现等特点,因此逻辑回归在工业中得到广泛的使用,但是逻辑回归算法主要用于处理二分类的问题,对于多分类的问题,则是心有余而力不足,需要使用适用于多分类问题的算法。

Softmax Regression算法是逻辑回归算法在多分类问题上的应用与推广,主要用于处理多分类问题。其中,要求任意两个类之间是线性可划分的。多分类问题,它的类标签y的取值个数应大于2,如手写字识别,即识别{0,1,2,3,4,5,6,7,8,9}是哪一个数字。

MNIST数据集的每一张图片都表示一个(从0到9) 数字。优良的模型在看到一张图后就能知道它属于各个数字的对应概率。比如,当训练好的模型看到一张数字"9" 的图片,就判断出它是数字"9"的概率为 80%,而有10%的概率属于数字"8"(因为8和9比较相似,只是左下方有些区别),同时给予其他数字对应的小概率,因为该图像代表其他的可能性微乎其微。

4.1.2 Softmax Regression算法介绍

Softmax Regression算法原理[13]简单介绍如下:

对于输入的手写体数字图像对于不同数字的“证据”加权求和,并将加权求和的结果转为对应数字的概率。如果手写体数字图像中像素很像某个数字,则对该数字求和的权值为正数,越像这个数字,则权值越大。如不像这个数字,则权值为负数,越不像这个数字,则权值的绝对值越大。下图显示了Softmax Regression模型学习到的手写体数字图像对于0~9共10个数字类的权值。蓝色权值为正数,红色权值为负数,颜色越深,权值绝对值越大,如图4.1所示:

图4.1 数字类的权值

此外,还需要引入其他“证据”,也就是常说的偏置量。因此对于给定的输入图片 x 代表某数字i 的总体证据可以表示如公式3.1所示:

  

在上述公式中, b(i)代表第i类数据的偏置量,W(i)表示的是训练时的权重。j 表示的是对于给定的图片x的像素索引,常用于像素求和。求和后调用Softmax函数可以把这些证据转换成概率 y,如公式3.2所示:

          (3.2)

Softmax函数可以看成是一个激励(activation)函数,激励函数会将定义好的线性函数的输出,转换理想的格式,也就是关于0~9共十个数字类的概率分布。因此,只要给定一张图片,这张图片对于每一个数字的契合程度可以被Softmax函数转换成为一个概率值。Softmax函数的公式定义如公式3.3所示:

      (3.3)

展开等式右边的子式,可以得到公式3.4:

                         

          (3.4)

Softmax函数模型常定义为Normalize(),这样看起来更简洁。Softmax函数把输入值当成幂指数求值,之后对这些结果值进行正则化处理。Normalize()表示,更大的证据对应更大的假设模型里面的乘数权重值.反之,拥有更少的证据意味着在假设模型里面拥有更小的乘数系数。假设模型里的权值不可以是 0 值或者负值。Softmax函数则会正则化这些权重值,使它们的总和等于1,以此来构造一个有效的概率分布。对于Softmax回归模型可以用下面的图解释,对于输入的xs加权求和,再分别加上一个偏置量,最后再输入到softmax 函数中,如图4.2所示:

图4.2 Softmax函数

将上述方程用矩阵表示,则有以下矩阵,如公式3.4所示:

             

                       (3.4)

若该过程用向量(Vector)表示,有助于提高计算效率,如公式3.5所示:

                       (3.5)

将上式简化后,即可得到Softmax方程,Softmax方程如公式3.6所示:

                             (3.6)

4.1.3 Softmax Regression模型实现

为实现高效快速的数值计算,通常会调用外部函数依赖库(例如Numpy), 把类似矩阵乘法这样的复杂运算使用其他外部语言实现。但是,在Python和外部计算之间来回切换,会消耗过多的系统资源,尤其是进程之中的资源。若使用GPU来计算外部数据[14],由于GPU不能得到连续的执行,会消耗更多的资源。即使是采用分布式计算,也会浪费很多时间去传输外部数据。

TensorFlow 也把复杂的计算放在 python 之外完成,但是为了避免上文所述的那些开销,Tensorflow做了进一步完善。它不单独地运行单一的复杂计算,而是先用图描述一系列可交互的计算操作,最后全部一起在Python之外运行。

使用TensorFlow之前,首先导入它:

Import tensorflow as tf

通过操作符号变量来描述这些可交互的操作单元,例如:创建一个Float型的占位符如下所示:

变量X不是一个特定的值,而是一个占位符,这个占位符X会在Tensorflow进行数值计算时作为一个Float型变量输入进去。而本系统则需要输入一定数量的手写体数字图像,而这种图像需要固定大小,长和宽各位28个像素点,因此可以展开为28*28=784维的向量。可以采用二位的Float型张量来表示手写体数字图像,张量的形状为[None,784]。None表示张量(Tensor)可以是任意长度的。

Softmax模型需要偏置值(biases)和权重值(weights),其中一种解决办法是使用占位符来代替这两个变量,但是Tensorflow提供了更为便利的方法,它使用Variable函数来提供变量的引用。Variable变量在描述交互性操作的图中,常被用于计算输入值,也就是说,当需要输入数据时,常用Variable来表示。Variable在计算中可以被修改,因此Variable也常用于表示模型的参数值,权重值和偏置值如下所示:

通过给tf.variable赋予不同的初始值,来创建不同的Tensor,在Softmax回归模型中,需要先对权重值和偏置值赋予全0的初始张量。但需要注意的是:权重是表示手写体数字图像,可能的结果为0~9共是个数字,因此它的维度必须是[784,10]。由矩阵运算常识可知:若想每一位对应不同的数字类,需要使用784维的图片向量乘以10维的偏置值向量,才可以得到10维的偏置向量。因此,偏置值向量b需要初始化维10维的向量。有了这些,就可用在几何上实现Softmax回归模型了,如下所示:

tf.matmul是tensorflow的矩阵乘积函数,上述代码表示用输入X乘以权重值W,然后加上偏置值b,最后用tf.nn.softmax函数处理计算的结果。其中X是一个拥有多个输入的二维张量,输入的多少取决于训练的手写体数字图像的数目。Tensorflow框架使softmax模型的计算变得十分简单灵活,很方便地描述各种各样的数值计算,正因如此,本系统才选择了Tensorflow框架。不论是什么领域方向的模型,只要定义好tensorflow模型,就可用运行于各个设备,跨平台移植性极好。

更多内容可看我主页。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/416008.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

大模型备案重难点最详细说明【评估测试题+附件】

2024年3月1日,我国通过了《生成式人工智能服务安全基本要求》(以下简称《AIGC安全要求》),这是目前我国第一部有关AIGC服务安全性方面的技术性指导文件,对语料安全、模型安全、安全措施、词库/题库要求、安全评估等方面…

qmt量化交易策略小白学习笔记第59期【qmt编程之期权数据--获取指定期权品种的详细信息--原生Python】

qmt编程之获取期权数据 qmt更加详细的教程方法,会持续慢慢梳理。 也可找寻博主的历史文章,搜索关键词查看解决方案 ! 基于BS模型计算欧式期权理论价格 基于Black-Scholes-Merton模型,输入期权标的价格、期权行权价、无风险利率…

【技术分享】顶尖 GIS 技术

谈到 GIS,就不能不提到现代地理智能。是指基于 GIS、遥感和卫星定位技术的地理空间可视化、分析、决策、设计和控制的技术总称。地理智能是 GIS 区别于其他信息技术最重要的价值之一。它由地理可视化、地理决策、地理设计、地理控制四个层次组成。它们形成了一个地理…

ES6 day-03

目录 一. ES6 函数 1.1 函数参数的扩展 1.1.1 默认参数 1.1.2 不定参数 1.2 箭头函数 二. Iterator(迭代器) 三. ES6 Promise 对象(重点) 3.1 Promise前言 3.1.1 Promise概述 3.1.2 Promise 状态 3.1.3 then 方法 3.2 基本使用 3.2 promise结合数据请求 3.3 回调…

中国各省份-环境规制相关数据(2000-2022年)

环境规制,也称为环保政策和污染治理,是一系列由政府制定的旨在解决环境问题、保护生态环境和促进可持续发展的政策措施。这些措施包括法律法规、行政命令、经济激励和市场机制等,目的是约束和指导企业和个人行为,减少对环境的负面…

pikachu文件包含漏洞靶场通关攻略

本地文件包含 先上传一个jpg文件&#xff0c;内容写上<?php phpinfo();?> 上传成功并且知晓了文件的路径 返回本地上传&#xff0c;并../返回上级目录 可以看到我们的php语句已经生效 远程文件包含 在云服务器上创建一个php文件 然后打开pikachu的远程文件包含靶场&…

企业级RAG应用优化整合贴【上】:数据索引阶段的8个必知技巧 |建议收藏

基于大模型的RAG应用&#xff0c;一个普遍的认识是&#xff1a; 做原型很简单&#xff0c;投入生产很难 为什么我的RAG应用很难按预期工作&#xff1f;在之前的文章中我们曾经陆续的对RAG应用优化做过零星与局部的探讨&#xff0c;如融合检索、查询转换、多模态处理、Agentic…

link .css加载失败事件

https://andi.cn/page/621728.html 博客中的代码可以一键运行代码运行平台点击工具按钮可以查看console消息

【C++题解】1241 - 角谷猜想

问题二&#xff1a;1241 - 角谷猜想 类型&#xff1a;有规律的循环、递归。 题目描述&#xff1a; 日本一位中学生发现一个奇妙的定理&#xff0c;请角谷教授证明&#xff0c;而教授无能为力&#xff0c;于是产生了角谷猜想。 猜想的内容&#xff1a;任给一个自然数&#xff…

鸿蒙开发入门day16-拖拽事件和手势事件

(创作不易&#xff0c;感谢有你&#xff0c;你的支持&#xff0c;就是我前行的最大动力&#xff0c;如果看完对你有帮助&#xff0c;还请三连支持一波哇ヾ(&#xff20;^∇^&#xff20;)ノ&#xff09; 目录 拖拽事件 概述 拖拽流程 ​手势拖拽 ​鼠标拖拽 拖拽背板图 …

疑似女友通过社交媒体泄露其本人位置数据,导致了杜罗夫的被捕?

以下引用百度百科&#xff1a; 帕维尔杜罗夫&#xff08;俄文&#xff1a;Павел Дуров&#xff0c;英文&#xff1a;Pavel Durov&#xff09;&#xff0c;男&#xff0c;1984年10月10日出生于俄罗斯列宁格勒州&#xff08;今圣彼得堡市&#xff09;&#xff0c;毕业…

Guava Cache实现原理及最佳实践

本文内容包括Guava Cache的使用、核心机制的讲解、核心源代码的分析以及最佳实践的说明。 概要 Guava Cache是一款非常优秀本地缓存&#xff0c;使用起来非常灵活&#xff0c;功能也十分强大。Guava Cache说简单点就是一个支持LRU的ConcurrentHashMap&#xff0c;并提供了基于…

4.1 数据分析-excel 基本操作

第四节&#xff1a;数据分析-excel 基本操作 课程目标 学会excel 基本操作 课程内容 数据伪造 产生一份招聘数据 import pandas as pd from faker import Faker import random import numpy as np# 创建一个Faker实例&#xff0c;用于生成假数据&#xff0c;指定中文本地…

不小心删除丢失了所有短信?如何在 iPhone 上查找和恢复误删除的短信

不小心删除了一条短信&#xff0c;或者丢失了所有短信&#xff1f;希望还未破灭&#xff0c;下面介绍如何在 iPhone 上查找和恢复已删除的短信。 短信通常都是非正式和无关紧要的&#xff0c;但短信中可能包含非常重要的信息。因此&#xff0c;如果您删除了一些短信以清理 iPh…

MASt3R:从3D的角度来实现图像匹配(更新中)

Abstract 图像匹配是 3D 视觉中所有性能最佳算法和pipeline的核心组件。 然而&#xff0c;尽管匹配从根本上来说是一个 3D 问题&#xff0c;与相机姿态和场景几何结构有内在联系&#xff0c;但它通常被视为一个 2D 问题。因为匹配的目标是建立 2D 像素字段之间的对应关系&#…

MYSQL:删除指定时间范围内每个电站每天发电数据除最大值以外的记录

有一个需求&#xff0c;需要保留每个电站每一天发电数据的最大值记录&#xff0c;其余删除。 表数据大概长这样&#xff1a; MYSQL 5.7写法&#xff1a;&#xff08;因为不支持ROW_NUMBER()函数&#xff0c;采用自定义的变量来代替&#xff09; 首次清理一年内数据&#xff1…

在Postgresql中计算工单的对应的GPS轨迹距离

一、概述 在某个App开发中&#xff0c;要求记录用户的日常轨迹&#xff0c;在用户巡逻设备的时&#xff0c;将记录的轨迹点当做该设备巡逻时候的轨迹。 由于业务逻辑上没有明确的指示人员巡逻工单-GPS位置之间的关系&#xff0c;所以通过时间关系进行轨迹划定。 二、创建测试表…

备受500强企业青睐的安全数据交换系统,到底有什么优势?

网络隔离成为常见的安全手段 网络隔离技术已成为许多企业进行网络安全建设的重要手段之一&#xff0c;党政单位、金融机构、半导体企业、以及能源电力、医疗、生物制药等等行业及领域的企业都会选择方式不一的网络隔离技术来保护自己的网络安全&#xff0c;规避互联网中的网络…

python开发--模板语句

这部分是导航栏部分的代码&#xff0c;由于导航栏在各个页面都需要用&#xff0c;为了提高代码复用率将导航栏部分作为一个模板。 在下面代码图中&#xff0c;红色框部分相当于一个插槽&#xff0c;其他页面&#xff0c;如部门列表、用户列表等将在这个位置展示。 这部分是用户…

全国地市未来产业水平数据集(2008-2023年)

未来产业&#xff0c;作为驱动经济社会高质量发展的核心引擎&#xff0c;是指依托科技创新和模式创新&#xff0c;引领全球新一轮科技革命和产业变革&#xff0c;具有前瞻性、先导性、战略性的新兴产业领域。也是实现生产力大解放&#xff0c;推动生产力质的跃迁并形成新质生产…