【深度学习基础(3)】初识神经网络之深度学习hello world

文章目录

  • 一. 训练Keras中的MNIST数据集
  • 二. 工作流程
    • 1. 构建神经网络
    • 2. 准备图像数据
    • 3. 训练模型
    • 4. 利用模型进行预测
    • 5. (新数据上)评估模型精度

本节将首先给出一个神经网络示例,引出如下概念。了解完本节后,可以对神经网络在代码上的实现有一个整体的了解。

本节相关概念:

  • 样本
  • 标签
  • 层(layer)
  • 数据蒸馏
  • 密集连接
  • 10路softmax分类层
  • 编译(compilation)步骤的3个参数
  • 损失值、精度
  • 过拟合

我们来看一个神经网络的具体实例:使用Python的Keras库来学习手写数字分类。

在这个例子中,我们要解决的问题是,将手写数字的灰度图像(28像素×28像素)划分到10个类别中(从0到9)。我们将使用MNIST数据集。你可以将“解决”MNIST问题看作深度学习的“Hello World”,用来验证你的算法正在按预期运行。下图给出了MNIST数据集的一些样本。

在这里插入图片描述

说明

在机器学习中,分类问题中的某个类别叫作类(class),数据点叫作样本(sample)与某个样本对应的类叫作标签(label)(即描述:样本属于哪个类别)。

 

你不需要现在就尝试在计算机上运行这个例子。之后的文章会具体分析。

 

一. 训练Keras中的MNIST数据集

from tensorflow.keras.datasets import mnist (train_images, train_labels), (test_images, test_labels) = mnist.load_data()

train_images和train_labels组成了训练集,模型将从这些数据中进行学习。然后,我们在测试集(包括test_images和test_labels)上对模型进行测试。

图像被编码为NumPy数组,而标签是一个数字数组,取值范围是0~9。图像和标签一一对应。

 
看一下训练数据:

>>> train_images.shape (60000, 28, 28) 
>>> len(train_labels) 
60000 
>>> train_labels 
array([5, 0, 4, ..., 5, 6, 8], dtype=uint8)

 

再来看一下测试数据:

>>> test_images.shape
(10000, 28, 28) 
>>> len(test_labels) 
10000 
>>> test_labels 
array([7, 2, 1, ..., 4, 5, 6], dtype=uint8)

 

二. 工作流程

我们的工作流程如下:
首先,将训练数据(train_images和train_labels)输入神经网络;
然后,神经网络学习将图像和标签关联在一起;
最后,神经网络对test_images进行预测,我们来验证这些预测与test_labels中的标签是否匹配。
 

1. 构建神经网络

下面我们来构建神经网络,如下:

from tensorflow import keras 
from tensorflow.keras import layers model = keras.Sequential([ layers.Dense(512, activation="relu"), layers.Dense(10, activation="softmax") ])

 

神经网络的核心组件是层(layer)

具体来说,层从输入数据中提取表示。大多数深度学习工作涉及将简单的层链接起来,从而实现渐进式的数据蒸馏(data distillation)。深度学习模型就像是处理数据的筛子,包含一系列越来越精细的数据过滤器(也就是层)。
 

本例中的模型包含2个Dense层,它们都是密集连接(也叫全连接)的神经层。

第2层是一个10路softmax分类层,它将返回一个由10个概率值(总和为1)组成的数组。每个概率值表示当前数字图像属于10个数字类别中某一个的概率。
 

在训练模型之前,我们还需要指定编译(compilation)步骤的3个参数

  • 优化器(optimizer):模型基于训练数据来自我更新的机制,其目的是提高模型性能。
  • 损失函数(loss function):模型如何衡量在训练数据上的性能,从而引导自己朝着正确的方向前进。
  • 在训练和测试过程中需要监控的指标(metric):本例只关心精度(accuracy),即正确分类的图像所占比例。后面两章会详细介绍损失函数和优化器的确切用途。

如下代码展示了编译步骤。


model.compile(optimizer="rmsprop", loss="sparse_categorical_crossentropy", metrics=["accuracy"])

 

2. 准备图像数据

在开始训练之前,我们先对数据进行预处理,将其变换为模型要求的形状,并缩放到所有值都在[0, 1]区间。前面提到过,训练图像保存在一个uint8类型的数组中,其形状为(60000, 28, 28),取值区间为[0,255]。我们将把它变换为一个float32数组,其形状为(60000, 28 *28),取值范围是[0, 1]。

下面准备图像数据,如代码所示。

train_images = train_images.reshape((60000, 28 * 28)) 
train_images = train_images.astype("float32") / 255 test_images = test_images.reshape((10000, 28 * 28)) 
test_images = test_images.astype("float32") / 255

 

3. 训练模型

在Keras中,通过调用模型的fit方法调用数据,训练模型。

>>> model.fit(train_images, train_labels, epochs=5, batch_size=128) 
Epoch 1/5 
60000/60000 [===========================] - 5s - loss: 0.2524 - acc: 0.9273 Epoch 2/5 
51328/60000 [=====================>.....] - ETA: 1s - loss: 0.1035 - acc: 0.9692

训练过程中显示了两个数字:一个是模型在训练数据上的损失值(loss),另一个是模型在训练数据上的精度(acc)。我们很快就在训练数据上达到了0.989(98.9%)的精度。

现在我们得到了一个训练好的模型,可以利用它来预测新数字图像的类别概率(如下代码)。这些新数字图像不属于训练数据,比如可以是测试集中的数据。

 

4. 利用模型进行预测

>>> test_digits = test_images[0:10] 
>>> predictions = model.predict(test_digits) 
>>> predictions[0] 
array([1.0726176e-10, 1.6918376e-10, 6.1314843e-08, 8.4106023e-06, 2.9967067e-11, 3.0331331e-09, 8.3651971e-14, 9.9999106e-01, 2.6657624e-08, 3.8127661e-07], dtype=float32)

如上代码我们对11个test_images图片进行预测,是什么数字,我们拿到第一个图片预测的概率数组,其中索引为7时,概率最大(0.99999106,几乎等于1),所以根据我们的模型,这个数字一定是7。

>>> predictions[0].argmax() 
7 
>>> predictions[0][7] 
0.99999106

这里我们检查测试标签是否与之一致:

>>> test_labels[0] 7

平均而言,我们的模型对这种前所未见的数字图像进行分类的效果如何?我们来计算在整个测试集上的平均精度,如下代码所示。

 

5. (新数据上)评估模型精度

>>> test_loss, test_acc = model.evaluate(test_images, test_labels) 
>>> print(f"test_acc: {test_acc}") 
test_acc: 0.9785

测试精度约为97.8%,比训练精度(98.9%)低不少。训练精度和测试精度之间的这种差距是过拟合(overfit)造成的。

过拟合是指机器学习模型在新数据上的性能往往比在训练数据上要差。

第一个例子到这里就结束了。你刚刚看到了如何用不到15行Python代码构建和训练一个神经网络,对手写数字进行分类。

 

之后的文章我们将详细描述每一个步骤的原理,并且将学到张量(输入模型的数据存储对象)、张量运算(层的组成要素)与梯度下降(可以让模型从训练示例中进行学习)。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/319277.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

一键自动化博客发布工具,chrome和firfox详细配置

blog-auto-publishing-tools博客自动发布工具现在已经可以同时支持chrome和firefox了。 很多小伙伴可能对于如何进行配置和启动不是很了解,今天带给大家一个详细的保姆教程,只需要跟着我的步骤一步来就可以无障碍启动了。 前提条件 前提条件当然是先下…

C#中.net8WebApi加密解密

尤其在公网之中,数据的安全及其的重要,除过我们使用jwt之外,还可以对传送的数据进行加密,就算别人使用抓包工具,抓到数据,一时半会儿也解密不了数据,当然,加密也影响了效率&#xff…

VISO流程图之子流程的使用

子流程的作用 整个流程图的框图多而且大,进行分块;让流程图简洁对于重复使用的流程,可以归结为一个子流程图,方便使用,避免大量的重复性工作; 新建子流程 方法1: 随便布局 框选3 和4 &#…

【Android学习】日期和时间选择对话框

实现功能 实现日期和时间选择的对话框&#xff0c;具体效果可看下图(以日期为例) 具体代码 1 日期对话框 1.1 xml <?xml version"1.0" encoding"utf-8"?> <LinearLayout xmlns:android"http://schemas.android.com/apk/res/android&quo…

✔ ★Java项目——设计一个消息队列(五)【虚拟主机设计】

虚拟主机设计 创建 VirtualHost实现构造⽅法和 getter创建交换机删除交换机创建队列删除队列创建绑定删除绑定发布消息 ★路由规则1) 实现 route ⽅法2) 实现 checkRoutingKeyValid3) 实现 checkBindingKeyValid4) 实现 routeTopic5) 匹配规则测试⽤例6) 测试 Router 订阅消息1…

STM32微秒级别延时--F407--TIM1

基本配置&#xff1a; TIM1挂载在APB2总线上&#xff0c;150MHz经过15分频&#xff0c;得到10MHz计数频率&#xff0c;由于disable了自动重装载&#xff0c;所以只需要看下一次计数值是多少即可。 void TIM1_Delay_us(uint16_t us) //使用阻塞方式进行延时&#xff0c;ARR值不…

终端安全管理软件哪个好?

终端安全管理软件是保障企业信息安全的重要工具。 它们能够有效地防范恶意软件、黑客攻击和其他安全威胁&#xff0c;并提供多方面的终端设备安全保护措施。 终端安全软件的功能和保护机制各不相同&#xff0c;这就需要企业根据自身的需求和情况来进行评估和选择。 下面总结了…

3.2Java全栈开发前端+后端(全栈工程师进阶之路)-前端框架VUE3框架-企业级应用- Vuex

Vuex简介 Vuex概述 Vuex是一个专门为Vue.js应用程序开发的状态管理模式, 它采用集中式存储管理所有组件的公共状态, 并以相应的规 则保证状态以一种可预测的方式发生变化. 试想这样的场景, 比如一个Vue的根实例下面有一个根组件名为App.vue, 它下面有两个子组件A.vue和B.vu…

【Linux】awk命令学习

最近用的比较多&#xff0c;学习总结一下。 文档地址&#xff1a;https://www.gnu.org/software/gawk/manual/gawk.html 一、awk介绍二、语句结构1.条件控制语句1&#xff09;if2&#xff09;for3&#xff09;while4&#xff09;break&continue&next&exit 2.比较运…

数据结构——循环结构:for循环

今天是星期五&#xff0c;明天休息&#xff0c;后天补课&#xff0c;然后就是运动会&#xff0c;接着是放假。&#xff08;但这些都和我没关系啊&#xff0c;哭死&#xff01;&#xff09;今天脑袋难得清醒一会儿&#xff0c;主要是醒的比较早吧&#xff0c;早起学了一会&#…

3GPP官网下载协议步骤

1.打开官网 https://www.3gpp.org/ 2.点击 3.在界面选择要找的series&#xff0c;跳转到查找界面 以V2X通信协议为例&#xff0c;论文中通常会看到许多应用&#xff1a; [7] “Study on evaluation methodology of new Vehicle-to-Everything (V2X) use cases for LTE and NR…

【Python】机器学习之Sklearn基础教程大纲

机器学习之Sklearn基础教程大纲 1. 引言 机器学习简介Scikit-learn&#xff08;Sklearn&#xff09;库介绍安装和配置Sklearn 2. 数据预处理 2.1 数据加载与查看 - 加载CSV、Excel等格式的数据- 查看数据的基本信息&#xff08;如形状、数据类型等&#xff09;2.2 数据清洗…

大语言模型中的第一性原理:Scaling laws

大语言模型的尺度定律在大语言模型的训练过程中起到了非常重要的作用。即使读者不参与大语言模型的训练过程&#xff0c;但了解大语言模型的尺度定律仍然是很重要的&#xff0c;因为它能帮助我们更好的理解未来大语言模型的发展路径。 1. 什么是尺度定律 尺度定律&#xff08…

anaconda、cuda、tensorflow、pycharm环境安装

anaconda、cuda、tensorflow、pycharm环境安装 anaconda安装 anaconda官方下载地址 本文使用的是基于python3.9的anaconda 接下来跟着步骤安装&#xff1a; 检验conda是否成功安装 安装CUDA和cuDNN 提醒&#xff0c;CUDA和cuDNN两者必须版本对应&#xff0c;否者将会出错…

AI家居设备的未来:智能家庭的下一个大步

&#x1f512;目录 ☂️智能家居设备的发展和AI技术的作用 ❤️AI技术实现智能家居设备的自动化控制和智能化交互的依赖 AI家居设备的未来应用场景 &#x1f4a3;智能家庭在未来的发展和应用前景 &#x1f4a5;智能家居设备的发展和AI技术的作用 智能家居设备的发展和AI技术的…

【skill】onedrive的烦人问题

Onedrive的迷惑行为 安装Onedrive&#xff0c;如果勾选了同步&#xff0c;会默认把当前用户的数个文件夹&#xff08;桌面、文档、图片、下载 等等&#xff09;移动到安装时提示的那个文件夹 查看其中的一个文件的路径&#xff1a; 这样一整&#xff0c;原来的文件收到严重影…

使用Python实现二维码生成工具

二维码的本质是什么&#xff1f; 二维码本质上&#xff0c;就是一段字符串。 我们可以把任意的字符串&#xff0c;制作成一个二维码图片。 生活中使用的二维码&#xff0c;更多的是一个 URL 网址。 需要用到的模块 先看一下Python标准库&#xff0c;貌似没有实现这个功能的…

将要上市的自动驾驶新书《自动驾驶系统开发》中摘录各章片段 1

以下摘录一些章节片段&#xff1a; 1. 概论 自动驾驶系统的认知中有一些模糊的地方&#xff0c;比如自动驾驶系统如何定义的问题&#xff0c;自动驾驶的研发为什么会有那么多的子模块&#xff0c;怎么才算自动驾驶落地等等。本章想先给读者一个概括介绍&#xff0c;了解自动驾…

IoTDB 入门教程 基础篇⑨——TsFile导入导出工具

文章目录 一、前文二、准备2.1 准备导出服务器2.2 准备导入服务器 三、导出3.1 导出命令3.2 执行命令3.3 tsfile文件 四、导入4.1 上传tsfile文件4.2 导入命令4.3 执行命令 五、查询六、参考 一、前文 IoTDB入门教程——导读 数据库备份与迁移是数据库运维中的核心任务&#xf…

Dockerfile实战(SSH、Systemctl、Nginx、Tomcat)

目录 一、构建SSH镜像 1.1 dockerfile文件内容 1.2 生成镜像 1.3 启动容器并修改root密码 二、构建Systemctl镜像 2.1 编辑dockerfile文件 ​编辑2.2 生成镜像 2.3 启动容器&#xff0c;并挂载宿主机目录挂载到容器中&#xff0c;然后进行初始化 2.4 进入容器验证 三、…