【神经网络】python实现神经网络(一)——数据集获取

一.概述

        在文章【机器学习】一个例子带你了解神经网络是什么中,我们大致了解神经网络的正向信息传导、反向传导以及学习过程的大致流程,现在我们正式开始进行代码的实现,首先我们来实现第一步的运算过程模拟讲解:正向传导。本次代码实现将以“手写数字识别”为例子。

二.测试训练数据集的获取

        首先我们需要通过官网获取到手写数字识别数据集,数据集一共分为四个部分,分别是训练集的图片(六万张)、训练集的标签、测试集的图片(一万张)以及测试集的标签。所以我们在代码中可以使用键值表示对应的key-value:

url_base = 'http://yann.lecun.com/exdb/mnist/'
key_file = {'train_img':'train-images-idx3-ubyte.gz','train_label':'train-labels-idx1-ubyte.gz','test_img':'t10k-images-idx3-ubyte.gz','test_label':'t10k-labels-idx1-ubyte.gz'
}

        同时,我们需要将下载的文件保存到与代码同一级目录下:

dataset_dir = os.path.dirname(os.path.abspath(__file__))

        下载部分十分简单么,就不在此赘述,需要注意的是代码使用了python的urlretrieve函数,该函数需要使用头文件urllib.request,需要自行下载:

def download_mnist():for filename in key_file.values():file_path = dataset_dir + "/" + filenameif os.path.exists(file_path):returnprint("Downloading " + filename + " ... ")urllib.request.urlretrieve(url_base + filename, file_path)print("Done")

三.测试训练数据集的加载

        下载完数据集后,我们需要将其加载到我们的程序中以供后续的使用,首先是判断一下我们是否已经下载过数据集,如果没有下载,则先进行下载操作,再执行其他步骤:

    if not os.path.exists(save_file) :download_mnist()dataset = _convert_numpy()print("Creating pickle file ...")with open(save_file, 'wb') as f:pickle.dump(dataset, f, -1)print("Done!")

        以上代码有个需要注意的地方,因为下载完数据集之后无法直接给到python使用,所以还需要对数据进行格式处理,处理成python可以识别的格式,这一步交由函数_convert_numpy实现:

def _convert_numpy():    dataset = {}dataset['train_img'] = _load_img(key_file['train_img'])dataset['train_label'] = _load_label(key_file['train_label'])dataset['test_img'] = _load_img(key_file['test_img'])dataset['test_label'] = _load_label(key_file['test_label'])return dataset

       其中,_load_img函数负责处理图片数据:

def _load_img(file_name):file_path = dataset_dir + "\\MNIST\\" + file_nameprint("Converting " + file_name + " to NumPy Array ...")with gzip.open(file_path, 'rb') as f:data = np.frombuffer(f.read(), np.uint8, offset=16)data = data.reshape(-1, img_size)print("Done")return data

        其中,_load_label函数负责处理标签数据:

def _load_label(file_name):file_path = dataset_dir + "\\MNIST\\" + file_nameprint("Converting " + file_name + " to NumPy Array ...")with gzip.open(file_path, 'rb') as f:labels = np.frombuffer(f.read(), np.uint8, offset=8)print("Done")return labels

        函数中使用到的都是一些python常用的函数,所以具体作用不在赘述,可自行查询。介绍完_convert_numpy函数,我们继续回到数据集加载函数本身,为了方便后续数据集的批量调用等操作,我们需要在加载数据后对其进行进一步的数据清洗整理等预处理,分别为数据归一化(normalize)、图像展开(flatten)以及图像标签对应(one_hot_label),先将三个功能代码贴上,然后我们再详细讲解各个功能的具体作用:

    with open(save_file,'rb') as f:dataset = pickle.load(f)if normalize:for key in ['train_img','test_img']:dataset[key] = dataset[key].astype(np.float32)if not flatten:for key in ('train_img', 'test_img'):dataset[key] = dataset[key].reshape(-1, 1, 28, 28)if one_hot_label:dataset['train_label'] = _change_one_hot_label(dataset['train_label'])dataset['test_label'] = _change_one_hot_label(dataset['test_label'])

3.1.数据归一化(normalize)

        数据归一化normalize如果设置为True,可以将输入图像归一化为0.0~1.0 的值。如果将该参数设置为False,则输入图像的像素会保持原来的0~255。函数实现是使用了python函数中的astype功能将数据,用于将数据集指定字段的数据转换为 float32 类型,常见于深度学习模型输入前的数据预处理。

dataset[key] = dataset[key].astype(np.float32)

3.2.图像展开(flatten)

        图像展开flatten用于设置是否展开输入图像使其变成一维数组。如果将该参数设置为False,则输入图像为1 × 28 × 28 的三维数组;若设置为True,则输入图像会保存为由784 个元素构成的一维数组。函数实现也只是使用到深度学习中常用的reshape函数:

 dataset[key] = dataset[key].reshape(-1, 1, 28, 28)

3.3.图像标签对应(one_hot_label)

        图像标签对应one_hot_label用于设置是否将标签保存为onehot表示(one-hot representation)。one-hot 表示是仅正确解标签为1,其余皆为0 的数组,就像[0,0,1,0,0,0,0,0,0,0]这样。当one_hot_label为False时,就是像7、2这样简单保存正确解标签,函数_change_one_hot_label的实现如下:

def _change_one_hot_label(X):T = np.zeros((X.size, 10))for idx, row in enumerate(T):row[X[idx]] = 1return T

        以上即为测试训练数据集加载函数的全部内容,我们将在下面正式调用一下看看是否能够正常工作,在此贴上函数全文:

ef load_mnist(normalize=True, flatten=True, one_hot_label=False):if not os.path.exists(save_file) :download_mnist()dataset = _convert_numpy()print("Creating pickle file ...")with open(save_file, 'wb') as f:pickle.dump(dataset, f, -1)print("Done!")with open(save_file,'rb') as f:dataset = pickle.load(f)if normalize:for key in ['train_img','test_img']:dataset[key] = dataset[key].astype(np.float32)if not flatten:for key in ('train_img', 'test_img'):dataset[key] = dataset[key].reshape(-1, 1, 28, 28)if one_hot_label:dataset['train_label'] = _change_one_hot_label(dataset['train_label'])dataset['test_label'] = _change_one_hot_label(dataset['test_label'])return (dataset['train_img'],dataset['train_label']),(dataset['test_img'],dataset['test_label'])

四.测试训练数据集的使用测试

        我们可以加载数据集并且查看到各个数据集的形状:

(x_train, t_train), (x_test, t_test) = load_mnist(flatten=True,normalize=False)
# 输出各个数据的形状
print(x_train.shape) # (60000, 784)
print(t_train.shape) # (60000,)
print(x_test.shape) # (10000, 784)
print(t_test.shape) # (10000,)

        根据输出我们可以看到,训练集图片有六万张,每张图片有784各像素(28*28),训练集标签和照片数量一样(那是肯定的),测试集图片和标签数量比训练集的少,主要用来验证模型学习后的正确性。

        我们甚至还能随机从数据集中抽取一张照片查看一下实际样子,具体实现如下:

def img_show(img):
pil_img = Image.fromarray(np.uint8(img))
pil_img.show()
(x_train, t_train), (x_test, t_test) = load_mnist(flatten=True,normalize=False)
img = x_train[0]
label = t_train[0]
print(label) # 5
print(img.shape) # (784,)
img = img.reshape(28, 28) # 把图像的形状变成原来的尺寸
print(img.shape) # (28, 28)
img_show(img)

        输出的图片如图下所示:

        在后面的文章中,我们将开始正式步入主题,讲解神经网络如何学习,各层次之间如何传递数值,如何反向传导,计算损失,又在重新学习,最终实现传入一张手写数字就能自动识别出具体的数字的。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/30223.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

黑金风格人像静物户外旅拍Lr调色教程,手机滤镜PS+Lightroom预设下载!

调色教程 针对人像、静物以及户外旅拍照片,运用 Lightroom 软件进行风格化调色工作。旨在通过软件中的多种工具,如基本参数调整、HSL(色相、饱和度、明亮度)调整、曲线工具等改变照片原本的色彩、明度、对比度等属性,将…

Kubernetes中的 iptables 规则介绍

#作者:邓伟 文章目录 一、Kubernetes 网络模型概述二、iptables 基础知识三、Kubernetes 中的 iptables 应用四、查看和调试 iptables 规则五、总结 在 Kubernetes 集群中,iptables 是一个核心组件, 用于实现服务发现和网络策略。iptables 通…

C语言_数据结构总结5:顺序栈

纯C语言代码,不涉及C 想了解链式栈的实现,欢迎查看这篇文章:C语言_数据结构总结6:链式栈-CSDN博客 这里分享插入一下个人觉得很有用的习惯: 1. 就是遇到代码哪里不理解的,你就问豆包,C知道&a…

STM32之ADC

逐次逼近式ADC: 左边是8路输入通道,左下是地址锁存和译码,可将通道的地址锁存进ADDA,ADDB,ADDC类似38译码器的结构,ALE为锁存控制键,通道选择开关可控制选择单路或者多路通道,DAC为…

Magento2根据图片文件包导入产品图片

图片包给的图片文件是子产品的图片&#xff0c;如下图&#xff1a;A104255是主产品的sku <?php/*** 根据图片包导入产品图片&#xff0c;包含子产品和主产品* 子产品是作为主图&#xff0c;主产品是作为附加图片*/use Magento\Framework\App\Bootstrap;include(../app/boot…

初学STM32之简单认识IO口配置(学习笔记)

在使用51单片机的时候基本上不需要额外的配置IO&#xff0c;不过在使用特定的IO的时候需要额外的设计外围电路&#xff0c;比如PO口它是没有内置上拉电阻的。因此若想P0输出高电平&#xff0c;它就需要外接上拉电平。&#xff08;当然这不是说它输入不需要上拉电阻&#xff0c;…

图像生成-ICCV2019-SinGAN: Learning a Generative Model from a Single Natural Image

图像生成-ICCV2019-SinGAN: Learning a Generative Model from a Single Natural Image 文章目录 图像生成-ICCV2019-SinGAN: Learning a Generative Model from a Single Natural Image主要创新点模型架构图生成器生成器源码 判别器判别器源码 损失函数需要源码讲解的私信我 S…

STM32之I2C硬件外设

注意&#xff1a;硬件I2C的引脚是固定的 SDA和SCL都是复用到外部引脚。 SDA发送时数据寄存器的数据在数据移位寄存器空闲的状态下进入数据移位寄存器&#xff0c;此时会置状态寄存器的TXE为1&#xff0c;表示发送寄存器为空&#xff0c;然后往数据控制寄存器中一位一位的移送数…

Git - 补充工作中常用的一些命令

Git - 补充工作中常用的一些命令 1 一些场景1.1 场景11.2 场景21.3 场景31.4 场景41.5 场景51.6 场景61.7 场景71.8 场景81.9 场景91.10 场景101.11 场景111.12 场景121.13 场景131.14 场景141.15 场景15 2 git cherry-pick \<commit-hash\> 和 git checkout branch \-\-…

AI 驱动的软件测试革命:从自动化到智能化的进阶之路

&#x1f680;引言&#xff1a;软件测试的智能化转型浪潮 在数字化转型加速的今天&#xff0c;软件产品的迭代速度与复杂度呈指数级增长。传统软件测试依赖人工编写用例、执行测试的模式&#xff0c;已难以应对快速交付与高质量要求的双重挑战。人工智能技术的突破为测试领域注…

Unity--Cubism Live2D模型使用

了解LIVE2D在unity的使用--前提记录 了解各个组件的作用 Live2D Manuals & Tutorials 这些文件都是重要的控制动画参数的 Cubism Editor是编辑Live2D的工具&#xff0c;而导出的数据的类型&#xff0c;需要满足以上的条件 SDK中包含的Cubism的Importer会自动生成一个Pref…

Windows 系统 Docker Desktop 入门教程:从零开始掌握容器化技术

文章目录 前言一、Docker 简介二、Docker Desktop 安装2.1 系统要求2.2 安装步骤 三、Docker 基本概念四、Docker 常用命令五、实战&#xff1a;运行你的第一个容器5.1 拉取并运行 Nginx 容器5.2 查看容器日志5.3 停止并删除容器 六、总结 前言 随着云计算和微服务架构的普及&…

Lab17_ Blind SQL injection with out-of-band data exfiltration

文章目录 前言&#xff1a;进入实验室构造 payload 前言&#xff1a; 实验室标题为&#xff1a; 带外数据泄露的 SQL 盲注 简介&#xff1a; 本实验包含一个SQL盲目注入漏洞。应用程序使用跟踪Cookie进行分析&#xff0c;并执行包含提交的Cookie值的SQL查询。 SQL查询是异…

Vue 框架深度解析:源码分析与实现原理详解

文章目录 一、Vue 核心架构设计1.1 整体架构流程图1.2 模块职责划分 二、响应式系统源码解析2.1 核心类关系图2.2 核心源码分析2.2.1 数据劫持实现2.2.2 依赖收集过程 三、虚拟DOM与Diff算法实现3.1 Diff算法流程图3.2 核心Diff源码 四、模板编译全流程剖析4.1 编译流程图4.2 编…

Linux基本指令

一&#xff1a;Xshell相关快捷键 1.AltEnter进入Xshell全屏模式&#xff0c;再按一次AltEnter退出Xshell全屏模式 2.Ctrl Insert复制 3.Shift Insert粘粘 二&#xff1a;Linux基本指令 1.clear&#xff1a; 清屏&#xff1a;即将屏幕框上的所有内容删除 2.pwd&#xf…

Python基于Django的医用耗材网上申领系统【附源码、文档说明】

博主介绍&#xff1a;✌Java老徐、7年大厂程序员经历。全网粉丝12w、csdn博客专家、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ &#x1f345;文末获取源码联系&#x1f345; &#x1f447;&#x1f3fb; 精彩专栏推荐订阅&#x1f447;&…

计算机视觉|3D卷积网络VoxelNet:点云检测的革新力量

一、引言 在科技快速发展的背景下&#xff0c;3D 目标检测技术在自动驾驶和机器人领域中具有重要作用。 在自动驾驶领域&#xff0c;车辆需实时、准确感知周围环境中的目标物体&#xff0c;如行人、车辆、交通标志和障碍物等。只有精确检测这些目标的位置、姿态和类别&#x…

【AD】5-13 特殊粘贴使用

同等间距复制很多过孔 复制之后进行特殊性粘贴&#xff0c;选择阵列粘贴 将元件带位号、带网络从PCB复制粘贴到另一个PCB 全选PCB并复制&#xff0c;来到另一个PCB&#xff0c;点击特殊性粘贴

Unity自定义区域UI滑动事件

自定义区域UI滑动事件 介绍制作1.创建一个Image2.创建脚本 总结 介绍 一提到滑动事件联想到有太多的插件了比如EastTouchBundle&#xff0c;今天想单纯通过UI去做一个滑动事件而不是基于Box2d或者Box去做滑动事件。 制作 1.创建一个Image 2.创建脚本 using UnityEngine; us…

报表DSL优化,享元模式优化过程,优化效果怎么样?

报表DSL优化与享元模式应用详解 一、报表DSL优化 1. 问题背景 报表系统通常使用领域特定语言&#xff08;DSL&#xff09;定义模板结构、数据绑定规则及样式配置。随着复杂度提升&#xff0c;DSL可能面临以下问题&#xff1a; 冗余配置&#xff1a;重复定义样式、布局或数据源…