【深度学习实验】前馈神经网络(四):自定义逻辑回归模型:前向传播、反向传播算法

目录

一、实验介绍

 二、实验环境

1. 配置虚拟环境

2. 库版本介绍

三、实验内容

0. 导入必要的工具包

1. 逻辑回归Logistic类

a. 构造函数__init__

b. __call__(self, x)方法

c. 前向传播forward

d. 反向传播backward

2. 模型训练

3. 代码整合


一、实验介绍

  • 实现逻辑回归模型(Logistic类)
    • 实现前向传播forward
    • 实现反向传播backward

 二、实验环境

    本系列实验使用了PyTorch深度学习框架,相关操作如下:

1. 配置虚拟环境

conda create -n DL python=3.7 
conda activate DL
pip install torch==1.8.1+cu102 torchvision==0.9.1+cu102 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
conda install matplotlib
 conda install scikit-learn

2. 库版本介绍

软件包本实验版本目前最新版
matplotlib3.5.33.8.0
numpy1.21.61.26.0
python3.7.16
scikit-learn0.22.11.3.0
torch1.8.1+cu1022.0.1
torchaudio0.8.12.0.2
torchvision0.9.1+cu1020.15.2

三、实验内容

ChatGPT:

        前馈神经网络(Feedforward Neural Network)是一种常见的人工神经网络模型,也被称为多层感知器(Multilayer Perceptron,MLP)。它是一种基于前向传播的模型,主要用于解决分类和回归问题。

        前馈神经网络由多个层组成,包括输入层、隐藏层和输出层。它的名称"前馈"源于信号在网络中只能向前流动,即从输入层经过隐藏层最终到达输出层,没有反馈连接。

以下是前馈神经网络的一般工作原理:

  1. 输入层:接收原始数据或特征向量作为网络的输入,每个输入被表示为网络的一个神经元。每个神经元将输入加权并通过激活函数进行转换,产生一个输出信号。

  2. 隐藏层:前馈神经网络可以包含一个或多个隐藏层,每个隐藏层由多个神经元组成。隐藏层的神经元接收来自上一层的输入,并将加权和经过激活函数转换后的信号传递给下一层。

  3. 输出层:最后一个隐藏层的输出被传递到输出层,输出层通常由一个或多个神经元组成。输出层的神经元根据要解决的问题类型(分类或回归)使用适当的激活函数(如Sigmoid、Softmax等)将最终结果输出。

  4. 前向传播:信号从输入层通过隐藏层传递到输出层的过程称为前向传播。在前向传播过程中,每个神经元将前一层的输出乘以相应的权重,并将结果传递给下一层。这样的计算通过网络中的每一层逐层进行,直到产生最终的输出。

  5. 损失函数和训练:前馈神经网络的训练过程通常涉及定义一个损失函数,用于衡量模型预测输出与真实标签之间的差异。常见的损失函数包括均方误差(Mean Squared Error)和交叉熵(Cross-Entropy)。通过使用反向传播算法(Backpropagation)和优化算法(如梯度下降),网络根据损失函数的梯度进行参数调整,以最小化损失函数的值。

        前馈神经网络的优点包括能够处理复杂的非线性关系,适用于各种问题类型,并且能够通过训练来自动学习特征表示。然而,它也存在一些挑战,如容易过拟合、对大规模数据和高维数据的处理较困难等。为了应对这些挑战,一些改进的网络结构和训练技术被提出,如卷积神经网络(Convolutional Neural Networks)和循环神经网络(Recurrent Neural Networks)等。

本系列为实验内容,对理论知识不进行详细阐释

(咳咳,其实是没时间整理,待有缘之时,回来填坑)

977468b5ae9843c6a88005e792817cb1.png

0. 导入必要的工具包

import torch

1. 逻辑回归Logistic

a. 构造函数__init__

 def __init__(self):self.inputs = Noneself.outputs = Noneself.params = None

         初始化了类的成员变量self.inputsself.outputsself.params,它们分别用于保存输入、输出和参数。

b. __call__(self, x)方法

    __call__(self, x)方法使得该类的实例可以像函数一样被调用。它调用了forward(x)方法,将输入的x传递给前向传播方法。

 def __call__(self, x):return self.forward(x)

c. 前向传播forward

  def forward(self, inputs):outputs = 1.0 / (1.0 + torch.exp(-inputs))self.outputs = outputsreturn outputs

    forward(self, inputs)方法执行逻辑回归的前向传播。它接受输入inputs作为参数,并通过逻辑回归的公式计算输出值outputs。最后,将计算得到的输出保存在self.outputs中,并返回输出值。

d. 反向传播backward

    def backward(self, outputs_grads=None):if outputs_grads is None:outputs_grads = torch.ones(self.outputs.shape)outputs_grad_inputs = torch.multiply(self.outputs, (1.0 - self.outputs))return torch.multiply(outputs_grads, outputs_grad_inputs)

    backward(self, outputs_grads=None)方法执行逻辑回归的反向传播。

  • 接受一个可选的参数outputs_grads,用于传递输出的梯度。
  • 如果没有提供outputs_grads,则默认为全1的张量,表示对输出的梯度都为1。
  • 根据逻辑回归的导数公式,可以将输出值与(1-输出值)相乘,然后再乘以传入的梯度值,得到输入的梯度。
  • 返回计算得到的输入梯度。

2. 模型训练

act = Logistic()
x = torch.tensor([3,3,4,2])
y = act(x)z = act.backward()
print(z)
  • 创建一个Logistic的实例act;
  • 传入张量x进行前向传播,得到输出张量y;
  • 调用act.backward()进行反向传播,得到输入x的梯度;
  • 将结果打印输出。
tensor([0.0452, 0.0452, 0.0177, 0.1050])

3. 代码整合

# 导入必要的工具包
import torchclass Logistic():def __init__(self):self.inputs = Noneself.outputs = Noneself.params = Nonedef __call__(self, x):return self.forward(x)def forward(self, inputs):outputs = 1.0 / (1.0 + torch.exp(-inputs))self.outputs = outputsreturn outputsdef backward(self, outputs_grads=None):if outputs_grads is None:outputs_grads = torch.ones(self.outputs.shape)outputs_grad_inputs = torch.multiply(self.outputs, (1.0 - self.outputs))return torch.multiply(outputs_grads, outputs_grad_inputs)act = Logistic()
x = torch.tensor([3,3,4,2])
y = act(x)z = act.backward()
print(z)

注意:

        本实验仅实现了逻辑回归的前向传播和反向传播部分,缺少了模型的参数更新和训练部分。完整的逻辑回归,需要进一步编写训练循环、损失函数和优化器等部分,欲知后事如何,请听下回分解。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/139040.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

YOLOv5、YOLOv8改进:Decoupled Head解耦头

目录 1.Decoupled Head介绍 2.Yolov5加入Decoupled_Detect 2.1 DecoupledHead加入common.py中: 2.2 Decoupled_Detect加入yolo.py中: 2.3修改yolov5s_decoupled.yaml 1.Decoupled Head介绍 Decoupled Head是一种图像分割任务中常用的网络结构&#…

MySQL进阶 —— 超详细操作演示!!!(中)

MySQL进阶 —— 超详细操作演示!!!(中) 三、SQL 优化3.1 插入数据3.2 主键优化3.3 order by 优化3.4 group by 优化3.5 limit 优化3.6 count 优化3.7 update 优化 四、视图/存储过程/触发器4.1 视图4.2 存储过程4.3 存…

阿里云大数据实战记录10:Hive 兼容模式的坑

文章目录 1、前言2、什么是 Hive 兼容模式?3、为什么要开启 Hive 模式?4、有什么副作用?5、如何开启 Hive 兼容模式?6、该场景下,能不能不开启 Hive 兼容模式?7、为什么不是DATE_FORMAT(datetime, string)&…

【Qt-17】Qt调用matlab生成的dll库

matlab生成dll库 1、matlab示例代码 function BDCube(x,y)[x,y,z] cylinder(x,y);t1 hgtransform;s1 surf(3*x,3*y,4*z,Parent,t1);grid onview(3)shading interp end 2、matlab环境配置 首先检查自己的mcc编译器是否可用,输出以下命令: &#x…

如何在没有第三方.NET库源码的情况,调试第三库代码?

大家好,我是沙漠尽头的狼。 本方首发于Dotnet9,介绍使用dnSpy调试第三方.NET库源码,行文目录: 安装dnSpy编写示例程序调试示例程序调试.NET库原生方法总结 1. 安装dnSpy dnSpy是一款功能强大的.NET程序反编译工具,…

【Java 基础篇】Java线程安全与并发问题详解

多线程编程在Java中是一个常见的需求,它可以提高程序的性能和响应能力。然而,多线程编程也带来了一系列的线程安全与并发问题。在本文中,我们将深入探讨这些问题,以及如何解决它们,适用于Java初学者和基础用户。 什么…

【AI视野·今日NLP 自然语言处理论文速览 第三十六期】Wed, 20 Sep 2023

AI视野今日CS.NLP 自然语言处理论文速览 Wed, 20 Sep 2023 Totally 64 papers 👉上期速览✈更多精彩请移步主页 Daily Computation and Language Papers SlimPajama-DC: Understanding Data Combinations for LLM Training Authors Zhiqiang Shen, Tianhua Tao, Li…

原生js值之数据类型详解

js的数据类型 数据类型分类基本数据类型boolean:布尔类undefined:未定义的值null类型数值转换 NumberparseInt 转换整数 parseFloat转换浮点数 String类型特点如何转换成字符串模板字面量字符串插值模板字面量标签函数 symbol类型特性使用 BigInt类型复杂数据类型Object类属性与…

[杂谈]-八进制数

八进制数 文章目录 八进制数1、概述2、八进制数的表示2.1 八进制数2.2 以八进制计数2.3 二进制数补零 3、八进制到十进制转换4、十进制到八进制转换5、二进制到八进制转换示例6、八进制到二进制和十进制转换示例7、总结 1、概述 八进制编号系统是另一种使用基数为8计数系统&am…

【Stm32】【Lin通信协议】Lin通信点亮灯实验

Lin通信点亮灯实验 通过STM32的串口发送数据,然后通过串口转换模块将数据转换成LIN(Local Interconnect Network)协议,最终控制点亮灯。需要工程和入门资料的可以私信我,看到了马上回。 入门书本推荐: 一…

【C++面向对象侯捷下】2.转换函数 | 3.non-explicit-one-argument ctor

文章目录 operator double() const {} 歧义了 标准库的转换函数

exe文件运行后无输出直接闪退如何找解决办法

一.搜索栏搜事件查看器 二.点开windows日志下的应用程序 三.找到错误处 四.搜索异常代码 点开有错误的详细信息,直接用搜索引擎搜索这个异常代码能大致判断是什么问题,给了一个解决思路,不至于不知道到底哪里出了问题

AUTOSAR词典:CAN驱动Mailbox配置技术要点全解析

AUTOSAR词典:CAN驱动Mailbox配置技术要点全解析 前言 首先,请问大家几个小小问题,你清楚: AUTOSAR框架下的CAN驱动关键词定义吗?是不是有些总是傻傻分不清楚呢?CAN驱动Mailbox配置过程中有哪些关键配置参…

Angular变更检测机制

前段时间遇到这样一个 bug,通过一个 click 事件跳转到一个新页面,新页面迟迟不加载; 经过多次测试发现,将鼠标移入某个 tab ,页面就加载出来了。 举个例子,页面内容无法加载,但是将鼠标移入下图…

[面试] k8s面试题 2

文章目录 核心组件1.什么是 Kubernetes 中的控制器(Controller)?请提供一些常见的控制器类型。2.请解释一下 Kubernetes 中的 Ingress 是什么,以及它的作用。3.如何通过命令行在 Kubernetes 中创建一个 Pod?4.Stateful…

Pdf文件签名检查

如何检查pdf的签名 首先这里有一个已经签名的pdf文件&#xff0c;通过pdf软件可以看到文件的数字签名。 图1为签名后的文件&#xff0c;图2为签名后文件被篡改。 下面就是如何代码检查这里pdf文件的签名 1.引入依赖 <dependency><groupId>org.projectlombok<…

数据结构——单链表

目录 一.前言 二.链表表示和实现&#xff08;单链表&#xff09; 1.1 顺序表的优缺点 1.2 链表的概念及结构 1.3 打印函数 1.4 空间函数 1.5 尾插函数&#xff08;最最最麻烦的&#xff09; 1.5.1 尾插最关键部分&#xff01; 1.6 头插函数 1.7 尾删函数…

云流化:XR扩展现实应用发展的一个新方向!

扩展现实的发展已经改变了我们工作、生活和娱乐的方式&#xff0c;而且这才刚刚开始。扩展现实 (Extended reality, XR) 涵盖了沉浸式技术&#xff0c;包括虚拟现实、增强现实和混合现实。从游戏到虚拟制作再到产品设计&#xff0c;XR 使人们能够以前所未有的方式在计算机生成的…

#循循渐进学51单片机#指针基础与1602液晶的初步认识#not.11

1、把本节课的指针相关内容&#xff0c;反复学习3到5遍&#xff0c;彻底弄懂指针是怎么回事&#xff0c;即使是死记硬背也要记住&#xff0c;等到后边用的时候可以实现顿悟。学会指针&#xff0c;就是突破了C语言的一道壁垒。 2&#xff0c;1602所有的指令功能都应用一遍&#…

vue3——pixi初学,编写一个简单的小游戏,复制粘贴可用学习

pixi官网 小游戏效果 两个文件夹 一个index.html 一个data.js //data.js import { reactive } from "vue"; import { Sprite, utils, Rectangle, Application, Text, Graphics } from "pixi.js";//首先 先创建一个舞台 export const app new Applicat…