机器学习与数据挖掘_使用梯度下降法训练线性回归模型

目录

实验内容

实验步骤

1. 导入必要的库

2. 加载数据并绘制散点图

3. 设置模型的超参数

4. 实现梯度下降算法

5. 打印训练后的参数和损失值

6. 绘制损失函数随迭代次数的变化图

7. 绘制线性回归拟合曲线

8. 基于训练好的模型进行新样本预测

实验代码

实验结果

实验总结


实验内容

(1)编写代码实现基于梯度下降的单变量线性回归算法,包括梯度的计算与验证;

(2)绘制数据散点图,以及得到的直线;

(3)绘制梯度下降过程中损失的变化图;

(4)基于训练得到的参数,输入新的样本数据,输出预测值。


实验步骤

1. 导入必要的库

使用 `numpy` 进行科学计算,并使用 `matplotlib` 来生成图形。为了保证图形中的中文正常显示,设置 `matplotlib` 的字体为黑体,并解决负号显示问题。

2. 加载数据并绘制散点图

使用 `numpy` 的 `genfromtxt` 函数从文件中加载数据,数据以逗号作为分隔符。分别提取第一列数据为 `x` 值,第二列数据为 `y` 值,展示数据点的分布情况。使用 `scatter` 函数绘制散点图,并使用 `show` 函数显示图形。

3. 设置模型的超参数

初始化线性回归模型的参数:学习率 `alpha` 设置为 `0.0001`。权重 `w` 和偏置 `b` 初始化为 `0`。设置梯度下降的迭代次数为 `1000`。获取数据样本数量 `m`。

4. 实现梯度下降算法

定义一个列表 `MSE` 用来存储每次迭代的均方误差。在每次迭代中,分别计算损失函数和模型参数的梯度:对每一个样本点,计算当前的预测值和真实值的误差,进而计算平方误差并累积。计算梯度,分别对权重 `w` 和偏置 `b` 进行更新。更新后的参数 `w` 和 `b` 基于学习率和当前梯度值来进行调整。

5. 打印训练后的参数和损失值

在训练结束后,打印出模型的最终参数 `w` 和 `b`。使用最后一次迭代的均方误差来表示最终的损失函数值。

6. 绘制损失函数随迭代次数的变化图

使用 `plot` 函数绘制损失函数随迭代次数变化的曲线,`x` 轴为迭代次数,`y` 轴为损失值。图形展示了梯度下降过程中损失函数值的变化趋势,验证模型的收敛情况。

7. 绘制线性回归拟合曲线

再次绘制原始数据的散点图,并基于训练得到的参数计算每个数据点的预测值。使用 `plot` 函数绘制线性回归拟合的曲线,并用红色标出拟合的直线。

8. 基于训练好的模型进行新样本预测

输入新的样本数据 `new_sample`,基于训练得到的参数 `w` 和 `b` 计算新的 `y` 值。打印出新样本数据及其对应的预测值。


实验代码

# 导入必要的库
import numpy as np  # 导入科学计算库
import matplotlib.pyplot as plt  # 导入绘图库
from matplotlib import rcParams  # 导入设置绘图样式的参数# 设置字体,防止中文乱码
rcParams['font.sans-serif'] = ['SimHei']  # 使用黑体
rcParams['axes.unicode_minus'] = False  # 解决负号显示问题# 1. 加载数据并画出散点图
points = np.genfromtxt('data1.txt', delimiter=',')  # 从文件中加载数据,数据以逗号分隔
x = points[:, 0]  # 获取第一列数据作为x
y = points[:, 1]  # 获取第二列数据作为y
plt.scatter(x, y)  # 绘制散点图,展示数据的分布情况
plt.show()# 2. 设置模型的超参数
alpha = 0.0001  # 学习率
w = 0  # 初始化权重w
b = 0  # 初始化偏置b
num_iter = 1000  # 梯度下降的迭代次数
m = len(points)  # 样本数量# 3. 梯度下降算法
MSE = []  # 用于保存每次迭代的均方误差
for iteration in range(num_iter):# 初始化梯度的和sum_grad_w = 0  # 用于累加w的梯度sum_grad_b = 0  # 用于累加b的梯度total_cost = 0  # 每次迭代的总损失初始化为0# 遍历所有数据点,计算偏导数并更新梯度for i in range(m):x_i = points[i, 0]  # 当前数据点的x值y_i = points[i, 1]  # 当前数据点的y值# 计算当前点的预测值pred_y_i = w * x_i + b# 计算损失函数(平方误差)total_cost += (y_i - pred_y_i) ** 2# 计算梯度sum_grad_w += (pred_y_i - y_i) * x_i  # 对w的偏导数sum_grad_b += (pred_y_i - y_i)  # 对b的偏导数# 计算当前迭代的均方误差total_cost /= mMSE.append(total_cost)  # 保存每次迭代的损失值# 计算偏导数的平均值grad_w = 2 / m * sum_grad_wgrad_b = 2 / m * sum_grad_b# 更新w和b,基于学习率和梯度w -= alpha * grad_wb -= alpha * grad_b# 4. 打印训练后的参数和损失值
print("参数w = ", w)
print("参数b = ", b)
# 使用 MSE[-1] 来表示最后一次迭代的损失函数值
print("最后的损失函数 = ", MSE[-1])# 5. 绘制损失函数随迭代次数的变化图
plt.plot(MSE)
plt.xlabel('迭代次数')
plt.ylabel('损失值')
plt.title('梯度下降过程中的损失函数变化')
plt.show()# 6. 画出拟合曲线
plt.scatter(x, y)  # 原始数据的散点图
pred_y = w * x + b  # 基于最终的w和b计算所有数据点的预测值
plt.plot(x, pred_y, color='red')  # 绘制线性回归拟合的直线,颜色为红色
plt.title('线性回归拟合曲线')
plt.show()# 7. 基于训练得到的参数进行新样本预测
new_sample = np.array([5, 10, 15])  # 新的输入数据
predicted_y = w * new_sample + b  # 计算新样本的预测值
print("输入的新样本数据: ", new_sample)
print("预测的y值: ", predicted_y)

实验结果

1. 数据散点图及其线性回归拟合曲线

数据散点图及其线性回归拟合曲线

2. 梯度下降过程中损失函数变化图

梯度下降过程中损失函数变化图

3. 相关参数展示及新样本数据和其预测值

相关参数展示及新样本数据和其预测值


实验总结

本次实验通过使用梯度下降法训练线性回归模型,实现了单变量线性回归的训练与预测。实验中,我们成功编写了基于梯度下降算法的代码,并通过图形展示了数据的分布情况及模型的拟合效果。

在实验过程中,模型的权重参数和偏置参数通过多次迭代逐步更新,梯度下降法有效地减少了损失函数值。最终,模型收敛到了一个较好的参数组合,使得拟合曲线能够较好地反映数据的趋势。此外,通过绘制损失函数的变化图,我们直观地看到了随着迭代次数的增加,损失值不断下降的过程,验证了梯度下降算法的收敛性。

实验结果表明,使用梯度下降法能够有效训练线性回归模型,并且在小数据集上可以获得较为理想的拟合效果。同时,通过该实验,进一步加深了对线性回归和梯度下降算法的理解和掌握。

总体而言,实验达到了预期的目标,完成了线性回归模型的训练、损失函数的可视化及新样本的预测任务。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/466451.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

机器学习与AI|如何利用数据科学优化库存周转率?

对于所有零售商来说,良好的库存管理都是非常重要的。众所周知,商品如果不放在货架上就无法出售,而如果库存过多则意味着严重的财务负担。 但是做好库存管理绝非易事,它依赖于对未来需求的准确预测和确保始终有合适库存的敏捷供应链…

Proteus中数码管动态扫描显示不全(已解决)

文章目录 前言解决方法后记 前言 我是直接把以前写的 51 数码管程序复制过来的,当时看的郭天祥的视频,先送段选,消隐后送位选,最后来个 1ms 的延时。 代码在 Proteus 中数码管静态是可以的,动态显示出了问题——显示…

如何快速搭建一个spring boot项目

一、准备工作 1.1 安装JDK:确保计算机上已安装Java Development Kit (JDK) 8或更高版本、并配置了环境变量 1.2 安装Maven:下载并安装Maven构建工具,这是Spring Boot官方推荐的构建工具。 1.3 安装代码编辑器:这里推荐使用Inte…

基于ViT的无监督工业异常检测模型汇总

基于ViT的无监督工业异常检测模型汇总 论文1:RealNet: A Feature Selection Network with Realistic Synthetic Anomaly for Anomaly Detection(2024)1.1 主要思想1.2 系统框架 论文2:Inpainting Transformer for Anomaly Detecti…

数据结构C语言描述2(图文结合)--有头单链表,无头单链表(两种方法),链表反转、有序链表构建、排序等操作,考研可看

前言 这个专栏将会用纯C实现常用的数据结构和简单的算法;用C基础即可跟着学习,代码均可运行;准备考研的也可跟着写,个人感觉,如果时间充裕,手写一遍比看书、刷题管用很多,这也是本人采用纯C语言…

Python | Leetcode Python题解之第542题01矩阵

题目: 题解: class Solution:def updateMatrix(self, matrix: List[List[int]]) -> List[List[int]]:m, n len(matrix), len(matrix[0])# 初始化动态规划的数组,所有的距离值都设置为一个很大的数dist [[10**9] * n for _ in range(m)]…

ENSP作业——园区网

题目 根据上图,可得需求为: 1.配置交换机上的VLAN及IP地址。 2.设置SW1为VLAN 2/3的主根桥,设置SW2为VLAN 20/30的主根桥,且两台交换机互为主备。 3.可以使用super vlan。 4.上层通过静态路由协议完成数据通信过程。 5.AR1作为企…

【1个月速成Java】基于Android平台开发个人记账app学习日记——第7天,申请阿里云SMS短信服务SDK

系列专栏链接如下,方便跟进: https://blog.csdn.net/weixin_62588253/category_12821860.html?fromshareblogcolumn&sharetypeblogcolumn&sharerId12821860&sharereferPC&sharesourceweixin_62588253&sharefromfrom_link 同时篇幅…

让Apache正确处理不同编码的文件避免中文乱码

安装了apache2.4.39以后&#xff0c;默认编码是UTF-8&#xff0c;不管你文件是什么编码&#xff0c;统统按这个来解析&#xff0c;因此 GB2312编码文件内的中文将显示为乱码。 <!doctype html> <html> <head><meta http-equiv"Content-Type" c…

『Django』初识前后端分离

点赞 + 关注 + 收藏 = 学会了 本文简介 在前面的「Django」系列的文章 中使用的是“前后端不分离”的方式去学习 Django,但现在企业比较流行的开发方式是前后端分离。 简单来说,前后端分离就是把前端和后端的工作分配给2个人做,前端主要负责用户界面的开发,后端主要负责…

探索开放资源上指令微调语言模型的现状

人工智能咨询培训老师叶梓 转载标明出处 开放模型在经过适当的指令调整后&#xff0c;性能可以与最先进的专有模型相媲美。但目前缺乏全面的评估&#xff0c;使得跨模型比较变得困难。来自Allen Institute for AI和华盛顿大学的研究人员们进行了一项全面的研究&#xff0c;探索…

搜维尔科技:【应用】Xsens在荷兰车辆管理局人体工程学评估中的应用

荷兰车辆管理局&#xff08;RDW&#xff09;通过数据驱动的人体工程学评估&#xff0c;将职业健康和安全放在首位。 关键信息 01 改进人体工程学评估&#xff1a;RDW使用Xsens动作捕捉和Scalefit Industrial Athlete进行精确、实时的人体工程学评估&#xff0c;识别并降低与…

文件系统和日志管理 附实验:远程访问第一台虚拟机日志

文件系统和日志管理 文件系统&#xff1a;文件系统提供了一个接口&#xff0c;用户用来访问硬件设备&#xff08;硬盘&#xff09;。 硬件设备上对文件的管理 文件存储在硬盘上&#xff0c;硬盘最小的存储单位是512字节&#xff0c;扇区。 文件在硬盘上的最小存储单位&…

大众汽车合肥社招入职笔试测评SHL题库:综合能力、性格问卷、英语口语真题考什么?

大众汽车合肥社招入职笔试测评包括综合能力测试、性格问卷和英语口语测试。以下是各部分的具体内容&#xff1a; 1. **综合能力测试**&#xff1a; - 这部分测试需要46分钟完成&#xff0c;建议准备计算器和纸笔。 - 测试内容涉及问题解决能力、数值计算能力和逻辑推理能力。 -…

Python进阶之IO操作

文章目录 一、文件的读取二、文件内容的写入三、之操作文件夹四、StringIO与BytesIO 一、文件的读取 在python里面&#xff0c;可以使用open函数来打开文件&#xff0c;具体语法如下&#xff1a; open(filename, mode)filename&#xff1a;文件名&#xff0c;一般包括该文件所…

UE5.4 PCG 自定义PCG蓝图节点

ExecuteWithContext&#xff1a; PointLoopBody&#xff1a; 效果&#xff1a;点密度值与缩放成正比

Transformer和BERT的区别

Transformer和BERT的区别比较表&#xff1a; 两者的位置编码&#xff1a; 为什么要对位置进行编码&#xff1f; Attention提取特征的时候&#xff0c;可以获取全局每个词对之间的关系&#xff0c;但是并没有显式保留时序信息&#xff0c;或者说位置信息。就算打乱序列中token…

Apache Commons Collections 反序列化漏洞

文章目录 前言一、漏洞爆出二、复现环境java集合框架问题JVM反射 三、Apache Commons Collections漏洞原理≤3.2.1CC关键类调用链路POC构造思路POC 前言 Apache Commons Collections是一个扩展了Java标准库里的Collection结构的第三方基础库&#xff0c;它提供了很多强大的数据…

正则表达式1 re.match惰性匹配详解案例

点个关注 re.match() re.match() 函数尝试从字符串的开头开始匹配一个模式&#xff0c;如果匹配成功&#xff0c;返回一个匹配成功的对象&#xff0c;否则返回None。大小写区分&#xff0c;内容匹配不到后面的,只能匹配一个&#xff0c;不能有空格&#xff08;开头匹配&#…

gov企业征信系统瑞数6vmp算法还原

URL aHR0cHM6Ly9zZC5nc3h0Lmdvdi5jbi8今天再来逆向下国家企业征信系统&#xff0c;这个站很卡&#xff0c;兄弟们你们轻点爬&#xff0c;我刷以下页面就转好久的圈圈&#xff0c;这个站两层防护&#xff0c;一层加速乐&#xff0c;一层瑞数&#xff0c;貌似还有极验验证码防护…