【深度学习实验】线性模型(二):使用NumPy实现线性模型:梯度下降法

目录

一、实验介绍

二、实验环境

1. 配置虚拟环境

2. 库版本介绍

三、实验内容

0. 导入库

1. 初始化参数

2. 线性模型 linear_model

3. 损失函数loss_function

4. 梯度计算函数compute_gradients

5. 梯度下降函数gradient_descent

6. 调用函数


一、实验介绍

        使用NumPy实现线性模型:梯度下降法

二、实验环境

        本系列实验使用了PyTorch深度学习框架,相关操作如下:

1. 配置虚拟环境

conda create -n DL python=3.7 
conda activate DL
pip install torch==1.8.1+cu102 torchvision==0.9.1+cu102 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
conda install matplotlib
 conda install scikit-learn

2. 库版本介绍

软件包本实验版本目前最新版
matplotlib3.5.33.8.0
numpy1.21.61.26.0
python3.7.16
scikit-learn0.22.11.3.0
torch1.8.1+cu1022.0.1
torchaudio0.8.12.0.2
torchvision0.9.1+cu1020.15.2

三、实验内容

ChatGPT:

         线性模型梯度下降法是一种常用的优化算法,用于求解线性回归模型中的参数。它通过迭代的方式不断更新模型参数,使得模型在训练数据上的损失函数逐渐减小,从而达到优化模型的目的。

        梯度下降法的基本思想是沿着损失函数梯度的反方向更新模型参数。在每次迭代中,根据当前的参数值计算损失函数的梯度,然后乘以一个学习率的因子,得到参数的更新量。学习率决定了参数更新的步长,过大的学习率可能导致错过最优解,而过小的学习率则会导致收敛速度过慢。

具体而言,对于线性回归模型,梯度下降法的步骤如下:

  1. 初始化模型参数,可以随机初始化或者使用一些启发式的方法。

  2. 循环迭代以下步骤,直到满足停止条件(如达到最大迭代次数或损失函数变化小于某个阈值):

    a. 根据当前的参数值计算模型的预测值。

    b. 计算损失函数关于参数的梯度,即对每个参数求偏导数。

    c. 根据梯度和学习率更新参数值。

    d. 计算新的损失函数值,并检查是否满足停止条件。

  3. 返回优化后的模型参数。

       本实验中,gradient_descent函数实现了梯度下降法的具体过程。它通过调用initialize_parameters函数初始化模型参数,然后在每次迭代中计算模型预测值、梯度以及更新参数值。

0. 导入库

import numpy as np

1. 初始化参数

        在梯度下降算法中,需要初始化待优化的参数,即权重 w 和偏置 b。可以使用随机初始化的方式。

def initialize_parameters():w = np.random.randn(5)b = np.random.randn(5)return w, b

2. 线性模型 linear_model

def linear_model(x, w, b):output = np.dot(x, w) + breturn output

3. 损失函数loss_function

         该函数接受目标值y和模型预测值prediction,计算均方误差损失。

def loss_function(y, prediction):loss = (prediction - y) * (prediction - y)return loss

4. 梯度计算函数compute_gradients

        为了使用梯度下降算法,需要计算损失函数关于参数 w 和 b 的梯度。可以使用数值计算的方法来近似计算梯度。

def compute_gradients(x, y, w, b):h = 1e-6  # 微小的数值,用于近似计算梯度grad_w = (loss_function(y, linear_model(x, w + h, b)) - loss_function(y, linear_model(x, w - h, b))) / (2 * h)grad_b = (loss_function(y, linear_model(x, w, b + h)) - loss_function(y, linear_model(x, w, b - h))) / (2 * h)return grad_w, grad_b

5. 梯度下降函数gradient_descent

        根据梯度计算的结果更新参数 w 和 b,从而最小化损失函数。

def gradient_descent(x, y, learning_rate, num_iterations):w, b = initialize_parameters()for i in range(num_iterations):prediction = linear_model(x, w, b)grad_w, grad_b = compute_gradients(x, y, w, b)w -= learning_rate * grad_wb -= learning_rate * grad_bloss = loss_function(y, prediction)print("Iteration", i, "Loss:", loss)return w, b

6. 调用函数

        执行梯度下降优化:调用 gradient_descent 函数并传入数据 x 和 y,设置学习率和迭代次数进行优化。

x = np.random.rand(5)
y = np.array([1, -1, 1, -1, 1]).astype('float')
learning_rate = 0.1
num_iterations = 100
w_optimized, b_optimized = gradient_descent(x, y, learning_rate, num_iterations)

        在上述代码中,每一次迭代都会打印出当前迭代次数和对应的损失值。通过不断更新参数 w 和 b,使得损失函数逐渐减小,达到最小化损失函数的目的。

希望这个详细解析能够帮助你优化代码并使用梯度下降算法最小化损失函数。如果还有其他问题,请随时提问!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/135878.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

RocketMQ 发送顺序消息

文章目录 顺序消息应用场景消息组(MessageGroup)顺序性生产的顺序性MQ 存储的顺序性消费的顺序性 rocketmq-client-java 示例(gRPC 协议)1. 创建 FIFO 主题生产者代码消费者代码解决办法解决后执行结果 rocketmq-client 示例&…

【结构型】代理模式(Proxy)

目录 代理模式(Proxy)适用场景代理模式实例代码(Java) 代理模式(Proxy) 为其他对象提供一种代理以控制对这个对象的访问。Proxy 模式适用于在需要比较通用和复杂的对象指针代替简单的指针的时候。 适用场景 远程代理 (Remote Proxy) 为一个对象在不同…

《ADS2011射频电路设计与仿真实例》功率放大器设计的输入输出匹配

徐兴福这本书的6.6 Smith圆图匹配这一节中具体匹配时,直接给出了电容与串联微带的值,没有给出推导过程,我一开始以为是省略了详细推导过程,后来发现好像基本上是可以随便自己设的。以输入匹配(书本6.6.4输入匹配电路的…

Modbus RTU(Remote Terminal Unit)与RS-485协议介绍(主站设备(Master)、从站设备(Slave))

文章目录 Modbus RTU与RS-485协议介绍一、引言二、Modbus RTU 协议介绍2.1 Modbus RTU 协议简介2.2 Modbus RTU 协议帧结构主站设备、从站设备与从站设备地址2.3 Modbus RTU 协议举例 三、RS-485 协议介绍3.1 RS-485 协议简介3.2 RS-485 物理连接方式3.3 RS-485 与 Modbus RTU …

LeetCode-热题100-笔记-day31

105. 从前序与中序遍历序列构造二叉树https://leetcode.cn/problems/construct-binary-tree-from-preorder-and-inorder-traversal/ 给定两个整数数组 preorder 和 inorder ,其中 preorder 是二叉树的先序遍历, inorder 是同一棵树的中序遍历&#xff0c…

全国职业技能大赛云计算--高职组赛题卷④(容器云)

全国职业技能大赛云计算--高职组赛题卷④(容器云) 第二场次题目:容器云平台部署与运维任务1 Docker CE及私有仓库安装任务(5分)任务2 基于容器的web应用系统部署任务(15分)任务3 基于容器的持续…

企业架构LNMP学习笔记61

Nginx作为tomcat的前段反向代理: 在实际业务环境中,用户是直接通过域名访问,基于协议一般是http、https等。默认tomcat运行在8080端口。一般会通过前端服务器反向代理到后端的tomcat的方式,来实现用户可以通过域名访问tomcat的we…

2023全新TwoNav开源网址导航系统源码 | 去授权版

2023全新TwoNav开源网址导航系统源码 已过授权 所有功能可用 测试环境:NginxPHP7.4MySQL5.6 一款开源的书签导航管理程序,界面简洁,安装简单,使用方便,基础功能免费。 TwoNav可帮助你将浏览器书签集中式管理&#…

Java面试八股文宝典:初识数据结构-数组的应用扩展之HashMap

前言 除了基本的数组,还有其他高级的数据结构,用于更复杂的数据存储和检索需求。其中,HashMap 是 Java 集合框架中的一部分,用于存储键值对(key-value pairs)。HashMap 允许我们通过键来快速查找和检索值&…

【数据结构】树的存储结构;树的遍历;哈夫曼树;并查集

欢~迎~光~临~^_^ 目录 1、树的存储结构 1.1双亲表示法 1.2孩子表示法 1.3孩子兄弟表示法 2、树与二叉树的转换 3、树和森林的遍历 3.1树的遍历 3.1.1先根遍历 3.1.2后根遍历 3.2森林的遍历 3.2.1先序遍历森林 3.2.2中序遍历森林 4、树与二叉树的应用 4.1哈夫曼树…

【Linux网络编程】Socket-TCP实例

该代码利用socket套接字建立Tcp连接,包含服务器和客户端。当服务器和客户端启动时需要把端口号或ip地址以命令行参数的形式传入。服务器启动如果接受到客户端发来的请求连接,accept函数会返回一个打开的socket文件描述符,区别于监听连接的lis…

【校招VIP】前端网络之路由选择协议

考点介绍 当两台非直接连接的计算机需要经过几个网络通信时,通常就需要路由器。路由器提供一种方法来开辟通过一个网状联结的路径。在图R-9中标示了几条存在于洛杉矶和纽约办公室的路径。这种网状网络提供了冗余路径以调整通信负载或倒行链路,通常有一条…

灰狼算法优化ICEEMDAN参数,四种适应度函数任意切换,最小包络熵、样本熵、信息熵、排列熵...

今天给大家带来一期由灰狼算法优化ICEEMDAN参数的MATLAB代码。 优化ICEEMDAN参数的思想可以参考该文献: [1]陈爱午,王红卫.基于HBA-ICEEMDAN和HWPE的行星齿轮箱故障诊断[J].机电工程,2023,40(08):1157-1166. 文献原文提到:由于 ICEEMDAN 方法的分解效果取…

【数据结构】队列知识点总结--定义;基本操作;队列的顺序实现;链式存储;双端队列;循环队列

欢迎各位看官^_^ 目录 1.队列的定义 2.队列的基本操作 2.1初始化队列 2.2判断队列是否为空 2.3判断队列是否已满 2.4入队 2.5出队 2.6完整代码 3.队列的顺序实现 4.队列的链式存储 5.双端队列 6.循环队列 1.队列的定义 队列(Queue)是一种先…

Vue3记录

Vue3快速上手 1.Vue3简介 2020年9月18日,Vue.js发布3.0版本,代号:One Piece(海贼王)耗时2年多、2600次提交、30个RFC、600次PR、99位贡献者github上的tags地址:https://github.com/vuejs/vue-next/releas…

软件需求怎么写?

前言:一般来说,软件产品的需求人员的主要输出物就是软件需求,如果这个软件产品就XX系统,人们口中的“系统需求”和“软件需求”就没有什么区别了。在车企行业,推行这ASPICE体系,在这个体系中明确申请了系统…

DMNet复现(一)之数据准备篇:Density map guided object detection in aerial image

一、生成密度图 密度图标签生成 采用以下代码,生成训练集密度图gt: import cv2 import glob import h5py import scipy import pickle import numpy as np from PIL import Image from itertools import islice from tqdm import tqdm from matplotli…

哈希及哈希表的实现

目录 一、哈希的引入 二、概念 三、哈希冲突 四、哈希函数 常见的哈希函数 1、直接定址法 2、除留余数法 五、哈希冲突的解决 1、闭散列 2、开散列 一、哈希的引入 顺序结构以及平衡树中,元素关键码与其存储位置之间没有对应的关系,因此在查找…

浅析三维模型3DTile格式轻量化处理常见问题与处理措施

浅析三维模型3DTile格式轻量化处理常见问题与处理措施 三维模型3DTile格式的轻量化处理是大规模三维地理空间数据可视化的关键环节,但在实际操作过程中,往往会遇到一些问题。下面我们来看一下这些常见的问题以及对应的处理措施。 变形过大:压…

Vue入门--vue的生命周期

一.什么是Vue 二.Vue的简介 官方网址 特点 三. 前后端的分离 重大问题 优势 4.Vue入门 定义一个管理边界 ​编辑 测试结果 vue的优势 ​编辑 测试结果 5.Vue的生命周期 vue的生命周期图 ​编辑建立一个html 测试结果 一.什么是Vue Vue是一种流行的JavaScript前端框…