文章目录
- 前言
- 1.基础概念
- 2.代价函数
- 3.单变量线性回归
- 3.1加载数据
- 3.2初始化超参数
- 3.3梯度下降算法
- 3.3.1初次梯度下降
- 3.3.2 多次梯度下降
- 3.3.3结果可视化
前言
随着互联网数据不断累积,硬件不断升级迭代,在这个信息爆炸的时代,机器学习已被应用在各行各业中,可谓无处不在。在入门机器学习时,所学习的第一个模型大概率是线性回归模型,本篇博客从底层原理到代码实现一步一步带你了解如何是线性回归。
1.基础概念
线性回归
旨在建立一个线性模型,通过若干带有标注的样本数据构造出一个预测模型f(x),用于描述自变量(特征)与因变量(目标)之间的关系。
最典型的线性回归模型,同时也是举例最多的例子便是预测房价,网上也有很多相关的数据集。其数学模型可以表示为:
拆开形式为:
当然也可以将参数b
融入w向量
中,只需在所有样本x
后面加上新的属性值为1
即可
2.代价函数
对于线性回归问题,一般使用平方误差作为代价函数,所以代价函数可以表示为:
所以训练的目标为让代价函数最小,此时我们可以直接求偏导令其等于零,来确定参数w,b
,也可以使用梯度下降算法
,此处本篇博客使用梯度下降算法,在运用梯度下降算法之前我们需要确定梯度,推导过程如图所示:
后续算法设计时需要用到。
这里对 梯度下降
简要解释一下 :对一个函数,根据某点的梯度(斜率),以一定的步长即学习率来逼近极值点。
3.单变量线性回归
现在大多数都会用sklearn
来实现线性回归模型,而忽略了底层如何实现的,所以本篇博客介绍如何一步一步实现梯度下降,而不使用现有以封装好的库
3.1加载数据
此处我们使用pandas库提供的函数来读取csv格式
的文件
import pandas as pd
df=pd.read_csv('./data/regress_data1.csv')
x=df.iloc[:,:1].values
y=df.iloc[:,1:2].values
运行结果:
3.2初始化超参数
w=0
b=0
predict=w*x+b
此时该模型预测结果为predict
,显然不是一个好的模型,因此需要梯度下降来调整超参数w,b
,此时我们计算一下当前模型的代价:
import numpy as np
m=len(x)
loss=np.sum(np.power(y-predict,2))/(2*m)
3.3梯度下降算法
3.3.1初次梯度下降
在计算梯度下降时,你可能出现这个问题,反正我出现了,但慢慢调过后就没了,原因是它把它当作矩阵处理的所以出现这个问题,所以可以之间将其转化为向量就会避免这个问题出现了
因此正确代码应该为:
w=0
b=0
learning_rate=0.0
temp_w=w-np.dot((predict-y)[:,0],x[:,0])/m*learning_rate
temp_b=b-np.sum(predict-y)/m*learning_rate
w=temp_w
b=temp_b
w,b
此时的代价函数为:
predict=w*x+b
np.sum(np.power(y-predict,2))/(2*m)
此时发现该模型的代价比之前模型的代价更低
3.3.2 多次梯度下降
此处我们迭代了1000次,并输出了损失
learning_rate=0.01
w=0
b=0
for epoch in range(1000):temp_w=w-np.dot((predict-y)[:,0],x[:,0])/m*learning_ratetemp_b=b-np.sum(predict-y)/m*learning_ratew=temp_wb=temp_bpredict=w*x+bloss=np.sum(np.power(y-predict,2))/(2*m)if((epoch+1)%100==0):print("第{}轮梯度下降后的损失为:{}".format(epoch+1,loss))
输出结果:
3.3.3结果可视化
import matplotlib.pyplot as plt
learning_rate=0.01
w=0
b=0
plt.scatter(x,y)for epoch in range(1000):temp_w=w-np.dot((predict-y)[:,0],x[:,0])/m*learning_ratetemp_b=b-np.sum(predict-y)/m*learning_ratew=temp_wb=temp_bpredict=w*x+bloss=np.sum(np.power(y-predict,2))/(2*m)if((epoch)%100==0):print("第{}轮梯度下降后的损失为:{}".format(epoch,loss))plt.plot(x, predict)
plt.show()