线性回归 梯度下降原理与基于Python的底层代码实现

线性回归基础知识可查看该专栏中其他文章。

文章目录

  • 1 梯度下降算法原理
  • 2 一元函数梯度下降示例代码
  • 3 多元函数梯度下降示例代码

1 梯度下降算法原理

梯度下降是一种常用的优化算法,可以用来求解许包括线性回归在内的许多机器学习中的问题。前面讲解了直接使用公式求解 θ \theta θ (最小二乘法的求解推导与基于Python的底层代码实现),但是对于复杂的函数来说,可能较难求出对应的公式,因此需要使用梯度下降。

假设我们要求解的线性回归公式是:

y = β 0 + β 1 x 1 + β 2 x 2 + . . . + β n x n + ϵ y = \beta_0 + \beta_1x_1 + \beta_2x_2 + ... + \beta_nx_n + \epsilon y=β0+β1x1+β2x2+...+βnxn+ϵ

其中 y y y 是因变量, β i \beta_i βi 是回归系数, x i x_i xi 是自变量, ϵ \epsilon ϵ 是误差项。我们的目标是找到一组回归系数 β i \beta_i βi,使得模型能够最小化误差。

使用梯度下降算法求解线性回归可以分为以下步骤:

  1. 随机初始化回归系数 β i \beta_i βi

  2. 计算模型的预测值 y ^ \hat{y} y^

y ^ = β 0 + β 1 x 1 + β 2 x 2 + . . . + β n x n \hat{y} = \beta_0 + \beta_1x_1 + \beta_2x_2 + ... + \beta_nx_n y^=β0+β1x1+β2x2+...+βnxn

  1. 计算误差(或损失函数):

J ( β 0 , β 1 , . . . , β n ) = 1 2 m ∑ i = 1 m ( y i − y i ^ ) 2 J(\beta_0, \beta_1, ..., \beta_n) = \frac{1}{2m}\sum_{i=1}^{m}(y_i - \hat{y_i})^2 J(β0,β1,...,βn)=2m1i=1m(yiyi^)2

其中 m m m 是样本数量, y i y_i yi 是第 i i i 个样本的真实值, y i ^ \hat{y_i} yi^ 是对应的预测值。

  1. 计算误差对于每个回归系数的偏导数:

∂ J ∂ β j = 1 m ∑ i = 1 m ( y i ^ − y i ) x i j \frac{\partial J}{\partial \beta_j} = \frac{1}{m}\sum_{i=1}^{m}(\hat{y_i} - y_i)x_{ij} βjJ=m1i=1m(yi^yi)xij

其中 x i j x_{ij} xij 是第 i i i 个样本的第 j j j 个特征值。

  1. 使用梯度下降更新回归系数:

β j = β j − α ∂ J ∂ β j \beta_j = \beta_j - \alpha\frac{\partial J}{\partial \beta_j} βj=βjαβjJ

其中 α \alpha α 是学习率,用来控制更新的步长。

  1. 重复步骤 2-5多次,直到误差达到某个预定的阈值或者达到预设的迭代次数。

梯度下降算法会不断迭代,直到误差最小化。通过不断更新回归系数,模型逐渐拟合数据,从而得到最终的结果。
在这里插入图片描述

(非常经典的图,已经要盘包浆了)

2 一元函数梯度下降示例代码

  1. 导入此次代码所需的包,设置绘图时正常处理中文字符。
import numpy as np  
import matplotlib as mpl  
import matplotlib.pyplot as plt  mpl.rcParams['font.sans-serif'] = [u'SimHei']  
mpl.rcParams['axes.unicode_minus'] = False
  1. 定义本次要模拟的函数。为了方便起见,这里直接对函数的导数进行了定义。也可根据需要调包求梯度或者自己写一个求偏导的类。
# 一维原始图像  
def f1(x):  return 0.5 * (x - 2) ** 2  
# 导函数  
def h1(x):  return 0.5 * 2 * (x - 2)
  1. 初始化梯度下降中的参数
GD_X = []  
GD_Y = []  
x = 4  
alpha = 0.1  
f_change = 1  
f_current = f1(x)  
GD_X.append(x)  
GD_Y.append(f_current)  
iter_num = 0

此处GD_X与GD_Y两个列表分别用于存储梯度下降的每一步取值,用于后面的画图。x是梯度下降的起点,可设置为随机数。f_change用于存储执行每次循环之后,y的变化值。此处赋值的意义仅在于确保能进入下面的循环而不会报错。iter_num用于记录循环执行的次数。alpha学习率,取值过大容易难以收敛,取值过小容易增加计算量。
4. 梯度下降步骤的循环

while f_change > 1e-10 and iter_num < 1000:iter_num += 1  x = x - alpha * h1(x)  tmp = f1(x)  f_change = np.abs(f_current - tmp)  f_current  = tmp  GD_X.append(x)  GD_Y.append(f_current)

循环结束的标准为:两次循环的y值变化(即f_change)小于1e-10或循环次数大于100。
每次循环,x的变化量都是学习率乘以这一点的梯度。之后计算变化后x对应的y和变化前x对应的外,获得两次y的差值。并将每次运行的结果使用append保存到列表之中。
5. 结果输出

print(u"最终结果为:(%.5f, %.5f)" % (x, f_current))  
print(u"迭代次数:%d" % iter_num)

image.png
大概100次后,我们找到了损失函数最小值所对应的x。
6. 结果绘图

X = np.arange(-2, 6, 0.05)  
Y = np.array(list(map(lambda t: f1(t), X)))  plt.figure(facecolor='w')  
plt.plot(X, Y, 'r-', linewidth=2)  
plt.plot(GD_X, GD_Y, 'bo--', linewidth=2)  
plt.title(f'函数$y=0.5 * (θ - 2)^2$ \n学习率:{alpha:.3f}  最终解:x={x:.3f} y={f_current:.3f}  迭代次数:{iter_num}')  
plt.show()

可以自行尝试不同的起点,不同的学习速率对结果的影响。
在这里插入图片描述
在这里插入图片描述

3 多元函数梯度下降示例代码

当变量数为2时,梯度下降可以使用3维绘图展示。当变量书超过2时,损失函数变为超平面难以展示,因此此处以二元函数为例。

  1. 定义本次要模拟的函数。
# 二元函数定义  
def f2(x, y):  return (x - 2) ** 2 + 2* (y + 1) ** 2  
# 偏导数  
def hx2(x, y):  return 2*(x - 2)  
def hy2(x, y):  return 4*(y + 1)

与一元函数相同,我们对函数的偏导数直接定义,减少非本博客相关的代码。
2. 初始化梯度下降中的参数

GD_X1 = []  
GD_X2 = []  
GD_Y = []  
x1 = 4  
x2 = 4  
alpha = 0.01  
f_change = 1  
f_current = f2(x1, x2)  
GD_X1.append(x1)  
GD_X2.append(x2)  
GD_Y.append(f_current)  
iter_num = 0

这里与一元函数的参数基本相同,只是多了一个用于存储额外维度的listGD_X2。
3. 梯度下降步骤的循环

while f_change > 1e-10 and iter_num < 1000:  iter_num += 1  prex1 = x1  prex2 = x2  x1 = x1 - alpha * hx2(prex1, prex2)  x2 = x2 - alpha * hy2(prex1, prex2)  tmp = f2(x1, x2)  f_change = np.abs(f_current - tmp)  f_current = tmp  GD_X1.append(x1)  GD_X2.append(x2)  GD_Y.append(f_current)  print(u"最终结果为:(%.3f, %.3f, %.3f)" % (x1, x2, f_current))  
print(u"迭代次数:%d" % iter_num)

此处的逻辑与一元函数基本相同。对于每一个x,都使用对应的偏导数乘以学习速率,从而获得新的x值。如果是二元以上的多元函数同理。
运行结果为:
image.png

  1. 绘图
X1 = np.arange(-5, 5, 0.2)  
X2 = np.arange(-5, 5, 0.2)  
X1, X2 = np.meshgrid(X1, X2)  
Y = np.array(list(map(lambda t: f2(t[0], t[1]), zip(X1.flatten(), X2.flatten()))))  
Y.shape = X1.shapefig = plt.figure(facecolor='w')  
ax = Axes3D(fig)  
ax.plot_surface(X1, X2, Y, rstride=1, cstride=1, cmap=plt.cm.jet, alpha=0.8)  
ax.plot(GD_X1, GD_X2, GD_Y, 'ko-')  
ax.set_xlabel('x')  
ax.set_ylabel('y')  
ax.set_zlabel('z')  
plt.show()

对于三维数据,我们使用meshgrid构建了绘图网格,用于绘制函数图像。在绘制完函数图像的基础上,绘制梯度下降每一步的图像。绘制折线图时,ko-代表黑色、圆点、虚线。
(3D图像建议设置为单独显示,方便拖动视角查看)
在这里插入图片描述

实际上,梯度下降的种类也有很多,比如随机梯度下降、批量梯度下降,小批量梯度下降。这些内容将会在下一篇博客中进行讲解。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/18849.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

面渣逆袭:Java集合连环三十问

大家好&#xff0c;我是老三。上期发布了一篇&#xff1a;面渣逆袭&#xff1a;HashMap追魂二十三问&#xff0c;反响很好&#xff01; 围观群众纷纷表示&#x1f447; 不写&#xff0c;是不可能不写的&#xff0c;只有卷才能维持了生活这样子。 当然&#xff0c;我写的这一系…

Android-Activity生命周期

文章参考&#xff1a;文章参考1 文章参考&#xff1a;文章参考2 五大状态 StartingRunningStoppedPausedDestroyed 借用一张已经包浆的图 PS&#xff1a;Running和Paused是可视阶段&#xff0c;其余都是不可视 几大函数 onCreate&#xff1a;通过setContentLayout初始化布局…

Java 八股文-集合框架篇

Java 集合框架 一、常见集合 1.说说有哪些常见集合&#xff1f; 集合相关类和接口都在java.util中&#xff0c;主要分为3种&#xff1a;List&#xff08;列表&#xff09;、Map&#xff08;映射&#xff09;、Set(集)。 其中Collection是集合List、Set的父接口&#xff0c…

python爬虫入门篇

接下来的一些时间会分享一些爬虫相关的代码和知识 有人会问爬虫怎么舔女神&#xff1f; 我只能说浅了 看完伟大的Technical Licking Dog 的文章你将会对舔狗的认知得到一个升华&#xff01; 目录 接下来的一些时间会分享一些爬虫相关的代码和知识 正文 爬虫的运行原理&…

程序人生 - 为什么表情包越转发越模糊,还会变绿?

当代人聊天离不开什么&#xff1f; 表情包&#xff01;&#xff01;&#xff01; 没有表情包&#xff0c;怎么表达我的感情&#xff1f;&#xff08;当然&#xff0c;我对你基本没什么感情~只是想秀一下沙雕表情包&#xff01;&#xff09;在过去的日子里&#xff0c;江湖上流传…

⚡【C语言趣味教程】(1) 深入浅出 HelloWorld | 通过 HelloWorld 展开教学 | 头文件详解 | main 函数详解

&#x1f517; 《C语言趣味教程》&#x1f448; 猛戳订阅&#xff01;&#xff01;&#xff01; ​—— 热门专栏《维生素C语言》的重制版 —— &#x1f4ad; 写在前面&#xff1a;这是一套 C 语言趣味教学专栏&#xff0c;目前正在火热连载中&#xff0c;欢迎猛戳订阅&#x…

正确保护Macbook

MacBook该如何正确保护呢&#xff1f;不是各种键盘膜、保护壳通通用上就是最好的&#xff0c;那么该如何正确做呢&#xff1f;下面是macw小编带来的详细指导&#xff0c;快来学习&#xff01; 在接下来的文章中&#xff0c;笔者将展示哪些配件是可取的&#xff0c;哪些配件是坚…

从做产品的角度分析吕布为什么非死不可?

这是一篇小品文&#xff0c;作者是“产品家实战营3期”学员…… 马中赤兔&#xff0c;人中吕布&#xff0c;本意虽褒&#xff0c;但个人觉得将人与牲口类比&#xff0c;其段位貌似也没高到哪里&#xff0c;&#xff1a;&#xff09; 不过说起三国里的武将武力排名&#xff0c;吕…

中国撸串指北:13万家烧烤店的吃货最爱

戳蓝字“CSDN云计算”关注我们哦&#xff01; 数据分析&#xff1a;还是更爱火锅的朱小五 内容撰写&#xff1a;最爱干豆腐卷的王小九 来源|凹凸数读 对美食最大的肯定无疑就是那操着不同口音说出的“好吃&#xff01;”二字。 ——《人生一串》豆瓣短评 以美食慰藉夜归人&…

Github上这几个沙雕项目,够我玩三天!

点击上方“码农突围”&#xff0c;马上关注 这里是码农充电第一站&#xff0c;回复“666”&#xff0c;获取一份专属大礼包 真爱&#xff0c;请设置“星标”或点个“在看” 开源最前线&#xff08;ID&#xff1a;OpenSourceTop&#xff09; 猿妹综合整理 今天&#xff0c;猿妹再…

几个有趣的Github项目,够你玩一阵了...

点击上方“后端技术精选”&#xff0c;选择“置顶公众号” 技术文章第一时间送达&#xff01; 来源&#xff1a;开源最前线 今天&#xff0c;给大家整理一份有意思的沙雕项目&#xff0c;顺带分享了我的试用成果&#xff0c;说实话&#xff0c;这些项目够你玩三天了。 亲戚关系…

包浆网图分分钟变高清,伪影去除、细节恢复更胜前辈AI,下载可玩|腾讯ARC实验室出品...

丰色 发自 凹非寺量子位 报道 | 公众号 QbitAI 下面来欣赏一些高糊图片“整个世界都清晰了”的魔法时刻&#xff1a; 无论是动漫还是真实图像&#xff0c;是不是都清晰还原了&#xff1f; 以上就是由腾讯ARC实验室最新发表的图像超分辨率模型完成的。 与前人工作相比&#xff0…

爬虫入门实践 | 利用python爬取彩票中奖信息

系统环境&#xff1a;mac python版本&#xff1a;3.6.2(anaconda) 库&#xff1a;requests、BeautifulSoup 爬取一些简单的静态网站&#xff0c;一般采取的策略为&#xff1a;选中目标&#xff0c;也就是需要爬取的网站url&#xff1b;观察结构&#xff0c;查看网页结构&…

全网超详细的下载与安装VMware虚拟机以及为什么要安装VMware虚拟机

文章目录 1. 文章引言2. 下载VMware3. 安装VMware 1. 文章引言 我们使用最多的系统是windows系统&#xff0c;因为&#xff0c;国内电脑厂商的操作系统(os)基本是windows系统&#xff0c;比如华为、联想、华硕等电脑。 但线上的服务器大多是Linux系统&#xff0c;而我们经常使…

图灵奖得主LeCun:ChatGPT局限巨大,自回归模型寿命不超5年

作者 | 新智元 编辑 | 新智元 点击下方卡片&#xff0c;关注“自动驾驶之心”公众号 ADAS巨卷干货&#xff0c;即可获取 【导读】图灵奖得主Yann LeCun畅谈AI&#xff1a;未来是开源。 今年上半年&#xff0c;可谓是AI届最波澜壮阔的半年。 在急速发展的各类GPT甚至AGI的雏形背…

LeCun畅谈:ChatGPT局限巨大,自回归模型寿命不超5年

点击下方卡片&#xff0c;关注“CVer”公众号 AI/CV重磅干货&#xff0c;第一时间送达 点击进入—>【计算机视觉】微信技术交流群 转载自&#xff1a;新智元 | 编辑&#xff1a;拉燕 【导读】图灵奖得主Yann LeCun畅谈AI&#xff1a;未来是开源。 今年上半年&#xff0c;可谓…

ChatGPT正在改变一切但仍然有其局限性

人工智能聊天机器人已经被证明非常有能力完成技术任务&#xff0c;例如编写和编码。但它还不能做所有的事情。 自11月下旬发布以来&#xff0c;ChatGPT已经席卷全球。这款聊天机器人的高级人工智能能力允许它完全独立完成任务&#xff0c;如撰写论文、电子邮件和诗歌、编写和调…

从集异壁理解ChatGPT的成功与局限

终其一生&#xff0c;人类都在探寻认知这个世界的方式。 音乐、绘画和人工智能是三个看似无关的领域&#xff0c;但是它们都是人类这次伟大尝试的绚烂明珠。在这三个领域&#xff0c;追根溯源&#xff0c;底层的结构&#xff0c;都简洁且美丽。 图片由Midjourney生成&#xff0…

GPT虚拟直播Demo系列(二)|无人直播间实现虚拟人回复粉丝

摘要 虚拟人和数字人是人工智能技术在现实生活中的具体应用&#xff0c;它们可以为人们的生活和工作带来便利和创新。在直播间场景里&#xff0c;虚拟人和数字人可用于直播主播、智能客服、营销推广等。接入GPT的虚拟人像是加了超强buff&#xff0c;具备更强大的自然语言处理能…

今晚 12:30 RLHF: From Zero to ChatGPT 直播活动

本次演讲&#xff0c;我们将介绍一种称之为从人类反馈中强化学习 (RLHF, Reinforcement Learning from Human Feedback) 的基础知识&#xff0c;以及如何使用 RLHF 驱动实现 ChatGPT 这样的工具。我们将为大家介绍相关联的机器学习模型&#xff0c;涵盖自然语言处理 (NLP) 和强…