机器学习-梯度下降实验一

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection 
import train_test_split, KFold
from sklearn.metrics import mean_squared_error, r2_score
from mpl_toolkits.mplot3d import Axes3D  # 用于3D图plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号# 1. 读取数据并进行处理
data = pd.read_csv('data.csv')# 提取输入 (X) 和输出 (Y)
X = data['X'].values.reshape(-1, 1)
Y = data['Y'].values# 划分训练集和测试集,70% 训练,30% 测试
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.3, random_state=42)# 为输入 X 添加一列 1 以考虑截距项 (bias)
X_train_b = np.c_[np.ones((X_train.shape[0], 1)), X_train]  # 添加截距项
X_test_b = np.c_[np.ones((X_test.shape[0], 1)), X_test]# 初始化参数 (theta)
theta = np.zeros(2)# 定义超参数
learning_rate = 0.01
n_iterations = 1000# 计算代价函数 (均方误差)def compute_cost(X, Y, theta):m = len(Y)predictions = X.dot(theta)cost = (1 / (2 * m)) * np.sum((predictions - Y) ** 2)return cost# 梯度下降算法def gradient_descent(X, Y, theta, learning_rate, n_iterations):m = len(Y)cost_history = np.zeros(n_iterations)for iteration in range(n_iterations):gradients = (1 / m) * X.T.dot(X.dot(theta) - Y)theta = theta - learning_rate * gradientscost_history[iteration] = compute_cost(X, Y, theta)return theta, cost_history# 交叉验证函数def cross_validation(X, Y, learning_rate, n_iterations, k_folds=5):kfold = KFold(n_splits=k_folds, shuffle=True, random_state=42)cv_mse = []for train_index, val_index in kfold.split(X):X_train_fold, X_val_fold = X[train_index], X[val_index]Y_train_fold, Y_val_fold = Y[train_index], Y[val_index]# 为每个 fold 的训练数据添加 biasX_train_fold_b = np.c_[np.ones((X_train_fold.shape[0], 1)), X_train_fold]X_val_fold_b = np.c_[np.ones((X_val_fold.shape[0], 1)), X_val_fold]# 初始化 thetatheta = np.zeros(X_train_fold_b.shape[1])# 使用梯度下降训练模型theta_final, _ = gradient_descent(X_train_fold_b, Y_train_fold, theta, learning_rate, n_iterations)# 对验证集进行预测Y_val_pred = predict(X_val_fold, theta_final)# 计算均方误差mse = mean_squared_error(Y_val_fold, Y_val_pred)cv_mse.append(mse)# 返回交叉验证的平均MSEreturn np.mean(cv_mse)# 预测函数def predict(X, theta):X_b = np.c_[np.ones((X.shape[0], 1)), X]  # 添加截距项return X_b.dot(theta)# 自动调优学习率和迭代次数,并加入交叉验证
best_theta = None
best_mse = float('inf')
best_learning_rate = None
best_iterations = Nonelearning_rates = [0.001, 0.01, 0.02]
iteration_steps = [400, 500, 1000, 2000, 4000]
mse_results = np.zeros((len(learning_rates), len(iteration_steps)))for i, lr in enumerate(learning_rates):for j, iterations in enumerate(iteration_steps):cv_mse = cross_validation(X_train, Y_train, lr, iterations)mse_results[i, j] = cv_mse  # 记录每次的MSEif cv_mse < best_mse:best_mse = cv_msebest_learning_rate = lrbest_iterations = iterationsprint(f"Best MSE after cross-validation: {best_mse}, Best Learning Rate: {best_learning_rate}, Best Iterations: {best_iterations}")# 使用最优学习率和迭代次数重新训练模型
theta_final, cost_history = gradient_descent(X_train_b, Y_train, np.zeros(2), best_learning_rate, best_iterations)# 计算训练集和测试集的拟合程度
Y_train_pred = predict(X_train, theta_final)
Y_test_pred = predict(X_test, theta_final)# 计算均方误差和R2
train_mse = mean_squared_error(Y_train, Y_train_pred)
test_mse = mean_squared_error(Y_test, Y_test_pred)
train_r2 = r2_score(Y_train, Y_train_pred)
test_r2 = r2_score(Y_test, Y_test_pred)print(f"Train MSE: {train_mse}, Train R2: {train_r2}")print(f"Test MSE: {test_mse}, Test R2: {test_r2}")# 1. 可视化训练集和测试集的散点图与拟合直线
plt.figure(figsize=(10, 6))
plt.scatter(X_train, Y_train, color='blue', label='Train Data')
plt.scatter(X_test, Y_test, color='orange', label='Test Data')# 画拟合直线
X_range = np.linspace(min(X), max(X), 100)
Y_pred_line = predict(X_range, theta_final)
plt.plot(X_range, Y_pred_line, color='red', label='Fitted Line')# 画新样本的预测结果# 定义多个新输入数据
X_new_sample = np.array([7.0, 8.5, 6.0, 9.0, 5.5])  # 示例多个新输入# 对新输入进行预测
Y_new_pred = predict(X_new_sample, theta_final)print(Y_new_pred)
plt.scatter(X_new_sample, Y_new_pred, color='green', marker='x', s=100, label='Prediction for X=7.0')plt.title('训练集、测试集与预测结果的拟合曲线')
plt.xlabel('X')
plt.ylabel('Y')
plt.legend()
plt.grid(True)
plt.show()# 2. 可视化损失函数变化
plt.figure(figsize=(10, 6))
plt.plot(range(len(cost_history)), cost_history, color='green', label='Cost Function')
plt.title('损失的变化图')
plt.xlabel('Number of Iterations')
plt.ylabel('Cost (MSE)')
plt.grid(True)
plt.legend()
plt.show()# 3. 可视化最佳参数选择(学习率和迭代次数的搜索过程)X_lr, Y_iter = np.meshgrid(iteration_steps, learning_rates)fig = plt.figure(figsize=(10, 6))
ax = fig.add_subplot(111, projection='3d')ax.plot_surface(X_lr, Y_iter, mse_results, cmap='viridis')
ax.set_xlabel('Iterations')
ax.set_ylabel('Learning Rate')
ax.set_zlabel('MSE')
ax.set_title('Learning Rate and Iterations vs. MSE')plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/425698.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【C++】——继承详解

目录 1、继承的概念与意义 2、继承的使用 2.1继承的定义及语法 2.2基类与派生类间的转换 2.3继承中的作用域 2.4派生类的默认成员函数 <1>构造函数 <2>拷贝构造函数 <3>赋值重载函数 <4析构函数 <5>总结 3、继承与友元 4、继承与静态变…

使用ESP8266和OLED屏幕实现一个小型电脑性能监控

前言 最近大扫除&#xff0c;发现自己还有几个ESP8266MCU和一个0.96寸的oled小屏幕。又想起最近一直想要买一个屏幕作为性能监控&#xff0c;随机开始自己diy。 硬件&#xff1a; ESP8266 MUColed小屏幕杜邦线可以传输数据的数据线 环境 Windows系统Qt6Arduino Arduino 库…

【蔡英丽医生】小细节大影响:解读血栓来临前的身体语言!

血栓&#xff0c;这一隐形的健康杀手&#xff0c;常常在不经意间悄然降临&#xff0c;给人们的健康带来严重威胁。了解血栓来临前的身体语言&#xff0c;对于及早预防和治疗至关重要。今天&#xff0c;我们特别邀请到北京中医药大学东方医院脑病科的副主任医师——蔡英丽医生&a…

极速上云2.0范式:一键智连阿里云

在传统上云的现状与挑战&#xff1a; 专线上云太重&#xff0c;VPN上云不稳&#xff0c;云上VPC&#xff0c;云下物理网络&#xff0c;多段最后一公里...... 层层对接&#xff0c;跳跳延迟&#xff0c;好生复杂! 当你试图理解SD-WAN供应商和阿里云的文档&#xff0c;以协调路由…

介绍一些免费 的 html 5模版网站 和配色 网站

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 前言一、H5 网站介绍网站 二、配色网站个人推荐 前言 提示&#xff1a;以下是本篇文章正文内容&#xff0c;下面案例可供参考 一、H5 网站介绍 以下是一些提供免费…

构建响应式 Web 应用:Vue.js 基础指南

构建响应式 Web 应用&#xff1a;Vue.js 基础指南 一 . Vue 的介绍1.1 介绍1.2 好处1.3 特点 二 . Vue 的快速入门2.1 案例 1 : 快速搭建 Vue 的运行环境 , 在 div 视图中获取 Vue 中的数据2.2 案例 2 : 点击按钮执行 vue 中的函数输出 vue 中 data 的数据2.3 小结 三 . Vue 常…

PHP安全

PHP伪协议&#xff1a; 一.【file://协议】 PHP.ini&#xff1a; file:// 协议在双off的情况下也可以正常使用&#xff1b; allow_url_fopen &#xff1a;off/on allow_url_include&#xff1a;off/on file:// 用于访问本地文件系统&#xff0c;在CTF中通常用来读取本地文…

如何用安卓玩Java版Minecraft,安卓手机安装我的世界Java版游戏的教程

安卓手机使用FCL启动器安装我的世界Java版游戏的教程。如何用安卓玩Java版Minecraft 视频教程&#xff1a;https://www.bilibili.com/video/BV1CctYebEzR/ 前言 目前&#xff0c;安卓设备上可以用来运行Java版Minecraft的启动器主要有以下几款&#xff1a; PojavLauncher&a…

dedecms(四种webshell姿势)、aspcms webshell漏洞复现

一、aspcms webshell 1、登陆后台&#xff0c;在扩展功能的幻灯片设置模块&#xff0c;点击保存进行抓包查看 2、在slideTextStatus写入asp一句话木马 1%25><%25Eval(Request(chr(65)))%25><%25 密码是a&#xff0c;放行&#xff0c;修改成功 3、使用菜刀工具连…

第十一章 【后端】商品分类管理微服务(11.3)——商品管理模块 yumi-etms-goods

11.3 商品管理模块 yumi-etms-goods 新建 yumi-etms-goods 模块 添加依赖 pom.xml <?xml version="1.0" encoding="UTF-8"?> <project xmlns&#

【数字集成电路与系统设计】Chisel/Scala简介与Verilog介绍

目录 一、芯片前端设计开发背景知识 二、Verilog介绍 2.1 硬件设计一些重要概念 2.2 功能性仿真 2.3 简单的Verilog代码例子&#xff08;4-bit的加法器&#xff09; 三、Chisel简介 3.1 Chisel基本概念 3.2 Chisel代码展示 3.3 Chisel转成Verilog代码 四、Scala入…

Notepad++插件:TextFX 去除重复行

目录 一、下载插件 TextFX Characters 二、去重实操 2.1 选中需要去重的文本 2.2 操作插件 2.3 结果展示 2.3.1 点击 Sort lines case sensitive (at column) 2.3.2 点击 Sort lines case insensitive (at column) 一、下载插件 TextFX Characters 点【插件】-【插件管理…

数学学习记录

目录 学习资源&#xff1a; 9月14日 1.映射&#xff1a;​编辑 2.函数: 9月15日 3.反函数&#xff1a; 4.收敛数列的性质 5.反三角函数&#xff1a; 9月16日 6.函数的极限&#xff1a; 7.无穷小和无穷大 极限运算法则&#xff1a; 学习资源&#xff1a; 3Blue1…

【Elasticsearch系列九】控制台实战

&#x1f49d;&#x1f49d;&#x1f49d;欢迎来到我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感受到一份轻松愉快的氛围&#xff0c;不仅可以获得有趣的内容和知识&#xff0c;也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…

【matlab】生成 GIF 的函数(已封装可直接调用)

文章目录 前言一、函数输入与输出二、函数代码三、例程&#xff08;可直接运行&#xff09;参考文献 前言 生成 gif 图片时遇到的问题&#xff0c;为了后续调用方便&#xff0c;封装为函数 一、函数输入与输出 输入&#xff1a; cell_figure: cell 数组&#xff0c;数组元素是…

热成像目标检测数据集

热成像目标检测数据集 V2 版本 项目背景 热成像技术因其在安防监控、夜间巡逻、消防救援等领域的独特优势而受到重视。本数据集旨在提供高质量的热成像图像及其对应的可见光图像&#xff0c;支持热成像目标检测的研究与应用。 数据集概述 名称&#xff1a;热成像目标检测数据…

CSS框架 Tailwind CSS

文章目录 前言一、Tailwind CSS是什么&#xff1f;二、项目中如何使用1.安装Tailwind CSS2.初始化Tailwind CSS该处使用的url网络请求的数据。3.引入Tailwind CSS样式4.进行配置&#xff08;tailwind.config.js&#xff09;5.全局引入注册6.使用Tailwind CSS 总结 前言 Tailwi…

IP-adapter masking

https://github.com/huggingface/diffusers/issues/6802https://github.com/huggingface/diffusers/issues/6802

2024/9/16 dataloader、tensorboard、transform

一、pytorch两大法宝元素 假设有一个名为pytorch的包 dir()&#xff1a;用于打开包&#xff0c;看里面的内容 help():用于查看具体的内容的用处 二、python文件&#xff0c;python控制台和jupyter的使用对比 三、pytorch读取数据 pytorch读取数据主要涉及到两个类&#xff1…

开源 AI 智能名片链动 2+1 模式 S2B2C 商城小程序与社交电商的崛起

摘要&#xff1a;本文深入探讨了社交电商迅速发展壮大的原因&#xff0c;并分析了开源 AI 智能名片链动 21 模式 S2B2C 商城小程序在社交电商中的重要作用。通过对传统电商与社交电商的对比&#xff0c;以及对各发展因素的剖析&#xff0c;阐述了该小程序如何为社交电商提供新的…