深度学习(1)-简单神经网络示例

我们来看一个神经网络的具体实例:使用Python的Keras库来学习手写数字分类。在这个例子中,我们要解决的问题是,将手写数字的灰度图像(28像素×28像素)划分到10个类别中(从0到9)​。我们将使用MNIST数据集,图2-1给出了MNIST数据集的一些样本。
在这里插入图片描述
在机器学习中,分类问题中的某个类别叫作类(class)​,数据点叫作样本(sample)​,与某个样本对应的类叫作标签(label)​。你不需要现在就尝试在计算机上运行这个例子。如果你想这么做,那么首先需要建立深度学习工作区(见第3章)​。MNIST数据集已预先加载在Keras库中,其中包含4个NumPy数组,如代码清单2-1所示。

from tensorflow.keras.datasets import mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

train_images和train_labels组成了训练集,模型将从这些数据中进行学习。然后,我们在测试集(包括test_images和test_labels)上对模型进行测试。图像被编码为NumPy数组,而标签是一个数字数组,取值范围是0~9。图像和标签一一对应。我们来看一下训练数据:

>>> train_images.shape
(60000, 28, 28)
>>> len(train_labels)
60000
>>> train_labels
array([5, 0, 4, ..., 5, 6, 8], dtype=uint8)

再来看一下测试数据:

>>> test_images.shape
(10000, 28, 28)
>>> len(test_labels)
10000
>>> test_labels
array([7, 2, 1, ..., 4, 5, 6], dtype=uint8)

工作流程如下:首先,将训练数据(train_images和train_labels)输入神经网络;然后,神经网络学习将图像和标签关联在一起;最后,神经网络对test_images进行预测,我们来验证这些预测与test_labels中的标签是否匹配。下面我们来构建神经网络,如代码清单2-2所示。

代码清单2-2 神经网络架构

from tensorflow import keras
from tensorflow.keras import layers
model = keras.Sequential([layers.Dense(512, activation="relu"),layers.Dense(10, activation="softmax")
])

**神经网络的核心组件是层(layer)**​。你可以将层看成数据过滤器:进去一些数据,出来的数据变得更加有用。具体来说,层从输入数据中提取表示——我们期望这种表示有助于解决手头的问题。大多数深度学习工作涉及将简单的层链接起来,从而实现渐进式的数据蒸馏(data distillation)​。深度学习模型就像是处理数据的筛子,包含一系列越来越精细的数据过滤器(也就是层)​。**本例中的模型包含2个Dense层,它们都是密集连接(也叫全连接)的神经层。第2层(也是最后一层)是一个10路softmax分类层,它将返回一个由10个概率值(总和为1)组成的数组。**每个概率值表示当前数字图像属于10个数字类别中某一个的概率。在训练模型之前,我们还需要指定编译(compilation)步骤的3个参数。优化器(optimizer)​:模型基于训练数据来自我更新的机制,其目的是提高模型性能。损失函数(loss function)​:模型如何衡量在训练数据上的性能,从而引导自己朝着正确的方向前进。在训练和测试过程中需要监控的指标(metric)​:本例只关心精度(accuracy)​,即正确分类的图像所占比例。

后面两章会详细介绍损失函数和优化器的确切用途。代码清单2-3展示了编译步骤。

代码清单2-3 编译步骤

model.compile(optimizer="rmsprop",loss="sparse_categorical_crossentropy",metrics=["accuracy"])

在开始训练之前,我们先对数据进行预处理,将其变换为模型要求的形状,并缩放到所有值都在[0, 1]区间。前面提到过,训练图像保存在一个uint8类型的数组中,其形状为(60000, 28, 28),取值区间为[0, 255]。我们将把它变换为一个float32数组,其形状为(60000, 28 * 28),取值范围是[0, 1]。下面准备图像数据,如代码清单2-4所示。

代码清单2-4 准备图像数据

train_images = train_images.reshape((60000, 28 * 28))
train_images = train_images.astype("float32") / 255
test_images = test_images.reshape((10000, 28 * 28))
test_images = test_images.astype("float32") / 255

现在我们准备开始训练模型。在Keras中,这一步是通过调用模型的fit方法来完成的——我们在训练数据上拟合(fit)模型,如代码清单2-5所示。
代码清单2-5 拟合模型

>>> model.fit(train_images, train_labels, epochs=5, batch_size=128)
Epoch 1/5
60000/60000 [===========================] - 5s - loss: 0.2524 - acc: 0.9273
Epoch 2/5
51328/60000 [=====================>.....] - ETA: 1s - loss: 0.1035 - acc: 0.9692

训练过程中显示了两个数字:一个是模型在训练数据上的损失值(loss)​,另一个是模型在训练数据上的精度(acc)​。我们很快就在训练数据上达到了0.989(98.9%)的精度。现在我们得到了一个训练好的模型,可以利用它来预测新数字图像的类别概率(见代码清单2-6)​。这些新数字图像不属于训练数据,比如可以是测试集中的数据。

代码清单2-6 利用模型进行预测

>>> test_digits = test_images[0:10]
>>> predictions = model.predict(test_digits)
>>> predictions[0]
array([1.0726176e-10, 1.6918376e-10, 6.1314843e-08, 8.4106023e-06,2.9967067e-11, 3.0331331e-09, 8.3651971e-14, 9.9999106e-01,2.6657624e-08, 3.8127661e-07], dtype=float32)

这个数组中每个索引为i的数字对应数字图像test_digits[0]属于类别i的概率。第一个测试数字在索引为7时的概率最大(0.99999106,几乎等于1)​,所以根据我们的模型,这个数字一定是7。

>>> predictions[0].argmax()
7
>>> predictions[0][7]
0.99999106

我们可以检查测试标签是否与之一致:

>>> test_labels[0]
7

平均而言,我们的模型对这种前所未见的数字图像进行分类的效果如何?我们来计算在整个测试集上的平均精度,如代码清单2-7所示。

代码清单2-7 在新数据上评估模型

>>> test_loss, test_acc = model.evaluate(test_images, test_labels)
>>> print(f"test_acc: {test_acc}")
test_acc: 0.9785

测试精度约为97.8%,比训练精度(98.9%)低不少。训练精度和测试精度之间的这种差距是过拟合(overfit)造成的。**过拟合是指机器学习模型在新数据上的性能往往比在训练数据上要差,**它是第4章的核心主题。第一个例子到这里就结束了。你刚刚看到了如何用不到15行Python代码构建和训练一个神经网络,对手写数字进行分类。在本章和第3章中,我们会详细了解这个例子中的每一个步骤及其原理。接下来,你将学到张量(输入模型的数据存储对象)​、张量运算(层的组成要素)与梯度下降(可以让模型从训练示例中进行学习)​。

需要记住的名词:
1.类
2.样本
3.标签
4.训练集
5.测试集
6.层(layer)
7.dense
8.softmax
9.损失函数
10.指标
11.过拟合

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/19473.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【AI】Docker中快速部署Ollama并安装DeepSeek-R1模型: 一步步指南

【AI】Docker中快速部署Ollama并安装DeepSeek-R1模型: 一步步指南 一、前言 为了确保在 Docker 环境中顺利安装并高效运行 Ollama 以及 DeepSeek 离线模型,本文将详细介绍整个过程,涵盖从基础安装到优化配置等各个方面。通过对关键参数和配置的深入理解…

将OpenWrt部署在x86服务器上

正文共:1234 字 40 图,预估阅读时间:2 分钟 如果你问ChatGPT有哪些开源的SD-WAN方案,他会这样答复你: 我们看到,OpenWrt也属于比较知名的开源SD-WAN解决方案。当然,在很久之前,我就发…

【区块链】零知识证明基础概念详解

🌈个人主页: 鑫宝Code 🔥热门专栏: 闲话杂谈| 炫酷HTML | JavaScript基础 ​💫个人格言: "如无必要,勿增实体" 文章目录 零知识证明基础概念详解引言1. 零知识证明的定义与特性1.1 基本定义1.2 三个核心…

Elasticsearch:将 Ollama 与推理 API 结合使用

作者:来自 Elastic Jeffrey Rengifo Ollama API 与 OpenAI API 兼容,因此将 Ollama 与 Elasticsearch 集成非常容易。 在本文中,我们将学习如何使用 Ollama 将本地模型连接到 Elasticsearch 推理模型,然后使用 Playground 向文档提…

基于Ubuntu+vLLM+NVIDIA T4高效部署DeepSeek大模型实战指南

一、 前言:拥抱vLLM与T4显卡的强强联合 在探索人工智能的道路上,如何高效地部署和运行大型语言模型(LLMs)一直是一个核心挑战。尤其是当我们面对资源有限的环境时,这个问题变得更加突出。原始的DeepSeek-R1-32B模型虽…

新数据结构(9)——Java异常体系

异常的种类 程序本身通常无法主动捕获并处理错误(Error),因为这些错误通常表示系统级的严重问题,但程序可以捕获并处理异常(Excrption),而Error则被视为一种程序无法或不应尝试恢复的异常类型。…

深度学习笔记——循环神经网络之LSTM

大家好,这里是好评笔记,公主号:Goodnote,专栏文章私信限时Free。本文详细介绍面试过程中可能遇到的循环神经网络LSTM知识点。 文章目录 文本特征提取的方法1. 基础方法1.1 词袋模型(Bag of Words, BOW)工作…

传统混合专家模型MoE架构详解以及python示例(DeepSeek-V3之基础)

我们已经了解到DeepSeek-V3的框架结构基于三大核心技术构建:多头潜在注意力(MLA)、DeepSeekMoE架构和多token预测(MTP)。而DeepSeekMoE架构的底层模型采用了混合专家模型(Mixture of Experts,MoE)架构。所以我们先了解一下传统混合专家模型MoE架构。 一、传统混合专家模…

【深度学习】计算机视觉(CV)-目标检测-Faster R-CNN —— 高精度目标检测算法

1.什么是 Faster R-CNN? Faster R-CNN(Region-based Convolutional Neural Network) 是 目标检测(Object Detection) 领域的一种 双阶段(Two-Stage) 深度学习方法,由 Ross Girshick…

实现pytorch注意力机制-one demo

主要组成部分: 1. 定义注意力层: 定义一个Attention_Layer类,接受两个参数:hidden_dim(隐藏层维度)和is_bi_rnn(是否是双向RNN)。 2. 定义前向传播: 定义了注意力层的…

SAP-ABAP:SAP的Screen Layout Designer屏幕布局设计器详解及示例

在SAP中,Screen Layout Designer(屏幕布局设计器)是用于设计和维护屏幕(Dynpro)布局的工具。通过Screen Layout Designer,您可以创建和修改屏幕元素(如输入字段、按钮、文本、表格控件等&#x…

windows11+ubuntu20.04双系统下卸载ubuntu并重新安装

windows11ubuntu20.04双系统下卸载ubuntu并重新安装 背景:昨晚我电脑ubuntu20.04系统突然崩溃了,无奈只能重装系统了(好在没有什么重要数据)。刚好趁着这次换个ubuntu24.04系统玩一下,学习一下ROS2。 现系统&#xff…

SpringBoot速成(11)更新用户头像,密码P13-P14

更新头像: 1.代码展示: 1.RequestParam 是 Spring MVC 中非常实用的注解,用于从 HTTP 请求中提取参数并绑定到控制器方法的参数上。 2.PatchMapping 是 Spring MVC 中的一个注解,用于处理 HTTP 的 PATCH 请求。PATCH 请求通常用于对资源的部…

DeepSeek R1 与 OpenAI O1:机器学习模型的巅峰对决

我的个人主页 我的专栏:人工智能领域、java-数据结构、Javase、C语言,希望能帮助到大家!!!点赞👍收藏❤ 一、引言 在机器学习的广袤天地中,大型语言模型(LLM)无疑是最…

Datawhale 数学建模导论二 笔记1

第6章 数据处理与拟合模型 本章主要涉及到的知识点有: 数据与大数据Python数据预处理常见的统计分析模型随机过程与随机模拟数据可视化 本章内容涉及到基础的概率论与数理统计理论,如果对这部分内容不熟悉,可以参考相关概率论与数理统计的…

【个人开发】deepspeed+Llama-factory 本地数据多卡Lora微调

文章目录 1.背景2.微调方式2.1 关键环境版本信息2.2 步骤2.2.1 下载llama-factory2.2.2 准备数据集2.2.3 微调模式2.2.3.1 zero-3微调2.2.3.2 zero-2微调2.2.3.3 单卡Lora微调 2.3 踩坑经验2.3.1 问题一:ValueError: Undefined dataset xxxx in dataset_info.json.2…

STM32 如何使用DMA和获取ADC

目录 背景 ‌摇杆的原理 程序 端口配置 ADC 配置 DMA配置 背景 DMA是一种计算机技术,允许某些硬件子系统直接访问系统内存,而不需要中央处理器(CPU)的介入,从而减轻CPU的负担。我们可以通过DMA来从外设&#xf…

Jvascript网页设计案例:通过js实现一款密码强度检测,适用于等保测评整改

本文目录 前言功能预览样式特点总结:1. 整体视觉风格2. 密码输入框设计3. 强度指示条4. 结果文本与原因说明 功能特点总结:1. 密码强度检测2. 实时反馈机制3. 详细原因说明4. 视觉提示5. 交互体验优化 密码强度检测逻辑Html代码Javascript代码 前言 能满…

Mybatis高级(动态SQL)

目录 一、动态SQL 1.1 数据准备&#xff1a; 1.2 <if>标签 1.3<trim> 标签 1.4<where>标签 1.5<set>标签 1.6 <foreach>标签 1.7<include> 标签 一、动态SQL 动态SQL是Mybatis的强⼤特性之⼀&#xff0c;能够完成不同条件下不同…

mac 意外退出移动硬盘后再次插入移动硬盘不显示怎么办

第一步&#xff1a;sudo ps aux | grep fsck 打开mac控制台输入如下指令&#xff0c;我们看到会出现两个进程&#xff0c;看进程是root的这个 sudo ps aux|grep fsck 第二步&#xff1a;杀死进程 在第一步基础上我们知道不显示u盘的进程是&#xff1a;62319&#xff0c;我们…