实现一个简单的线性回归和多项式回归(2)

对于多项式回归,可以同样使用前面线性回归中定义的LinearRegression算子、训练函数train、均方误差函数mean_squared_error,生成数据集create_toy_data,这里就不多做赘述咯~

        拟合的函数为

def sin(x):y = torch.sin(2 * math.pi * x)return y

1.数据集的建立

func = sin
interval = (0,1)
train_num = 15
test_num = 10
noise = 0.5
X_train, y_train = create_toy_data(func=func, interval=interval, sample_num=train_num, noise = noise)
X_test, y_test = create_toy_data(func=func, interval=interval, sample_num=test_num, noise = noise)X_underlying = torch.linspace(interval[0], interval[1], 100)
y_underlying = sin(X_underlying)# 绘制图像
plt.rcParams['figure.figsize'] = (8.0, 6.0)
plt.scatter(X_train, y_train, facecolor="none", edgecolor='#e4007f', s=50, label="train data")
#plt.scatter(X_test, y_test, facecolor="none", edgecolor="r", s=50, label="test data")
plt.plot(X_underlying, y_underlying, c='#000000', label=r"$\sin(2\pi x)$")
plt.legend(fontsize='x-large')
plt.savefig('ml-vis2.pdf')
plt.show()

        生成结果: 

2.模型构建

# 多项式转换
def polynomial_basis_function(x, degree=2):"""输入:- x: tensor, 输入的数据,shape=[N,1]- degree: int, 多项式的阶数example Input: [[2], [3], [4]], degree=2example Output: [[2^1, 2^2], [3^1, 3^2], [4^1, 4^2]]注意:本案例中,在degree>=1时不生成全为1的一列数据;degree为0时生成形状与输入相同,全1的Tensor输出:- x_result: tensor"""if degree == 0:return torch.ones(x.size(), dtype=torch.float32)x_tmp = xx_result = x_tmpfor i in range(2, degree + 1):x_tmp = torch.multiply(x_tmp, x)  # 逐元素相乘x_result = torch.concat((x_result, x_tmp), dim=-1)return x_result

        我的理解多项式回归的模型的建立更像是将原来的一个x变成x x^2 x^3 ... x^degree这个过程

3.模型训练

plt.rcParams['figure.figsize'] = (12.0, 8.0)for i, degree in enumerate([0, 1, 3, 8]):  # []中为多项式的阶数model = Linear(degree)X_train_transformed = polynomial_basis_function(X_train.reshape([-1, 1]), degree)X_underlying_transformed = polynomial_basis_function(X_underlying.reshape([-1, 1]), degree)model = optimizer_lsm(model, X_train_transformed, y_train.reshape([-1, 1]))  # 拟合得到参数y_underlying_pred = model(X_underlying_transformed).squeeze()print(model.params)# 绘制图像plt.subplot(2, 2, i + 1)plt.scatter(X_train, y_train, facecolor="none", edgecolor='#e4007f', s=50, label="train data")plt.plot(X_underlying, y_underlying, c='#000000', label=r"$\sin(2\pi x)$")plt.plot(X_underlying, y_underlying_pred, c='#f19ec2', label="predicted function")plt.ylim(-2, 1.5)plt.annotate("M={}".format(degree), xy=(0.95, -1.4))# plt.legend(bbox_to_anchor=(1.05, 0.64), loc=2, borderaxespad=0.)
plt.legend(loc='lower left', fontsize='x-large')
plt.savefig('ml-vis3.pdf')
plt.show()

         训练结果:

观察可视化结果,红色的曲线表示不同阶多项式分布拟合数据的结果:

* 当 $M=0$$M=1$时,拟合曲线较简单,模型欠拟合;

* 当 $M=8$ 时,拟合曲线较复杂,模型过拟合;

* 当 $M=3$ 时,模型拟合最为合理。

4.模型评估

        下面通过均方误差来衡量训练误差、测试误差以及在没有噪音的加入下sin函数值与多项式回归值之间的误差,更加真实地反映拟合结果。多项式分布阶数从0到8进行遍历。

# 训练误差和测试误差
training_errors = []
test_errors = []
distribution_errors = []# 遍历多项式阶数
for i in range(9):model = Linear(i)X_train_transformed = polynomial_basis_function(X_train.reshape([-1, 1]), i)X_test_transformed = polynomial_basis_function(X_test.reshape([-1, 1]), i)X_underlying_transformed = polynomial_basis_function(X_underlying.reshape([-1, 1]), i)optimizer_lsm(model, X_train_transformed, y_train.reshape([-1, 1]))y_train_pred = model(X_train_transformed).squeeze()y_test_pred = model(X_test_transformed).squeeze()y_underlying_pred = model(X_underlying_transformed).squeeze()train_mse = mean_squared_error(y_true=y_train, y_pred=y_train_pred).item()training_errors.append(train_mse)test_mse = mean_squared_error(y_true=y_test, y_pred=y_test_pred).item()test_errors.append(test_mse)# distribution_mse = mean_squared_error(y_true=y_underlying, y_pred=y_underlying_pred).item()# distribution_errors.append(distribution_mse)print("train errors: \n", training_errors)
print("test errors: \n", test_errors)
# print ("distribution errors: \n", distribution_errors)# 绘制图片
plt.rcParams['figure.figsize'] = (8.0, 6.0)
plt.plot(training_errors, '-.', mfc="none", mec='#e4007f', ms=10, c='#e4007f', label="Training")
plt.plot(test_errors, '--', mfc="none", mec='#f19ec2', ms=10, c='#f19ec2', label="Test")
# plt.plot(distribution_errors, '-', mfc="none", mec="#3D3D3F", ms=10, c="#3D3D3F", label="Distribution")
plt.legend(fontsize='x-large')
plt.xlabel("degree")
plt.ylabel("MSE")
plt.savefig('ml-mse-error.pdf')
plt.show()

        可视化结果如下:

由可视化结果可得:

  1. 当阶数较低的时候,模型的表示能力有限,训练误差和测试误差都很高,代表模型欠拟合;
  2. 当阶数较高的时候,模型表示能力强,但将训练数据中的噪声也作为特征进行学习,一般情况下训练误差继续降低而测试误差显著升高,代表模型过拟合。

        对于模型过拟合的情况,可以引入正则化方法,通过向误差函数中添加一个惩罚项来避免系数倾向于较大的取值。

degree = 8 # 多项式阶数
reg_lambda = 0.0001 # 正则化系数X_train_transformed = polynomial_basis_function(X_train.reshape([-1,1]), degree)
X_test_transformed = polynomial_basis_function(X_test.reshape([-1,1]), degree)
X_underlying_transformed = polynomial_basis_function(X_underlying.reshape([-1,1]), degree)model = Linear(degree) optimizer_lsm(model,X_train_transformed,y_train.reshape([-1,1]))y_test_pred=model(X_test_transformed).squeeze()
y_underlying_pred=model(X_underlying_transformed).squeeze()model_reg = Linear(degree) optimizer_lsm(model_reg,X_train_transformed,y_train.reshape([-1,1]),reg_lambda=reg_lambda)y_test_pred_reg=model_reg(X_test_transformed).squeeze()
y_underlying_pred_reg=model_reg(X_underlying_transformed).squeeze()mse = mean_squared_error(y_true = y_test, y_pred = y_test_pred).item()
print("mse:",mse)
mes_reg = mean_squared_error(y_true = y_test, y_pred = y_test_pred_reg).item()
print("mse_with_l2_reg:",mes_reg)# 绘制图像
plt.scatter(X_train, y_train, facecolor="none", edgecolor="#e4007f", s=50, label="train data")
plt.plot(X_underlying, y_underlying, c='#000000', label=r"$\sin(2\pi x)$")
plt.plot(X_underlying, y_underlying_pred, c='#e4007f', linestyle="--", label="$deg. = 8$")
plt.plot(X_underlying, y_underlying_pred_reg, c='#f19ec2', linestyle="-.", label="$deg. = 8, \ell_2 reg$")
plt.ylim(-1.5, 1.5)
plt.annotate("lambda={}".format(reg_lambda), xy=(0.82, -1.4))
plt.legend(fontsize='large')
plt.savefig('ml-vis4.pdf')
plt.show()

可视化结果为:

         多项式回归的难点就是是否真正理解了线性回归中的算子,均方误差等函数的目的和用法,其他的就是简单的函数调用问题啦~不懂的话还是建议多看看线性函数,先把线性函数的看明白最好~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/153490.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

go的面向对象学习

文章目录 面向对象编程(上)1.问题与解决思路2.结构体1》Golang语言面向对象编程说明2》结构体与结构体变量(实例/对象)的关系的示意图3》入门案例(using struct to solve the problem of cat growing) 3.结构体的具体应用4.创建结构体变量和访问结构体字段5.struct类型的内存分…

axios的get请求时数组参数没有下标

开发新项目过程中 发现get请求时 数组参数没有下标 这样肯定是不行的 后端接口需要数组[0]: 7 数组[1]:4这样的数据 原因是因为在请求拦截器没有处理需要的参数 解决方法 在请求拦截器 处理一下参数 import axios, { AxiosError, AxiosInstance, AxiosRequestHeaders } fro…

解决yolo无法指定显卡的问题,实测v5、v7、v8有效

方法1 基本上这个就能解决了!!! 在train.py的最上方加上下面这两行,注意是最上面,其次指定的就是你要使用的显卡 import os os.environ[CUDA_VISIBLE_DEVICES]6方法2: **问题:**命令行参数指…

HTML5+CSSday4综合案例二——banner效果

bannerCSS展示图&#xff1a; <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta http-equiv"X-UA-Compatible" content"IEedge"><meta name"viewport" content"wi…

Redis安装教程

官网地址 地址链接&#xff1a;传送门 安装步骤 这里有更多版本的选择 进去根据自己的需要选择版本&#xff0c;我这里用的7系列的稳定版。

AI视频监控平台EasyCVR接入海康SDK出现异常,该如何解决?

安防监控系统/视频集中存储/云存储/视频监控管理平台EasyCVR能在复杂的网络环境中&#xff0c;将分散的各类视频资源进行统一汇聚、整合、集中管理&#xff0c;实现视频资源的鉴权管理、按需调阅、全网分发、智能分析等。 有用户反馈&#xff0c;在使用视频监控系统EasyCVR接入…

汽车烟雾测漏仪(EP120)

【汽车烟雾测漏仪&#xff08;EP120&#xff09;】 此烟雾测漏仪专为车辆管道&#xff08;油道、气道、冷却管道&#xff09; 的泄露检测而设计。适用于所有轻型 汽车、摩托车、轻卡、游艇等。 【特点】 具有空气模式和烟雾模式。空气模式&#xff0c;无需烟雾&#xff0c;检测…

打造虚拟企业展厅,开启商务活动新时代

引言: 虚拟企业展厅是一种基于数字技术的全新商务模式&#xff0c;正在改变传统商务活动的方式&#xff0c;它比传统的企业展厅更便利&#xff0c;也更能凸显企业优势&#xff0c;展示企业风貌。 一&#xff0e;虚拟企业展厅的好处 1.打破地域限制 传统的商务活动通常需要参…

12. Java异常及异常处理处理

Java —— 异常及处理 1. 异常2. 异常体系3. 常见Exception4. 异常处理4.1 try finally catch关键字4.2 throws和throw 自定义异常4.3 finally&#xff0c;final&#xff0c;finalize三者的区别 1. 异常 异常&#xff1a;在程序执行过程中发生的意外状况&#xff0c;可能导致程…

JVM命令行监控工具

JVM命令行监控工具 概述 性能诊断是软件工程师在日常工作中需要经常面对和解决的问题&#xff0c;在用户体验至上的今天&#xff0c;解决好应用的性能问题能带来非常大的收益。 Java作为最流行的编程语言之一&#xff0c;其应用性能诊断一直受到业界广泛关注&#xff0c;可能…

vue3 集成 tailwindcss

tailwindcss 介绍 Tailwind CSS 是一个流行的前端框架&#xff0c;用于构建现代、响应式的网页和 Web 应用程序。它的设计理念是提供一组可复用的简单、低级别的 CSS 类&#xff0c;这些类可以直接应用到 HTML 元素上&#xff0c;从而加速开发过程并提高样式一致性。 主要特点…

超市微信小程序是怎么做的

市微信小程序是利用微信小程序平台为超市或零售商提供线上销售服务的一种应用。通过小程序&#xff0c;超市可以向消费者提供更加便捷、快速、个性化的购物体验&#xff0c;从而提升销售业绩、增加客户满意度。以下是超市微信小程序可以实现的一些主要功能。 一、商品展示与搜索…

linux centos出现No space left on device解决方案

问题是因为系统磁盘空间不足 解决方法: 找到那个磁盘不足问题 df -lh 发现/dev/mapper/cl-root磁盘已用50G,有如下 解决方案&#xff1a; 1、如果是虚拟机可以通过分配空间使其空间增加 2、将其他不常用磁盘空间分配给cl-root如&#xff08; /dev/mapper/cl-home &#…

XSS原理

原理&#xff1a; 这是一种将任意 Javascript 代码插入到其他Web用户页面里执行以达到攻击目的的漏洞。攻击者利用浏览器的动态展示数据功能&#xff0c;在HTML页面里嵌入恶意代码。当用户浏览改页时&#xff0c;这些潜入在HTML中的恶意代码会被执行&#xff0c;用户浏览器被攻…

6-3 递增的整数序列链表的插入 分数 5

List Insert(List L, ElementType X) {//创建结点List node (List)malloc(sizeof(List));node->Data X;node->Next NULL;List head L->Next; //定位real头指针//空链表 直接插入if (head NULL) {L->Next node;node->Next head;return L;}//插入数据比第…

哈夫曼树哈夫曼编码知识点

目录 前言 哈夫曼树 1.相关背景 2.基本概念 3.什么是哈夫曼树 3.特点 4.哈夫曼树的构造&#xff08;哈夫曼算法&#xff09; 5.带权路径长度计算 哈夫曼编码&#xff08;哈夫曼树的应用&#xff09; 1.基本概念 2.编码方式 3.编码依据 4.小试牛刀&#xff08;习题&am…

智慧楼宇3D数据可视化大屏交互展示实现了楼宇能源的高效、智能、精细化管控

智慧园区是指将物联网、大数据、人工智能等技术应用于传统建筑和基础设施&#xff0c;以实现对园区的全面监控、管理和服务的一种建筑形态。通过将园区内设备、设施和系统联网&#xff0c;实现数据的传输、共享和响应&#xff0c;提高园区的管理效率和运营效益&#xff0c;为居…

socket can查看详细信息 命令 ip -details -statistics link show can0

ip -details -statistics link show can0 ip -details link show can0 ip -statistics link show can0 也可以像第一行那样结合使用

Synchronized的实现和锁升级

1.JVM是如何处理和识别Synchronized的&#xff1f; 我们从字节码角度分析synchronized的实现&#xff1a; Synchronized(锁对象){}同步代码块底层实现方式是monitorenter和monitorexit指令。 修饰普通同步方法时底层实现方式是执行指令会检查方法是否设置ACC_SYNCHRONIZED&am…

登陆认证权限控制(1)——从session到token认证的变迁 session的问题分析 + CSRF攻击的认识

前言 登陆认证&#xff0c;权限控制是一个系统必不可少的部分&#xff0c;一个开放访问的系统能否在上线后稳定持续运行其实很大程度上取决于登陆认证和权限控制措施是否到位&#xff0c;不然可能系统刚刚上线就会夭折。 本篇博客回溯登陆认证的变迁历史&#xff0c;阐述sess…