神经网络进行波士顿房价预测

前言

前一阵学校有五一数模节校赛,和朋友一起参加做B题,波士顿房价预测,算是第一次自己动手实现一个简单的小网络吧,虽然很简单,但还是想记录一下。

题目介绍

波士顿住房数据由哈里森和鲁宾菲尔德于1978年Harrison and Rubinfeld1收集。它包括了波士顿大区每个调查行政区的506个观察值。1980年Belsley et al.2曾对此数据做过分析。

数据一共14列,每一列的含义分别如下:

英文简称    详细含义
CRIM    城镇的人均犯罪率
ZN    大于25,000平方英尺的地块的住宅用地比例。
INDUS    每个镇的非零售业务英亩的比例。
CHAS    查尔斯河虚拟变量(如果环河,则等于1;否则等于0)
NOX    一氧化氮的浓度(百万分之几)
RM    每个住宅的平均房间数
AGE    1940年之前建造的自有住房的比例
DIS    到五个波士顿就业中心的加权距离
RAD    径向公路通达性的指标
TAX    每一万美元的全值财产税率
PTRATIO    各镇的师生比率
B    计算方法为 $1000(B_k-0.63)^2$,其中Bk是按城镇划分的非裔美国人的比例
LSTAT    底层人口的百分比(%)
price    自有住房数的中位数,单位(千美元)
基于上述数据,请完成以下问题:

建立波士顿房价预测模型并对预测结果进行评价。

问题分析

首先这道题目的很明确,数据一共是 $506×14$ 的一个矩阵,有十三维的自变量,通过建立一个模型来拟合回归出最终的因变量 price,即户主拥有住房价值的中位数。这是一个回归问题,综合考虑有以下两个思路

通过各种回归算法(GradientBoostingRegressor,RandomForestRegressor,ExtraTreesRegressor,LinearRegressor等)结合全部或部分自变量来回归最终的price

建立前馈神经网络模型,根据通用逼近定理,我们可以拟合此回归模型。

我们对上述模型来进行实现并确定评估标准来对他们进行比较,选择最优的模型作为预测模型。

算法流程

传统的回归算法

自变量的选择

首先,考虑到数据集中13列自变量其中某一些可能和最终的房价并无强相关性,如果全部使用进行预测可能会对模型引入噪声,因此我们首先计算了房价price与各个自变量之间的相关系数 $r$ ,其中 $r$ 计算公式如下: $$ r = \frac{\sum(x_i-\bar{x})(y_i-\bar{y})}{\sqrt{\sum(x_i-\bar{x})^2\sum(y_i-\bar{y})^2}} $$ 其中 $x_i,y_i$ 为数据的每个分量,$\bar{x},\bar{y}$ 为数据的均值

该系数反映了两变量之间的相关性,$r$ 的绝对值介于 $[0,1]$ 区间内,$|r|$ 越接近1,表示两数据相关性越高,反之越低。计算后结果如下:


观察结果可以发现,在给定的十三个变量中,LSTAT 与 price 的相关程度最高$(|r|>0.7)$,其次是 RM 与PTRATIO $(|r|>0.5)$,再者是 TAX,INDUS,NOX $(|r|>0.4)$,除上述之外的七个变量都与 price 无较强的相关性,因此我们考虑使用六个相关性较强变量和十三个变量分别来对房价进行预测,并对他们进行对比,来寻找最优的回归模型。

模型的构建

首先我们使用了sklearn中自带的 boston 数据集,并将整体数据集随机划分为了训练集和测试集两部分,所占比例分别为80%和20%。

然后,我们利用Linear,Ridge,Lasso,ElasticNet,DecisionTree,GradientBoosting,RandomForest,ExtraTrees八种模型通过训练集对其进行训练。

接下来,我们利用训练集拟合得到的模型,使用测试集对其进行测试,与 Ground Truth 进行对比,并通过 $R^2$ 来评价该预测结果,其中 $R^2$ 计算公式如下,其是衡量回归模型好坏的常见指标,其值一般处于[0,1]之间,$R^2$ 越接近1,说明模型的性能越好。 $$ R^2 = 1-\frac{\sum(\hat{y_i}-y_i)^2}{\sum(\bar{y}-y_i)^2} $$

最后,考虑到模型的训练及预测可能具有偶然性,因此我们对于每一个模型进行20次训练及预测,利用20次的结果对其进行综合评价。利用得到的结果绘制 箱线图 所得结果如下:

分析最终结果可以发现,无论是使用六个相关性较强变量还是十三个变量来进行预测,GradientBoost(梯度提升决策树)回归模型都是最好的,此外,我们可以发现,利用十三个变量要比利用六个主要变量来进行预测比有着更好的效果。

前馈神经网络

模型的构建
近年来,神经网络理论不断发展,前馈神经网络(多层感知机、全连接神经网络)越来越多的被利用到数据分析中,因此考虑使用前馈神经网络来解决此问题。

前馈神经网络(全连接神经网络)的网络结构一般由三部分构成,输入层,隐藏层,以及输出层,输入层与输出层一般只有一层,隐藏层可有多层。中间利用非线性函数作为激活函数可以使得网络具有拟合非线性函数的能力

根据通用近似定理:

通用近似定理

对于具有线性输出层和至少一个使用“挤压”性质的激活函数的隐藏层组成的前馈神经网络,只要其隐藏层神经元的数量足够,它可以以任意精度来近似任何从一个定义在实数空间中的有界闭集函数。

只要隐藏层网络维度够高,就可以拟合任意的函数。

考虑到我们的模型有六维or十三维的数据输入,因此我们建立两层前馈神经网络,中间具有一层隐藏层,维度为1000维,激活函数使用Relu,Relu函数有以下优点:

Relu相比于传统的Sigmoid、Tanh,导数更加好求,反向传播就是不断的更新参数的过程,因为其导数不复杂形式简单,可以使得网络训练更快速。

此外,当数值过大或者过小,Sigmoid,Tanh的导数接近于0,Relu为非饱和激活函数则不存在这种现象,可以很好的解决梯度消失的问题

Relu函数及网络结构图如图所示:

$$ Relu:f(x) = max(0,x) $$

具体实现

利用流行的深度学习框架 Pytorch 来对模型进行实现。

首先,将数据集随机划分为训练集和测试集两部分,分别占80%和20%,并将其转化为Pytorch中的张量形式。
然后,利用MinMaxScaler对输入数据进行归一化,利用下列公式将其统一归一化为 $[0,1]$ 之间,以求模型能够更快的收敛。
$$ MinMaxScaler:x^{*} = \frac{x-min(x)}{max(x)-min(x)} $$

接下来,构建网络模型,利用 mseloss 作为损失函数,在训练过程中利用反向传播使其最终收敛为0。
$$ MseLoss = \frac{1}{2n}\sum||y(x)-a^L(x)||^2 $$

最后,我们设置网络的学习率为0.01,训练10000个epoch,发现其loss最终降低到0.3%左右,我们利用上文提到的 $R^2$ 对结果进行评估并与回归模型进行对比,通过观察图片可以发现,前馈神经网络相比于传统的回归模型有着更好的拟合效果, 20次预测得到的$R^2$平均值达到了0.95,此外中位数,最大值,最小值也要比回归模型更加优秀,因此我们采用前馈神经网络模型来对最后的房价进行预测。

最终预测

最终我们利用构建的前馈神经网络模型进行预测,利用测试集对其进行对比,绘制预测如下:

可以看到其中很多点都覆盖的很好,即预测准确。

通过理论对模型进行量化分析,计算预测的 $R^2$ $$ R^2 = 1-\frac{\sum(\hat{y_i}-y_i)^2}{\sum(\bar{y}-y_i)^2} = 1-0.01357 = 0.98643=98.643% $$ 可以发现 $R^2$ 十分接近1,说明回归模型性能良好,符合要求。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/461327.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

如果要用示波器测量电路中某处电压与电流的相位差,应如何实现?

使用示波器测量电路中某处电压与电流的相位差,可以通过以下步骤实现: 1. 准备和连接 所需设备 示波器(双通道)电流探头(或电阻分压器用于间接测量电流)电压探头 连接探头到待测信号 电压探头&#xff…

恋爱脑学Rust之闭包三Traits:Fn,FnOnce,FnMut

在Rust中,FnOnce、FnMut和Fn是三个用于表示闭包(closure)类型的trait。闭包是一种特殊的函数,它可以捕获其环境变量,即在其定义时所处的作用域中的变量。以下是关于这三个trait的详细介绍: 1. FnOnce&#…

【Linux:TCP通信流程】

网络字节序: 计算机中存在两种存储字节的方式,分别是:大端存储和小端存储,TCP/IP协议规定,网络数据字节流应采用大端字节序。如果当前发送的主机是小端机就需要将数据转化为大端,再发送。 小端存储&#…

HTB:Cicada[WriteUP]

目录 连接至HTB服务器并启动靶机 使用nmap对靶机进行开放端口扫描 使用nmap对靶机开放端口进行脚本、服务信息扫描 首先尝试空密码连接靶机SMB服务 由于不知道账户名,这里我们使用crackmapexec对smb服务进行用户爆破 通过该账户连接至靶机SMB服务器提取敏感信…

17. 云计算和分布式计算

文章目录 第17章 云计算和分布式计算17.1 云基础17.2 云中的故障超时长尾延迟 17.3 利用多个实例提升性能和可用性分布式计算和负载均衡器分布式系统中的状态管理分布式系统中的时间协调分布式系统中的数据协调自动扩展:实例的自动创建和销毁自动扩展虚拟机自动缩放…

【GESP】C++一级知识点研究,cout和printf性能差异分析

一道简单循环输出练习题(BCQM3148,循环输出),由于cout的代码超时问题,让我注意到二者在使用上的差异,遂查阅研究如下。 全文详见:https://www.coderli.com/gesp-knowledge-cout-printf/【GESP】C一级知识点研究&#…

入门 | Kafka数据使用vector消费到Loki中使用grafana展示

一、Loki的基本介绍 1、基本介绍 Loki 是由 Grafana Labs 开发的一款水平可扩展、高性价比的日志聚合系统。它的设计初衷是为了有效地处理和存储大量的日志数据,与 Grafana 生态系统紧密集成,方便用户在 Grafana 中对日志进行查询和可视化操作。 从架构…

Golang | Leetcode Golang题解之第515题在每个树行中找最大值

题目: 题解: func largestValues(root *TreeNode) (ans []int) {if root nil {return}q : []*TreeNode{root}for len(q) > 0 {maxVal : math.MinInt32tmp : qq nilfor _, node : range tmp {maxVal max(maxVal, node.Val)if node.Left ! nil {q …

Vue3的router和Vuex的学习笔记整理

一、路由的基本搭建 1、安装 npm install vue-router --registryhttps://registry.npmmirror.com 2、配置路由模块 第一步:src/router/index.js创建文件 第二步:在src/view下面创建两个vue文件,一个叫Home.vue和About.vue 第三步&#x…

vue插件清除 所有console.log()

一、作用 1、提升性能console.log() 语句会消耗一定的性能,尤其是在频繁调用的情况下。在生产环境中移除这些语句可以提高应用的运行效率。 2、减少信息泄露console.log() 可以输出敏感信息(如用户数据、API 响应等)。在生产环境中&#xf…

vue项目中如何在路由变化时增加一个进度条

在 Vue.js 项目中,使用路由(如 Vue Router)时,为了提升用户体验,你可能会想要在路由变化时显示一个进度条。这可以通过多种方式实现,其中一种流行的做法是使用第三方库,如 vue-loading-bar 或 n…

python 模块和包、类和对象

模块 模块是包含 Python 代码的文件,通常用于组织相关的函数、类和其他语句。模块可以被导入并在其他 Python 文件中使用。 创建模块 假设你创建了一个名为 mymodule.py 的文件,内容如下: # mymodule.pydef greet(name): return f"…

图书管理系统汇报

【1A536】图书管理系统汇报 项目介绍1.用户登录注册功能1. 1用户角色管理2.图书管理功能2.1 添加图书2.2 编辑图书2.3 删除图书 3.图书搜索和筛选3.1 图书搜索3.2 图书筛选 4.图书借阅、图书归还4.1 图书借阅4.2 图书归还 5.用户信息管理5.1上传头像5.2修改头像5.3 修改密码 项…

执行Django项目的数据库迁移命令时报错:(1050, “Table ‘django_session‘ already exists“);如何破?

一、问题描述: 当我们写Django时,由于自己的操作不当,导致执行数据库迁移命令时报错,报错的种类有很多,例如: 迁移文件冲突:可能你有多个迁移文件试图创建同一个表。数据库状态与迁移文件不同…

蓝牙BLE开发——红米手机无法搜索蓝牙设备?

解决 红米手机,无法搜索附近蓝牙设备 具体型号当时忘记查看了,如果你遇到有以下选项,记得打开~ 设置权限

java设计模式之创建者模式(5种)

设计模式 软件设计模式,又称为设计模式,是一套被反复利用,代码设计经验的总结,他是在软件设计过程中的一些不断发生的问题,以及该问题的解决方案。 **创建者模式又分为以下五个模式:**用来描述怎么“将对象…

SpringSecurity框架(入门)

简介: Spring Security 是一个用于构建安全的 Java 应用程序的框架,尤其适用于基于Spring的应用程序。它提供了全面的安全控制,从认证(Authentication)到授权(Authorization),以及…

广东网站设计提升你网站在搜索引擎中的排名

在当今网络盛行的时代,拥有一个设计优良的网站,对企业的在线发展至关重要。特别是对于广东地区的企业来说,网站设计不仅仅是美观的问题,更直接影响着搜索引擎中的排名。因此,精心策划和设计的网站,能够显著…

Cuebric:用AI重新定义3D创作的未来

一、简介 Cuebric 是一家成立于2022年夏天的好莱坞创新公司,致力于为电影、电视、游戏和时尚等行业提供先进的AI多模态SaaS平台。自2024年1月正式推出以来,Cuebric 已经在市场上获得了广泛的认可和积极的反馈。目前,该平台正处于1.0版本的beta测试阶段,已募集约50万美元的…

计算机的错误计算(一百四十)

摘要 探讨 MATLAB 中函数 的计算精度。 从计算机的错误计算(一百三十九)知,对于对数运算,当真数在 1 附近时,计算机的输出会出现较大误差。为此,IEEE 754-2019 中专门定义有函数 其目的就是当自变量在 …