机器学习:梯度下降法(Python)

LinearRegression_GD.py

import numpy as np
import matplotlib.pyplot as pltclass LinearRegression_GradDesc:"""线性回归,梯度下降法求解模型系数1、数据的预处理:是否训练偏置项fit_intercept(默认True),是否标准化normalized(默认True)2、模型的训练:闭式解公式,fit(self, x_train, y_train)3、模型的预测,predict(self, x_test)4、均方误差,判决系数5、模型预测可视化"""def __init__(self, fit_intercept=True, normalize=True, alpha=0.05, max_epochs=300, batch_size=20):""":param fit_intercept: 是否训练偏置项:param normalize: 是否标准化:param alpha: 学习率:param max_epochs: 最大迭代次数:param batch_size: 批量大小,若为1,则为随机梯度,若为训练集样本量,则为批量梯度,否则为小批量梯度"""self.fit_intercept = fit_intercept  # 线性模型的常数项。也即偏置bias,模型中的theta0self.normalize = normalize  # 是否标准化数据self.alpha = alpha  # 学习率self.max_epochs = max_epochsself.batch_size = batch_sizeself.theta = None  # 训练权重系数if normalize:self.feature_mean, self.feature_std = None, None  # 特征的均值,标准方差self.mse = np.infty  # 训练样本的均方误差self.r2, self.r2_adj = 0.0, 0.0  # 判定系数和修正判定系数self.n_samples, self.n_features = 0, 0  # 样本量和特征数self.train_loss, self.test_loss = [], []  # 存储训练过程中的训练损失和测试损失def init_params(self, n_features):"""初始化参数如果训练偏置项,也包含了bias的初始化:return:"""self.theta = np.random.randn(n_features, 1) * 0.1def fit(self, x_train, y_train, x_test=None, y_test=None):"""模型训练,根据是否标准化与是否拟合偏置项分类讨论:param x_train: 训练样本集:param y_train: 训练目标集:param x_test: 测试样本集:param y_test: 测试目标集:return:"""if self.normalize:self.feature_mean = np.mean(x_train, axis=0)  # 样本均值self.feature_std = np.std(x_train, axis=0) + 1e-8  # 样本方差x_train = (x_train - self.feature_mean) / self.feature_std  # 标准化if x_test is not None:x_test = (x_test - self.feature_mean) / self.feature_std  # 标准化if self.fit_intercept:x_train = np.c_[x_train, np.ones_like(y_train)]  # 添加一列1,即偏置项样本if x_test is not None and y_test is not None:x_test = np.c_[x_test, np.ones_like(y_test)]  # 添加一列1,即偏置项样本self.init_params(x_train.shape[1])  # 初始化参数self._fit_gradient_desc(x_train, y_train, x_test, y_test)  # 梯度下降法训练模型def _fit_gradient_desc(self, x_train, y_train, x_test=None, y_test=None):"""三种梯度下降求解:(1)如果batch_size为1,则为随机梯度下降法(2)如果batch_size为样本量,则为批量梯度下降法(3)如果batch_size小于样本量,则为小批量梯度下降法:return:"""train_sample = np.c_[x_train, y_train]  # 组合训练集和目标集,以便随机打乱样本# np.c_水平方向连接数组,np.r_竖直方向连接数组# 按batch_size更新theta,三种梯度下降法取决于batch_size的大小best_theta, best_mse = None, np.infty  # 最佳训练权重与验证均方误差for i in range(self.max_epochs):self.alpha *= 0.95np.random.shuffle(train_sample)  # 打乱样本顺序,模拟随机化batch_nums = train_sample.shape[0] // self.batch_size  # 批次for idx in range(batch_nums):# 取小批量样本,可以是随机梯度(1),批量梯度(n)或者是小批量梯度(<n)batch_xy = train_sample[self.batch_size * idx: self.batch_size * (idx + 1)]# 分取训练样本和目标样本,并保持维度batch_x, batch_y = batch_xy[:, :-1], batch_xy[:, -1:]# 计算权重更新增量delta = batch_x.T.dot(batch_x.dot(self.theta) - batch_y) / self.batch_sizeself.theta = self.theta - self.alpha * deltatrain_mse = ((x_train.dot(self.theta) - y_train.reshape(-1, 1)) ** 2).mean()self.train_loss.append(train_mse)if x_test is not None and y_test is not None:test_mse = ((x_test.dot(self.theta) - y_test.reshape(-1, 1)) ** 2).mean()self.test_loss.append(test_mse)def get_params(self):"""返回线性模型训练的系数:return:"""if self.fit_intercept:  # 存在偏置项weight, bias = self.theta[:-1], self.theta[-1]else:weight, bias = self.theta, np.array([0])if self.normalize:  # 标准化后的系数weight = weight / self.feature_std.reshape(-1, 1)  # 还原模型系数bias = bias - weight.T.dot(self.feature_mean)return weight.reshape(-1), biasdef predict(self, x_test):"""测试数据预测:param x_test: 待预测样本集,不包括偏置项:return:"""try:self.n_samples, self.n_features = x_test.shape[0], x_test.shape[1]except IndexError:self.n_samples, self.n_features = x_test.shape[0], 1  # 测试样本数和特征数if self.normalize:x_test = (x_test - self.feature_mean) / self.feature_std  # 测试数据标准化if self.fit_intercept:# 存在偏置项,加一列1x_test = np.c_[x_test, np.ones(shape=x_test.shape[0])]y_pred = x_test.dot(self.theta).reshape(-1, 1)return y_preddef cal_mse_r2(self, y_test, y_pred):"""计算均方误差,计算拟合优度的判定系数R方和修正判定系数:param y_pred: 模型预测目标真值:param y_test: 测试目标真值:return:"""self.mse = ((y_test.reshape(-1, 1) - y_pred.reshape(-1, 1)) ** 2).mean()  # 均方误差# 计算测试样本的判定系数和修正判定系数self.r2 = 1 - ((y_test.reshape(-1, 1) - y_pred.reshape(-1, 1)) ** 2).sum() / \((y_test.reshape(-1, 1) - y_test.mean()) ** 2).sum()self.r2_adj = 1 - (1 - self.r2) * (self.n_samples - 1) / \(self.n_samples - self.n_features - 1)return self.mse, self.r2, self.r2_adjdef plt_predict(self, y_test, y_pred, is_show=True, is_sort=True):"""绘制预测值与真实值对比图:return:"""if self.mse is np.infty:self.cal_mse_r2(y_pred, y_test)if is_show:plt.figure(figsize=(8, 6))if is_sort:idx = np.argsort(y_test)  # 升序排列,获得排序后的索引plt.plot(y_test[idx], "k--", lw=1.5, label="Test True Val")plt.plot(y_pred[idx], "r:", lw=1.8, label="Predictive Val")else:plt.plot(y_test, "ko-", lw=1.5, label="Test True Val")plt.plot(y_pred, "r*-", lw=1.8, label="Predictive Val")plt.xlabel("Test sample observation serial number", fontdict={"fontsize": 12})plt.ylabel("Predicted sample value", fontdict={"fontsize": 12})plt.title("The predictive values of test samples \n MSE = %.5e, R2 = %.5f, R2_adj = %.5f"% (self.mse, self.r2, self.r2_adj), fontdict={"fontsize": 14})plt.legend(frameon=False)plt.grid(ls=":")if is_show:plt.show()def plt_loss_curve(self, is_show=True):"""可视化均方损失下降曲线:param is_show: 是否可视化:return:"""if is_show:plt.figure(figsize=(8, 6))plt.plot(self.train_loss, "k-", lw=1, label="Train Loss")if self.test_loss:plt.plot(self.test_loss, "r--", lw=1.2, label="Test Loss")plt.xlabel("Epochs", fontdict={"fontsize": 12})plt.ylabel("Loss values", fontdict={"fontsize": 12})plt.title("Gradient Descent Method and Test Loss MSE = %.5f"% (self.test_loss[-1]), fontdict={"fontsize": 14})plt.legend(frameon=False)plt.grid(ls=":")# plt.axis([0, 300, 20, 30])if is_show:plt.show()

test_linear_regression_gd.py

import numpy as np
from LinearRegression_GD import LinearRegression_GradDesc
from sklearn.model_selection import train_test_splitnp.random.seed(42)
X = np.random.rand(1000, 6)  # 随机样本值,6个特征
coeff = np.array([4.2, -2.5, 7.8, 3.7, -2.9, 1.87])  # 模型参数
y = coeff.dot(X.T) + 0.5 * np.random.randn(1000)  # 目标函数值X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=0, shuffle=True)lr_gd = LinearRegression_GradDesc(alpha=0.1, batch_size=1)
lr_gd.fit(X_train, y_train, X_test, y_test)
theta = lr_gd.get_params()
print(theta)
y_test_pred = lr_gd.predict(X_test)
lr_gd.plt_predict(y_test, y_test_pred)
lr_gd.plt_loss_curve()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/247342.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【面试】测试开发面试题

帝王之气&#xff0c;定是你和万里江山&#xff0c;我都护得周全 文章目录 前言1. 网络原理get与post的区别TCP/IP各层是如何传输数据的IP头部包含哪些内容TCP头部为什么有浮动网络层协议1. 路由协议2. 路由信息3. OSPF与RIP的区别Cookie与Session&#xff0c;Token的区别http与…

解决Linux部署报错No main manifest attribute, in XXX.jar

这是我近期遇到的一个问题&#xff0c;报错原因就是没找到主类&#xff0c;首先你在你本地运行&#xff0c;本地运行ok的话&#xff0c;解压生成的jar包&#xff0c;里面有个META-INF文件&#xff0c;打开MANIFEST.MF文件&#xff0c;该文件是一个清单文件。该文件包含有关JAR文…

11. 双目视觉之立体视觉基础

目录 1. 深度恢复1.1 单目相机缺少深度信息1.2 如何恢复场景深度&#xff1f;1.3 深度恢复的思路 2. 对极几何约束2.1 直观感受2.2 数学上的描述 1. 深度恢复 1.1 单目相机缺少深度信息 之前学习过相机模型&#xff0c;最经典的就是小孔成像模型。我们知道相机通过小孔成像模…

汽车网络安全dos, someip

汽车Cyber Security入门之DoS 攻防 - 知乎 3、SOME/IP-TP 近年来火热地谈论下一代EE架构和SOA的时候&#xff0c;总离不开SOME/IP这个进程间通讯协议。在许多应用场景中&#xff0c;需要通过UDP传输大型的SOME/IP有效载荷。鉴于在以太网上传输数据包的大小限制&#xff0c;SO…

Linux文件管理(下)

上上篇介绍了Linux文件管理的上部分内容&#xff0c;这次继续将 Linux文件管理的剩余部分说完。内容如下。 一、查看文件内容 1、cat 命令 1.1 输出文件内容 基本语法&#xff1a; cat 文件名称主要功能&#xff1a;正序输出文件的内容。 eg&#xff1a;输出 readme.txt文…

剧本杀小程序的诞生:重塑线下娱乐的数字化未来

随着科技的不断发展&#xff0c;人们对于娱乐方式的需求也在不断升级。近年来&#xff0c;剧本杀作为一种新型的线下社交娱乐方式&#xff0c;以其独特的魅力和深度的人际互动性&#xff0c;受到了广大年轻人的喜爱。然而&#xff0c;传统的剧本杀模式存在一些问题&#xff0c;…

中间件安全

中间件安全 vulhub漏洞复现&#xff1a;https://vulhub.org/操作教程&#xff1a;https://www.freebuf.com/sectool/226207.html 一、Apache Apache(音译为阿帕奇)是世界使用排名第一的Web服务器软件。它可以运行在几乎所有广泛使用的计算机平台上&#xff0c;由于其跨平台和…

录屏软件哪个好?为您提供最佳选择(最新)

随着科技的进步&#xff0c;录屏软件已成为我们日常生活和工作中不可或缺的工具。无论是为了制作教程、会议记录还是游戏录像&#xff0c;一款优秀的录屏软件都是必不可少的。可是录屏软件哪个好呢&#xff1f;在本文中&#xff0c;我们将介绍两款常用的录屏软件&#xff0c;并…

uniapp 实现路由拦截,权限或者登录控制

背景&#xff1a; 项目需要判断token&#xff0c;即是否登录&#xff0c;登录之后权限 参考uni-app官方&#xff1a; 为了兼容其他端的跳转权限控制&#xff0c;uni-app并没有用vue router路由&#xff0c;而是内部实现一个类似此功能的钩子&#xff1a;拦截器&#xff0c;由…

鸿蒙(HarmonyOS)项目方舟框架(ArkUI)之CheckboxGroup组件

鸿蒙&#xff08;HarmonyOS&#xff09;项目方舟框架&#xff08;ArkUI&#xff09;之CheckboxGroup组件 一、操作环境 操作系统: Windows 10 专业版、IDE:DevEco Studio 3.1、SDK:HarmonyOS 3.1 二、CheckboxGroup组件 提供多选框组件&#xff0c;通常用于某选项的打开或关…

Apollo Cyber RT:引领实时操作系统在自动驾驶领域的创新

&#x1f3ac; 鸽芷咕&#xff1a;个人主页 &#x1f525; 个人专栏:《linux深造日志》《粉丝福利》 ⛺️生活的理想&#xff0c;就是为了理想的生活! ⛳️ 推荐 前些天发现了一个巨牛的人工智能学习网站&#xff0c;通俗易懂&#xff0c;风趣幽默&#xff0c;忍不住分享一下…

【服务器】宝塔面板的使用手册

目录 &#x1f337;概述 &#x1f33c;1. 绑定域名 &#x1f33c;2. 添加端口 &#x1f33c;3. 安装docker配置docker​​​​​​​ &#x1f33c;4. 软件商店 &#x1f33c;5. 首页 &#x1f337;概述 宝塔面板的安装教程&#xff1a;【服务器】安装宝塔面板 &#x1f…

绘制太极图 - 使用 PyQt

大家好&#xff01;今天我们将一起来探讨一下如何使用PyQt&#xff0c;这是一个强大的Python库&#xff0c;来绘制一个传统的太极图。这个图案代表着古老的阴阳哲学&#xff0c;而我们的代码将以大白话的方式向你揭示它的奥秘。 PyQt&#xff1a;是什么鬼&#xff1f; 首先&a…

Modelarts零代码体验,一键实现工地钢筋盘点,建筑提效新思维

前言 最近家附近的好几块地&#xff0c;同时在进行房产开发建设&#xff0c;早晚都能看到建筑师傅们在忙碌。 某天&#xff0c;夜跑中&#xff0c;发现前方的建筑工地&#xff0c;师傅们忙活的热火朝天&#xff0c;塔吊也在吊运钢筋中。 准备绕路的时候&#xff0c;旁边负责…

解锁创意无限:Adobe Photoshop 2023(PS2023)引领设计革命

Adobe Photoshop 2023 (PS2023)&#xff0c;作为图像处理软件的翘楚&#xff0c;以其卓越的性能和无限的可能性&#xff0c;继续引领着数字创意设计的潮流。对于设计师、摄影师、艺术家以及那些对视觉效果有高要求的人们来说&#xff0c;PS2023无疑是他们的必备工具。 在PS202…

web前端项目-实现录音功能【附源码】

录音功能 运行效果&#xff1a;本项目可实现录音软件的录音、存储、播放等功能 HTML源码&#xff1a; &#xff08;1&#xff09;index.html&#xff1a; <!DOCTYPE html> <html><head><meta http-equiv"Content-Type" content"text/h…

算法基础课-数据结构

单链表 题目链接&#xff1a;826. 单链表 - AcWing题库 思路&#xff1a;AcWing 826. 单链表---图解 - AcWing 需要注意的点在于理解ne[idx] head&#xff0c;idx表示当前的点&#xff0c;意思是将当前的点链到头结点的后面&#xff0c;再将头结点链在当前idx的前面。 #inc…

Qt|大小端数据转换

后面打算写Qt关于网络编程的博客&#xff0c;网络编程就绕不开字节流数据传输&#xff0c;字节流数据的传输一般是根据协议来定义对应的报文该如何组包&#xff0c;那这就必然牵扯到了大端字节序和小端字节序的问题了。不清楚的大小端的可以看一下相关资料&#xff1a;大小端模…

看图说话:Git图谱解读

很多新加入公司的同学在使用Git各类客户端管理代码的过程中对于Git图谱解读不太理解&#xff0c;我们常用的Git客户端是SourceTree&#xff0c;配合P4Merge进行冲突解决基本可以满足日常工作大部分需要。不同的Git客户端工具对图谱展示会有些许差异&#xff0c;以下是SourceTre…

[C#]winform部署yolov7+CRNN实现车牌颜色识别车牌号检测识别

【官方框架地址】 https://github.com/WongKinYiu/yolov7.git 【框架介绍】 Yolov7是一种目标检测算法&#xff0c;全称You Only Look Once version 7。它是继Yolov3和Yolov4之后的又一重要成果&#xff0c;是目标检测领域的一个重要里程碑。 Yolov7在算法结构上继承了其前…