【深度学习实验】前馈神经网络(八):模型评价(自定义支持分批进行评价的Accuracy类)

目录

一、实验介绍

 二、实验环境

1. 配置虚拟环境

2. 库版本介绍

三、实验内容

 0. 导入必要的工具包

1. __init__(构造函数)

2. update函数(更新评价指标)

5. accumulate(计算准确率)

4. reset(重置评价指标)

5. 构造数据进行测试

6. 代码整合


一、实验介绍

       本文将实现一个辅助功能——计算预测的准确率。Accuracy支持对每一个回合中每批数据进行评价,并将结果累积,最终获得整批数据的评价结果。

  • 在训练或验证过程中迭代地调用update方法来更新评价指标;
  • 使用accumulate方法获取累计的准确率;
  • 通过reset方法重置评价指标,以便进行下一轮的计算。

 二、实验环境

    本系列实验使用了PyTorch深度学习框架,相关操作如下:

1. 配置虚拟环境

conda create -n DL python=3.7 
conda activate DL
pip install torch==1.8.1+cu102 torchvision==0.9.1+cu102 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
conda install matplotlib
 conda install scikit-learn

2. 库版本介绍

软件包本实验版本目前最新版
matplotlib3.5.33.8.0
numpy1.21.61.26.0
python3.7.16
scikit-learn0.22.11.3.0
torch1.8.1+cu1022.0.1
torchaudio0.8.12.0.2
torchvision0.9.1+cu1020.15.2

三、实验内容

ChatGPT:

        前馈神经网络(Feedforward Neural Network)是一种常见的人工神经网络模型,也被称为多层感知器(Multilayer Perceptron,MLP)。它是一种基于前向传播的模型,主要用于解决分类和回归问题。

        前馈神经网络由多个层组成,包括输入层、隐藏层和输出层。它的名称"前馈"源于信号在网络中只能向前流动,即从输入层经过隐藏层最终到达输出层,没有反馈连接。

以下是前馈神经网络的一般工作原理:

  1. 输入层:接收原始数据或特征向量作为网络的输入,每个输入被表示为网络的一个神经元。每个神经元将输入加权并通过激活函数进行转换,产生一个输出信号。

  2. 隐藏层:前馈神经网络可以包含一个或多个隐藏层,每个隐藏层由多个神经元组成。隐藏层的神经元接收来自上一层的输入,并将加权和经过激活函数转换后的信号传递给下一层。

  3. 输出层:最后一个隐藏层的输出被传递到输出层,输出层通常由一个或多个神经元组成。输出层的神经元根据要解决的问题类型(分类或回归)使用适当的激活函数(如Sigmoid、Softmax等)将最终结果输出。

  4. 前向传播:信号从输入层通过隐藏层传递到输出层的过程称为前向传播。在前向传播过程中,每个神经元将前一层的输出乘以相应的权重,并将结果传递给下一层。这样的计算通过网络中的每一层逐层进行,直到产生最终的输出。

  5. 损失函数和训练:前馈神经网络的训练过程通常涉及定义一个损失函数,用于衡量模型预测输出与真实标签之间的差异。常见的损失函数包括均方误差(Mean Squared Error)和交叉熵(Cross-Entropy)。通过使用反向传播算法(Backpropagation)和优化算法(如梯度下降),网络根据损失函数的梯度进行参数调整,以最小化损失函数的值。

        前馈神经网络的优点包括能够处理复杂的非线性关系,适用于各种问题类型,并且能够通过训练来自动学习特征表示。然而,它也存在一些挑战,如容易过拟合、对大规模数据和高维数据的处理较困难等。为了应对这些挑战,一些改进的网络结构和训练技术被提出,如卷积神经网络(Convolutional Neural Networks)和循环神经网络(Recurrent Neural Networks)等。

本系列为实验内容,对理论知识不进行详细阐释

(咳咳,其实是没时间整理,待有缘之时,回来填坑)

977468b5ae9843c6a88005e792817cb1.png​​

 0. 导入必要的工具包

import torch
from sklearn.datasets import load_iris
from torch.utils.data import Dataset, DataLoader
  • DatasetDataLoader类用于处理数据集和数据加载

这段代码定义了一个名为Accuracy的类,用于支持分批进行模型评价,特别是在分类任务中计算准确率。

1. __init__(构造函数)

class Accuracy:def __init__(self, is_logist=True):self.num_correct = 0self.num_count = 0self.is_logist = is_logist
  • 构造函数在创建Accuracy对象时被调用。它接受一个可选的参数is_logist,默认为True,用于指示是否为logist形式的预测值。
  • self.num_correct用于记录正确预测的样本个数。
  • self.num_count用于记录总样本个数。
  • self.is_logist指示是否为logist形式的预测值。

2. update函数(更新评价指标)

def update(self, outputs, labels):if outputs.shape[1] == 1:outputs = outputs.squeeze(-1)if self.is_logist:preds = (outputs >= 0).long()else:preds = (outputs >= 0.5).long()else:preds = torch.argmax(outputs, dim=1).long()labels = labels.squeeze(-1)batch_correct = (preds==labels).float().sum()batch_count = len(labels)self.num_correct += batch_correctself.num_count += batch_count
  • update方法用于更新评价指标。它接受两个参数outputslabels,分别表示模型的预测输出和真实标签。
  • 根据outputs的形状判断任务类型。
    •  如果outputs是二维张量且第二维大小为1,那么表示是二分类任务。
      •   如果is_logist=True,则将outputs通过阈值(0)转换为预测值preds,并将其转换为整数类型。
      •   如果is_logist=False,则将outputs通过阈值(0.5)转换为预测值preds,并将其转换为整数类型。
    •  如果outputs是二维张量且第二维大小大于1,表示是多分类任务。此时,将outputs中概率最大的类别作为预测值preds
  • labels去除多余的维度,并计算本批数据中预测正确的样本个数batch_correct
  • 获取本批数据的样本个数batch_count
  • 更新num_correctnum_count,累积计算正确样本个数和总样本个数。

5. accumulate(计算准确率)

def accumulate(self):if self.num_count == 0:return 0return self.num_correct / self.num_count
  • accumulate方法用于计算准确率。
    •  如果num_count为0,表示没有进行过更新,返回0。
    • 否则,返回正确样本个数除以总样本个数的比例,即准确率

4. reset(重置评价指标)

def reset(self):self.num_correct = 0self.num_count = 0
  • reset方法用于重置评价指标,将num_correctnum_count重置为0,以便进行下一轮评价

5. 构造数据进行测试

y = torch.tensor([0, 2])
y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])
acc = Accuracy()
acc.update(y_hat, y)
acc.num_correct

6. 代码整合

import torch# 支持分批进行模型评价的 Accuracy 类
class Accuracy:def __init__(self, is_logist=True):# 正确样本个数self.num_correct = 0# 样本总数self.num_count = 0self.is_logist = is_logistdef update(self, outputs, labels):# 判断是否为二分类任务if outputs.shape[1] == 1:outputs = outputs.squeeze(-1)# 判断是否是logit形式的预测值if self.is_logist:preds = (outputs >= 0).long()else:preds = (outputs >= 0.5).long()else:# 多分类任务时,计算最大元素索引作为类别preds = torch.argmax(outputs, dim=1).long()# 获取本批数据中预测正确的样本个数labels = labels.squeeze(-1)batch_correct = (preds == labels).float().sum()batch_count = len(labels)# 更新self.num_correct += batch_correctself.num_count += batch_countdef accumulate(self):# 使用累计的数据,计算总的评价指标if self.num_count == 0:return 0return self.num_correct / self.num_countdef reset(self):self.num_correct = 0self.num_count = 0y = torch.tensor([0, 2])
y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])
acc = Accuracy()
acc.update(y_hat, y)
acc.num_correct

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/140483.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

react多条件查询

1、声明一个filter常量 2.filter接受(condition,data)两个参数 3、调用data里面的filter进行筛选 4、任意一个item当筛选条件 5、使用object.key获取对象所有key 6、对每个key使用Array.prototype.every()方法判断是否满足条…

华为数通方向HCIP-DataCom H12-821题库(单选题:361-380)

第361题 如图所示是一台路由器的BGP输出信息。那么以下关于这段信息的描述,错误的是哪一项? <Huawei>display bgp error Error Type: Peer Error Peer Address:10.1.1.2 VRFName:Public Error Info: Router-ID conflictA、该路由器邻居地址是10.1.1.2 B、Error Type显…

电脑计算机xinput1_3.dll丢失的解决方法分享,四种修复手段解决问题

日常生活中可能会遇到的问题——xinput1_3.dll丢失的解决方法。我相信&#xff0c;在座的很多朋友都曾遇到过这个问题&#xff0c;那么接下来&#xff0c;我将分享如何解决这个问题的解决方法。 首先&#xff0c;让我们来了解一下xinput1_3.dll文件。xinput1_3.dll是一个动态链…

【数据结构】什么是数据结构?

数据结构(Data Structure)是计算机存储,组织数据的方式,指相互之间存在一种或多种特定关系的数据元素的集合. 这么讲可能有些抽象,放一张图大家可能好理解一点: 上图依次是数据结构中逻辑结构中的:集合结构,线性结构,树形结构,图形结构. 而: 数据结构是一门研究非数值计算的程…

分库分表MySQL

目录 Mycat入门 分片配置 分片配置(配置Mycat的用户以及用户的权限) 启动服务 登录Mycat Mycat配置 schema.xml 1.schema标签:配置逻辑库,逻辑表的相关信息 1-1.核心属性 1-2.table标签 2.datanode标签:配置数据节点的相关信息 核心属性 3.datahost标签:配置的是节…

【性能测试】JMeter:集合点,同步定时器的应用实例!

一、集合点的定义 在性能测试过程中&#xff0c;为了真实模拟多个用户同时进行操作以度量服务器的处理能力&#xff0c;可以考虑同步虚拟用户以便恰好在同一时刻执行操作或发送请求。 通过插入集合点可以较真实模拟多个用户并发操作。 (注意&#xff1a;虽然通过加入集合点可…

蓝牙核心规范(V5.4)10.2-BLE 入门笔记之CIS篇

LE CIS 同步通信 同步通信提供了一种使用蓝牙LE在设备之间传输有时间限制的数据的方式。它提供了一个机制,允许多个接收器设备在不同的时间从相同的源接收数据,以同步它们对该数据的处理。LE AUDIO使用同步通信。 当使用同步通信时,数据具有有限的时间有效期,在到期时被认…

第1篇 目标检测概述 —(1)目标检测基础知识

前言&#xff1a;Hello大家好&#xff0c;我是小哥谈。目标检测是计算机视觉领域中的一项任务&#xff0c;旨在自动识别和定位图像或视频中的特定目标&#xff0c;目标可以是人、车辆、动物、物体等。目标检测的目标是从输入图像中确定目标的位置&#xff0c;并使用边界框将其标…

Linux指令(ls、pwd、cd、touch、mkdir、rm)

whoami who pwd ls ls -l clearls指令 ls ls -l ls -a :显示当前目录下的隐藏文件&#xff08;隐藏文件以.开头&#xff09;ls -a -l 和 ls -l -a 和 ls -la 和 ls -al &#xff08;等价于ll&#xff09; pwd命令 显示用户当前所在的目录 cd指令 mkdir code &#xff08;创建…

PPT系统化学习 - 第1天

文章目录 更改PPT主题更改最大撤回次数自动保存禁止PPT压缩图片字体嵌入PPTPPT导出为PDFPPT导出为图片PPT导出为图片型幻灯片PPT导出成视频添加参考线设置默认字体设置默认形状建立模板、保存模板、使用模板建立模板保存模板使用模板 更改PPT主题 更改PPT的主题&#xff1a; 夜…

【Vue】轻松理解数据代理

hello&#xff0c;我是小索奇&#xff0c;精心制作的Vue教程持续更新哈&#xff0c;想要学习&巩固&避坑就一起学习叭~ Object定义配置方法 代码 引出数据代理&#xff0c;先上代码&#xff0c;后加解释 <!DOCTYPE html> <html><head><meta cha…

《从零开始的Java世界》01基本程序设计

《从零开始的Java世界》系列主要讲解Javase部分&#xff0c;从最简单的程序设计到面向对象编程&#xff0c;再到异常处理、常用API的使用&#xff0c;最后到注解、反射&#xff0c;涵盖Java基础所需的所有知识点。学习者应该从学会如何使用&#xff0c;到知道其实现原理全方位式…

【20230921】关于sing-box命令行程序开机自启动运行(Windows、Linux)

1 背景 sing-box是一个命令行程序&#xff0c;官网给出的教程是复制链接到Git Bash&#xff08;windows&#xff09;或终端运行&#xff08;Linux&#xff09;。每次开机都进行复制运行是一件繁琐的事情。 复制的内容其实就是下次并运行shell脚本&#xff0c;其实系统只需要运…

2、SpringBoot_依赖介绍

三、SpringBoot介绍 1.parent 前言&#xff1a;之前是使用spring/springmvc 开发&#xff0c;整合不同的组件会有很多依赖&#xff0c;这些依赖会涉及到很多的版本信息&#xff0c;版本信息多了之后可能会导致版本冲突问题概述&#xff1a;把很多组件技术的搭配放到一起&…

linux 约束

linux 约束 1、约束的概念1.1什么是约束1.2约束的优劣势 2、约束的作用3、约束的分类4、约束的应用场景5、约束的管理5.1创建5.2查看5.3插入5.4删除 6、总结 1、约束的概念 1.1什么是约束 在关系型数据库中&#xff0c;约束是用于限制表中数据规则的一种机制。它可以确保表中…

基于springboot+vue的老年一站式服务平台

博主主页&#xff1a;猫头鹰源码 博主简介&#xff1a;Java领域优质创作者、CSDN博客专家、公司架构师、全网粉丝5万、专注Java技术领域和毕业设计项目实战 主要内容&#xff1a;毕业设计(Javaweb项目|小程序等)、简历模板、学习资料、面试题库、技术咨询 文末联系获取 项目介绍…

机器学习第十课--提升树

一.Bagging与Boosting的区别 在上一章里我们学习了一个集成模型叫作随机森林&#xff0c;而且也了解到随机森林属于Bagging的成员。本节我们重点来学习一下另外一种集成模型叫作Boosting。首先回顾一下什么叫Bagging? 比如在随机森林里&#xff0c;针对于样本数据&#xff0c;…

PostgreSQL 技术内幕(十)WAL log 模块基本原理

事务日志是数据库的重要组成部分&#xff0c;记录了数据库系统中所有更改和操作的历史信息。 WAL log(Write Ahead Logging)也被称为xlog&#xff0c;是事务日志的一种&#xff0c;也是关系数据库系统中用于保证数据一致性和事务完整性的一系列技术&#xff0c;在数据库恢复、高…

php函数usort使用方法

在 PHP 中&#xff0c;usort() 函数用于对数组进行排序&#xff0c;它允许你使用自定义的比较函数来确定元素的顺序。以下是 usort() 函数的使用方法&#xff1a; usort(array &$array, callable $cmp_function): bool参数说明&#xff1a; $array&#xff1a;要排序的数…

钉钉stream机器人-实操详细教程

支持事件订阅、机器人收消息、卡片回调等功能 优点&#xff1a; 配置简单&#xff0c;不依赖也不需要暴露公网IP&#xff0c;无需向公网开放端口 github官方链接&#xff1a;GitHub - open-dingtalk/dingtalk-stream-sdk-python: Python SDK for DingTalk Stream Mode API, Co…