机器学习第5天:多项式回归与学习曲线

文章目录

多项式回归介绍

方法与代码

方法描述

分离多项式

学习曲线的作用

场景

学习曲线介绍

欠拟合曲线

示例

结论

过拟合曲线

示例

​结论 


多项式回归介绍

当数据不是线性时我们该如何处理呢,考虑如下数据

import matplotlib.pyplot as plt
import numpy as npnp.random.seed(42)x = 8 * np.random.rand(100, 1) - 4
y = 2*x**2+3*x+np.random.randn(100, 1)plt.scatter(x, y)
plt.show()


方法与代码

方法描述

先讲思路,以这个二元函数为例

y=3*x^{2}+2*x+c

将多项式化为多个单项的,也就是将x的平方和x两个项分离开,然后单独给线性模型处理,求出参数,最后再组合在一起,很好理解,让我们来看一下代码


分离多项式

我们使用机器学习库的PolynomialFeatures来分离多项式

from sklearn.preprocessing import PolynomialFeaturespoly_features = PolynomialFeatures(degree=2, include_bias=False)
x_poly = poly_features.fit_transform(x)
print(x[0])
print(x_poly[0])

运行结果

可以看到,4, 5行代码将原始x和x平方挑选了出来,这时我们再把这个数据进行线性回归

model = LinearRegression()
model.fit(x_poly, y)
print(model.coef_)

 这段代码使用处理后的x拟合y,再打印模型拟合的参数,可以看到模型的两个参数分别是2.9和2左右,而我们的方程的一次参数和二次参数分别是3和2,可见效果还是很好的

把预测的结果绘制出来

model = LinearRegression()
model.fit(x_poly, y)
pre_y = model.predict(x_poly)# 这里是为了让x升序的排序算法, 可以尝试不加这段代码图会变成什么样
sorted_indices = sorted(range(len(x)), key=lambda k: x[k])
x_sorted = [x[i] for i in sorted_indices]
y_sorted = [pre_y[i] for i in sorted_indices]plt.plot(x_sorted, y_sorted, "r-")
plt.scatter(x, y)
plt.show()


学习曲线的作用

场景

设想一下,当你需要预测房价,你也有多组数据,包括离学校距离,交通状况等,但是问题来了,你只知道这些特征可能与房价有关,但并不知道这些特征与房价之间的方程关系,这时我们进行回归任务时,就可能导致欠拟合或者过拟合,幸运的是,我们可以通过学习曲线来判断


学习曲线介绍

学习曲线图就是以损失函数为纵坐标,数据集大小为横坐标,然后在图上画出训练集和验证集两条曲线的图,训练集就是我们用来训练模型的数据,验证集就是我们用来验证模型性能的数据集,我们往往将数据集分成训练集与验证集

我们先定义一个学习曲线绘制函数

import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegressiondef plot_learning_curves(model, x, y):x_train, x_val, y_train, y_val = train_test_split(x, y, test_size=0.2)train_errors, val_errors = [], []for m in range(1, len(x_train)):model.fit(x_train[:m], y_train[:m])y_train_predict = model.predict(x_train[:m])y_val_predict = model.predict(x_val)train_errors.append(mean_squared_error(y_train[:m], y_train_predict))val_errors.append(mean_squared_error(y_val, y_val_predict))plt.plot(np.sqrt(train_errors), "r-+", linewidth=2, label="train")plt.plot(np.sqrt(val_errors), "b-", linewidth=3, label="val")plt.legend()plt.show()

 简单介绍一下,这个函数接收模型参数,x,y参数,然后在for循环中,取不同数据集大小来计算RMSE损失(就是\sqrt{MSE}),然后把曲线绘制出来


欠拟合曲线

我们知道欠拟合就是模拟效果不好的情况,可以想象的到,无论在训练集还是验证集上,他的损失都会比较高

示例

我们将线性模型的学习曲线绘制出来

import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegressiondef plot_learning_curves(model, x, y):x_train, x_val, y_train, y_val = train_test_split(x, y, test_size=0.2)train_errors, val_errors = [], []for m in range(1, len(x_train)):model.fit(x_train[:m], y_train[:m])y_train_predict = model.predict(x_train[:m])y_val_predict = model.predict(x_val)train_errors.append(mean_squared_error(y_train[:m], y_train_predict))val_errors.append(mean_squared_error(y_val, y_val_predict))plt.plot(np.sqrt(train_errors), "r-+", linewidth=2, label="train")plt.plot(np.sqrt(val_errors), "b-", linewidth=3, label="val")plt.legend()plt.show()x = np.random.rand(100, 1)
y = 2 * x + np.random.rand(100, 1)model = LinearRegression()
plot_learning_curves(model, x, y)

 

结论

可以看到,在只有一点数据时,模型在训练集上效果很好(因为就是开始这一些数据训练出来的),而在验证集上效果不好,但随着训练集增加(模型学习到的越多),验证集上的误差逐渐减小,训练集上的误差增加(因为是学到了一个趋势,不会完全和训练集一样了)

这个图的特征是两条曲线非常接近,且误差都较大(差不多在0.3) ,这是欠拟合的表现(模型效果不好)


过拟合曲线

过拟合就是完全以数据集来模拟曲线,泛化能力很差

示例

我们来试试将一次函数模拟成三次函数,再来看看学习曲线(毫无疑问过拟合了)

import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression
from sklearn.pipeline import Pipelinedef plot_learning_curves(model, x, y):x_train, x_val, y_train, y_val = train_test_split(x, y, test_size=0.2)train_errors, val_errors = [], []for m in range(1, len(x_train)):model.fit(x_train[:m], y_train[:m])y_train_predict = model.predict(x_train[:m])y_val_predict = model.predict(x_val)train_errors.append(mean_squared_error(y_train[:m], y_train_predict))val_errors.append(mean_squared_error(y_val, y_val_predict))plt.plot(np.sqrt(train_errors), "r-+", linewidth=2, label="train")plt.plot(np.sqrt(val_errors), "b-", linewidth=3, label="val")plt.legend()plt.show()np.random.seed(10)
x = np.random.rand(200, 1)
y = 2 * x + np.random.rand(200, 1)poly_regression = Pipeline([("Poly", PolynomialFeatures(degree=3, include_bias=False)),("Line", LinearRegression())
])plot_learning_curves(poly_regression, x, y)

结论 

这条曲线的特征是训练集的效果比验证集好(两条线之间有一定间距),这往往是过拟合的表现(在训练集上效果好,验证集差,表面泛化能力差) 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/196467.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

安装最新版IntelliJ IDEA来开发Java应用程序

安装最新版IntelliJ IDEA来开发Java应用程序 Install the Latest Version of IntelliJ IDEA to Develop Java Applications 本文简要介绍如何安装配置JetBrains IntelliJ IDEA集成开发环境,从而开发Java应用程序;文中侧重实际操作和编程步骤&#xff0…

学习指南:如何快速上手媒体生态一致体验开发

过去开发者们在使用多媒体能力时,往往会遇到这样的问题,比如:为什么我开发的相机不如系统相机的效果好?为什么我的应用和其他的音乐一起发声了,我要怎么处理?以及我应该怎么做才能在系统的播控中心里可以看…

从C语言的面向过程编程过渡理解面向对象编程风格中的封装

在C语言中,我们解决一个问题通常是采用在了解了问题如何解决后,设置一个一个的函数,依次调用实现不同的功能的函数从而解决问题,这种编程风格就叫做面向过程。  除此之外,还有一种叫做面向对象的编程风格被广泛的使用…

【算法每日一练]-图论(保姆级教程 篇5(LCA,最短路,分层图)) #LCA #最短路计数 #社交网络 #飞行路线 # 第二短路

今天讲最短路统计和分层图 目录 题目:LCA 思路: 题目:最短路计数 思路: 题目:社交网络 思路: 题目:飞行路线 思路: 题目:第二短路 思路: 题目&a…

electron使用better-sqlite3打包失败(electron打包有进程没有界面)

remove *\chrome_100_percent.pak: Access is denied. 解决: 管理员权限执行:taskkill /IM 你的进程名.exe /F,再次执行build electron使用better-sqlite3打包后有进程没有界面 原因是代码及依赖包安装有误,模块丢失。主要分享的…

Python开源项目GPEN——人脸重建(Face Restoration),模糊清晰、划痕修复及黑白上色的实践

无论是自己、家人或是朋友、客户的照片,免不了有些是黑白的、被污损的、模糊的,总想着修复一下。作为一个程序员 或者 程序员的家属,当然都有责任满足他们的需求、实现他们的想法。除了这个,学习了本文的成果,或许你还…

腾讯云新用户优惠活动有哪些可以参加?腾讯云新人服务器优惠活动

腾讯云作为国内领先的云服务提供商,不仅为用户提供稳定可靠的云服务器,还为新用户带来了一系列的优惠活动和代金券,以降低购买成本,提高业务效益。在这里,我们将为您详细介绍腾讯云服务器的新人优惠活动及代金券&#…

Git企业开发级讲解(四)

📘北尘_:个人主页 🌎个人专栏:《Linux操作系统》《经典算法试题 》《C》 《数据结构与算法》 ☀️走在路上,不忘来时的初心 文章目录 一、理解分⽀二、创建分支三、切换分⽀四、合并分⽀五、删除分⽀六、合并冲突七、分⽀管理策略…

pythom导出mysql指定binlog文件

要求 要求本地有py环境和全局环境变量 先测试直接执行binlog命令执行命令 Windows 本地直接执行命令 # E:\output>E:\phpstudy_pro\Extensions\MySQL5.7.26\bin\mysqlbinlog binglog文件地址 # --no-defaults 不限制编码 # -h mysql链接地址 # -u mysql 链接名称 # -p m…

word文档转换为ppt文件,怎么做?

大家是否会遇到需要将word文档转换为ppt文件的情况?除了反反复复粘贴复制以外,还有其他方法可以转换文件格式,今天给大家分享word转换ppt方法。 首先我们先将word文件打开大纲模式 然后我们将文中的大标题设置为1级标题,副标题设…

flutter仿支付宝余额宝年化收益折线图

绘制: 1.在pubspec.yaml中引入:fl_chart: 0.55.2 2.绘制: import package:jade/utils/JadeColors.dart; import package:util/easy_loading_util.dart; import package:fl_chart/fl_chart.dart; import package:flutter/material.dart; impo…

Leetcode—142.环形链表II【中等】

2023每日刷题(三十三) Leetcode—142.环形链表II 实现代码 /*** Definition for singly-linked list.* struct ListNode {* int val;* struct ListNode *next;* };*/ struct ListNode *detectCycle(struct ListNode *head) {struct ListNode* …

Stable Diffusion - StableDiffusion WebUI 软件升级与扩展兼容

欢迎关注我的CSDN:https://spike.blog.csdn.net/ 本文地址:https://spike.blog.csdn.net/article/details/134463035 目前,StableDiffusion WebUI 的版本是 1.6.0,同步更新 controlnet、tagcomplete、roop、easy-prompt-selector等…

vmware安装MacOS以及flutter遇到的问题

安装过程:参考下面的文章 链接: 虚拟机VMware安装苹果系统macOS,超级详细教程,附文件下载,真教程!! 无限重启情况: (二) 配置虚拟机找到你的虚拟机安装文件…

Vite -静态资源处理 - SVG格式的图片

特点 Vite 对静态资源是开箱即用的。 无需做特殊的配置。项目案例 项目结构 study-vite| -- src| -- assets| -- bbb.svg # 静态的svg图片资源| -- index.html # 主页面| -- main.js # 引入静态资源| -- package.json # 脚本配置| -- vite.co…

景联文科技入选量子位智库《中国AIGC数据标注产业全景报告》数据标注行业代表机构

量子位智库《中国AIGC数据标注产业全景报告》中指出,数据标注处于重新洗牌时期,更高质量、专业化的数据标注成为刚需。未来五年,国内AI基础数据服务将达到百亿规模,年复合增长率在27%左右。 基于数据基础设施建设、大模型/AI技术理…

安装应用与免安装应用差异对比

差异 安装的程序和免安装的应用程序之间有以下几个方面的差别: 安装过程:安装的程序需要通过一个安装程序或安装脚本进行安装。这个过程通常会将应用程序的文件和依赖项复制到指定的目录,并进行一些配置和注册操作。免安装的应用程序则不需要…

Apache Airflow (八) :DAG任务依赖设置

🏡 个人主页:IT贫道_大数据OLAP体系技术栈,Apache Doris,Clickhouse 技术-CSDN博客 🚩 私聊博主:加入大数据技术讨论群聊,获取更多大数据资料。 🔔 博主个人B栈地址:豹哥教你大数据的个人空间-豹…

nginx学习(1)

一、下载安装NGINX: 先安装gcc-c编译器 yum install gcc-c yum install -y openssl openssl-devel(1)下载pcre-8.3.7.tar.gz 直接访问:http://downloads.sourceforge.net/project/pcre/pcre/8.37/pcre-8.37.tar.gz,就…