Scikit-LearnTensorFlow机器学习实用指南(三):一个完整的机器学习项目【下】

机器学习实用指南(三):一个完整的机器学习项目【下】

作者:LeonG
本文参考自:《Hands-On Machine Learning with Scikit-Learn & TensorFlow 机器学习实用指南》,感谢中文AI社区ApacheCN提供翻译。

本文全部代码和数据集保存在我的github-----LeonG的github

1.回顾

在上一节,我们从网络上获取了数据:housing

然后将数据分为训练集strat_train_set和测试集strat_test_set

将训练标签也就是房价单独分离housing_labels

最后分析了训练集的一些规律,针对这个数据集制作了一个数据整理工具full_pipeline,将训练集strat_train_set转为housing_prepared

经过这些步骤,我们的训练模型只需要输入训练数据housing_prepared和训练标签housing_labels,就可以得到训练好的模型了。

2.训练模型

接下来我们要尝试几种机器学习的算法模型:线性回归模型、决策树模型、随机森林模型。

提示:这些算法模型的具体原理和细节在以后的章节会详细解析,本章只是简单的使用,不用担心不看懂。

2.1线性回归模型

我们先来训练一个线性回归模型,借助sklearn中的LinearRegression类来实现:

from sklearn.linear_model import LinearRegression
lin_reg = LinearRegression()
#输入训练数据进行训练
lin_reg.fit(housing_prepared,housing_labels)

只需要这样简单的三行操作就能训练完一个线性回归模型。

现在使用训练集中的前五行来验证:

#取前五行数据
some_data = housing.iloc[:5]
some_labels = housing_labels.iloc[:5]
#对这些数据进行预测(代入到训练好的模型中计算出预测房价)
some_data_prepared = full_pipeline.transform(some_data)
#模型的预测值
print("Predictions:\t", lin_reg.predict(some_data_prepared))
#数据集的标签值
print("Labels:\t\t", list(some_labels))
Predictions: [210644.60 317768.80 210956.43 59218.98 189747.55]
Labels:      [286600.0, 340600.0, 196900.0, 46300.0, 254500.0]

可以看出数据之间差距还是比较大的,我们计算一下这个回归模型的RMSE

RMSE是均方根误差:
\sqrt{\sum_{i=1}^{n}\frac{1}{n}{(f(x_i)-y_i)}^{2}}
均方根误差的意义大概可以理解为:预测值和实际值的平均差距

当然,我们也不需要手动写公式,直接让sklearn来帮我们算:

from sklearn.metrics import mean_squared_error
#计算出训练集的所有预测值
housing_predictions = lin_reg.predict(housing_prepared)
#计算线性模型的预测值和实际值的均方误差
lin_mse = mean_squared_error(housing_labels, housing_predictions)
#均方根误差=均方误差开方
lin_rmse = np.sqrt(lin_mse)
lin_rmse
68628.19819848923

计算得出线性回归模型的均方根误差很大,这样肯定不行。

结论:对于该数据集来说,线性回归模型是一个欠拟合模型。

2.2决策树模型

修复欠拟合的主要方法是选择一个更强大的模型,接下来试试决策树模型。

决策树模型可以发现数据中复杂的非线性关系。借助sklearn中的DecisionTreeRegressor类来实现:

from sklearn.tree import DecisionTreeRegressor
tree_reg = DecisionTreeRegressor()
#输入训练数据进行训练
tree_reg.fit(housing_prepared, housing_labels)

这次不取前五行测试了,直接计算这个模型的RMSE:

#计算出训练集的所有预测值
housing_predictions = tree_reg.predict(housing_prepared)
#计算决策树模型的预测值和实际值的均方误差
tree_mse = mean_squared_error(housing_labels, housing_predictions)
#均方根误差=均方误差开方
tree_rmse = np.sqrt(tree_mse)
tree_rmse
0.0

咦,没有误差?这个模型是绝对完美的吗?不对,这是因为模型严重的拟合数据,任何一条训练数据都能得到对应的训练标签。

结论:对于该数据集来说,决策树模型是一个过拟合模型。

3.交叉验证

如何验证模型的真实水平呢?在确定模型之前,我们都不要碰测试集,所以需要用训练集的部分数据来做训练,接下来使用交叉验证法。

3.1K折交叉验证法

K折交叉验证法:将数据集分为K份,称为折,每次用其中一个折作为测试集来计算误差,经过K次计算后求出一组长度为K的误差值,这组误差的平均值就是交叉验证得出的误差值。

我们将K设置为10:

借助sklearn可以很简单的实现验证:

from sklearn.model_selection import cross_val_score
#总共有五个参数,第一个是模型,2和3是数据,scoring指定了计算方式,cv是K值
scores = cross_val_score(tree_reg, housing_prepared, housing_labels,scoring="neg_mean_squared_error", cv=10)
#score是效用函数计算得出的,实际上和均方误差相反,所以要加上负号
tree_rmse_scores = np.sqrt(-scores)

设置一个输出函数来查看具体情况:

def display_scores(scores):#均方误差,一共十个print("Scores:", scores) #平均的均方误差print("Mean:", scores.mean())#均方误差的标准差print("Standard deviation:", scores.std())
display_scores(tree_rmse_scores)
Scores: [67649.82 67698.67 71079.28 69445.09 71808.23 73827.59 71111.46 71243.31 75630.03 70498.20]
Mean: 70999.17217565424
Standard deviation: 2344.261017051602

可以看出,决策树模型并没有那么好用,甚至比线性回归模型还糟糕。

交叉验证不仅可以得到模型性能的评估,还能测量评估的准确性(标准差)。决策树的误差大概是71000,波动幅度±2300。

3.2随机森林模型

上述两个模型误差都很大,现在使用随机森林模型。

这个模型的名字很有意思,如果将决策树看做是一棵树的话,随机森林就是随机组合一些属性来训练许多决策树,在其他多个模型之上建立模型成为集成学习。借助sklearn中的RandomForestRegressor类来实现:

from sklearn.ensemble import RandomForestRegressorforest_reg = RandomForestRegressor()
#输入训练数据进行训练
forest_reg.fit(housing_prepared,housing_labels)
#计算交叉验证误差
scores = cross_val_score(forest_reg, housing_prepared, housing_labels,scoring="neg_mean_squared_error", cv=10)
#score是效用函数计算得出的,实际上和均方误差相反,所以要加上负号
forest_rmse_scores = np.sqrt(-scores)
display_scores(forest_rmse_scores)
Scores: [51066.82 50166.73 52755.29 55534.63 51963.35 54194.03 52341.03 50770.49 54823.32 52582.53]
Mean: 52619.825856487405
Standard deviation: 1679.5101421709217

看起来效果比上面两个模型都要好,实际上我们应该多测试几个模型,比如不同核心的支持向量机、神经网络等等,目标是列出可以使用模型的列表。做完之后就是对模型的微调了。

4.模型微调

调整什么呢,调整超参数,超参数是什么,超参数就是不能通过学习来自动调整的参数。比如学习率,神经网络的层数等等。

在机器学习中,超参数是在开始学习过程之前设置值的参数,而不是通过训练得到的参数数据。通常情况下,需要对超参数进行优化,给学习机选择一组最优超参数,以提高学习的性能和效果。

超参数的一些示例:

  • 树的数量或树的深度
  • 矩阵分解中潜在因素的数量
  • 学习率(多种模式)
  • 深层神经网络隐藏层数
  • k均值聚类中的簇数

4.1网格搜索

网格搜索的意思很简单,为几个超参数设定一个范围取值,逐一搜索最佳组合的方式就是网格搜索。

我们借助SKlearn中的GridSearchCV来自动完成搜索工作。

你所需要做的是告诉 GridSearchCV 要试验有哪些超参数,要试验什么值, GridSearchCV 就能用交叉验证试验所有可能超参数值的组合。例如,下面的代码搜索了随机森林模型超参数值的最佳组合:

from sklearn.model_selection import GridSearchCV
param_grid = [
#字典1:尝试3×4=12种组合
{'n_estimators': [3, 10, 30], 'max_features': [2, 4, 6, 8]},
#字典2:尝试1×2×3=6种组合
{'bootstrap': [False], 'n_estimators': [3, 10], 'max_features': [2, 3, 4]},
]
forest_reg = RandomForestRegressor()
#定义一个网格搜索,采用5折交叉验证法,判断标准是均方误差
grid_search = GridSearchCV(forest_reg, param_grid, cv=5,scoring='neg_mean_squared_error')
#使用这些组合来训练随机森林
grid_search.fit(housing_prepared, housing_labels)
#输出最佳超参数组合
grid_search.best_params_
#输出最佳模型
grid_search.best_estimator_
{'max_features': 8, 'n_estimators': 30}
RandomForestRegressor(bootstrap=True, criterion='mse', max_depth=None,max_features=8, max_leaf_nodes=None,min_impurity_decrease=0.0, min_impurity_split=None,min_samples_leaf=1, min_samples_split=2,min_weight_fraction_leaf=0.0, n_estimators=30,n_jobs=None, oob_score=False, random_state=None,verbose=0, warm_start=False)

两个搜索组合,第一个组合有12种情况,第二个组合有6种情况,一共18种情况,因为是五折交叉验证,所以,一共要进行18×5=90轮训练!完成后,就能返回超参数的最佳组合best_params_最佳模型best_estimator_

我们还可以查看网格搜索中每一个属性组合的得分情况:

cvres = grid_search.cv_results_
#输出超参数组合对应的得分情况
for mean_score, params in zip(cvres["mean_test_score"], cvres["params"]):print(np.sqrt(-mean_score), params)
65092.02010981559 {'max_features': 2, 'n_estimators': 3}
55359.4496463199  {'max_features': 2, 'n_estimators': 10}
52556.54068544674 {'max_features': 2, 'n_estimators': 30}
59605.3135515846  {'max_features': 4, 'n_estimators': 3}
53164.81127753654 {'max_features': 4, 'n_estimators': 10}
50676.9400303366  {'max_features': 4, 'n_estimators': 30}
59477.454981964   {'max_features': 6, 'n_estimators': 3}
52493.72219849825 {'max_features': 6, 'n_estimators': 10}
50129.58037046355 {'max_features': 6, 'n_estimators': 30}
58975.88486428257 {'max_features': 8, 'n_estimators': 3}
51999.19103564533 {'max_features': 8, 'n_estimators': 10}
49948.7230892116  {'max_features': 8, 'n_estimators': 30}
61188.21752051931 {'bootstrap': False, 'max_features': 2, 'n_estimators': 3}
54352.52617644768 {'bootstrap': False, 'max_features': 2, 'n_estimators': 10}
60598.83975833867 {'bootstrap': False, 'max_features': 3, 'n_estimators': 3}
52874.9435481527  {'bootstrap': False, 'max_features': 3, 'n_estimators': 10}
59428.43938347719 {'bootstrap': False, 'max_features': 4, 'n_estimators': 3}
52007.58232013096 {'bootstrap': False, 'max_features': 4, 'n_estimators': 10}

可以看出{'max_features': 8, 'n_estimators': 30}这个超参数组合得分最高,我们成功的使用网格搜索法调整了超参数值,将误差值从52619降低到49948

4.2其他方法

网格搜索看起来就是穷举法,穷举的方式在组合数少的情况下还能用,组合多的话最好使用随机搜索RandomizedSearchCV,虽然不会尝试所有的组合,但是能抽取更多完全不同的组合情况,还能方便的设定搜索次数,控制计算量。

还有一种方法是集成法,将几个不同的最佳模型组合起来使用,这个方法在后面的章节深入讲解。

5.测试模型

现在,我们可以测试一下调整好的模型了。

测试集也要进行处理,类比本章第一节的操作:

  1. 将测试数据分为两个测试数据X_test和测试标签y_test

  2. 将测试数据进行数据整理将X_test转为X_test_prepared,注意,这里使用tranform函数而不是fit_transform函数。

最后对测试集进行预测,计算预测值和实际值的均方根误差就能得到最终的误差效果。

#获得最佳的模型
final_model = grid_search.best_estimator_
#分割测试数据和测试标签
X_test = strat_test_set.drop("median_house_value", axis=1)
y_test = strat_test_set["median_house_value"].copy()
#将测试数据进行整理(使用transform函数)
X_test_prepared = full_pipeline.transform(X_test)
#计算预测值
final_predictions = final_model.predict(X_test_prepared)
#计算预测值和测试标签的均方误差
final_mse = mean_squared_error(y_test, final_predictions)
#计算最终的均方根误差
final_rmse = np.sqrt(final_mse)
48154.525254070046

得出该模型的均方根误差为48154。

至此,我们机器学习项目的开发阶段就算告一段落了。

最后,就是项目的预上线了,我们需要向万达集团展示具体实施方案,然后给自己倒上一杯卡布奇诺。

希望这一章能告诉你机器学习项目是什么样的,你能用学到的工具训练一个好系统。

你已经看到,大部分的工作是数据准备步骤、搭建监测工具、建立人为评估的流水线和自动化定期模型训练。

当然,最好能了解整个过程、熟悉三或四种算法,而不是在探索高级算法上浪费全部时间,导致在全局上的时间不够。 因此,如果你还没做,现在最好拿起台电脑,选择一个感兴趣的数据集,将整个流程从头到尾完成一遍。

讲实话,能坚持学到这里的朋友,真的很优秀。学完了这一章,你的机器学习之路已经成功了一半。

接下来我们会对机器学习的各种算法进行具体的学习和实践。


欢迎来我的博客留言讨论,我的博客主页:LeonG的博客

本文参考自:《Hands-On Machine Learning with Scikit-Learn & TensorFlow机器学习实用指南》,感谢中文AI社区ApacheCN提供翻译。

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。

最后编辑于:2024-08-25 10:33:47


喜欢的朋友记得点赞、收藏、关注哦!!!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/432633.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Rust - 字符串:str 与 String

在其他语言中,字符串通常都会比较简单,例如 “hello, world” 就是字符串章节的几乎全部内容了。 但是Rust中的字符串与其他语言有所不同,若带着其他语言的习惯来学习Rust字符串,将会波折不断。 所以最好先忘记脑中已有的关于字…

LiveNVR监控流媒体Onvif/RTSP功能-支持电子放大拉框放大直播视频拉框放大录像视频流拉框放大电子放大

LiveNVR监控流媒体Onvif/RTSP功能-支持电子放大拉框放大直播视频拉框放大录像视频流拉框放大电子放大 1、视频广场2、录像回看3、RTSP/HLS/FLV/RTMP拉流Onvif流媒体服务 1、视频广场 视频广场 -》播放 ,左键单击可以拉取矩形框,放大选中的范围&#xff…

汽车总线之----FlexRay总线

Introduction 随着汽车智能化发展,车辆开发的ECU数量不断增加,人们对汽车系统的各个性能方面提出了更高的需求,比如更多的数据交互,更高的传输带宽等。现如今人们广泛接受电子功能来提高驾驶安全性,像ABS防抱死系统&a…

网络安全 DVWA通关指南 DVWA Weak Session IDs(弱会话)

DVWA Weak Session IDs(弱会话) 文章目录 DVWA Weak Session IDs(弱会话)Low LevelMedium LevelHigh LevelImpossible Level 参考文献 WEB 安全靶场通关指南 相关阅读 Brute Force (爆破) Command Injection(命令注入…

SpringSecurity-用户认证

1、用户认证 1.1 用户认证核心组件 我们系统中会有许多用户,确认当前是哪个用户正在使用我们系统就是登录认证的最终目的。这里我们就提取出了一个核心概念:当前登录用户/当前认证用户。整个系统安全都是围绕当前登录用户展开的,这个不难理…

基于Spring JDBC AbstractRoutingDataSource 实现动态数据源

AbstractRoutingDataSource 实现动态数据源 AbstractRoutingDataSource 即抽象的路由数据源,提供了动态数据源切换的机制。你可以通过实现它的 determineCurrentLookupKey() 方法,根据不同的条件返回对应的数据源 key,基于这点可以根据外部输…

C语言 fwirte 函数 - C语言零基础入门教程

目录 一.fwirte 函数简介二.fwirte 函数使用三.猜你喜欢 零基础 C/C 学习路线推荐 : C/C 学习目录 >> C 语言基础入门 一.fwirte 函数简介 C 语言文件读写,fread 函数用于读取文件中的数据到指定缓冲区中,而 fwrite 函数用于把缓冲区数据写入到文件…

从1岁活到80岁很平凡 chatgpt 到底能干啥

有人说:一个人从1岁活到80岁很平凡,但如果从80岁倒着活,那么一半以上的人都可能不凡。 生活没有捷径,我们踩过的坑都成为了生活的经验,这些经验越早知道,你要走的弯路就会越少。 Introduction ChatGPT是一款基于人工智能技术的聊天机器人,可以自动回复用户的问题和提供…

【算法题】72. 编辑距离-力扣(LeetCode)

【算法题】72. 编辑距离-力扣(LeetCode) 1.题目 下方是力扣官方题目的地址 72. 编辑距离 给你两个单词 word1 和 word2, 请返回将 word1 转换成 word2 所使用的最少操作数 。 你可以对一个单词进行如下三种操作: 插入一个字符删除一个字符替换一个…

公交IC卡收单管理系统 多处 SQL注入致RCE漏洞复现

0x01 产品简介 公交IC卡收单管理系统是城市公共交通领域中不可或缺的一部分,它通过集成先进的集成电路技术(IC卡)实现了乘客便捷的支付方式,并有效提高了公共交通运营效率。系统集成了发卡、充值、消费、数据采集、查询和注销等多个功能模块,为公交公司和乘客提供了全面、…

使用shardingsphere实现mysql数据库分片

在大数据时代,随着业务数据量的不断增长,单一的数据库往往难以承载大规模的数据处理需求。数据库分片(Sharding)是一种有效的数据库扩展技术,通过将数据分布到多个数据库实例上,提高系统的性能和可扩展性。…

详细解读,F5服务器负载均衡的技术优势

在现代大规模、高流量的网络使用场景中,为应对高并发和海量数据的挑战,服务器负载均衡技术应运而生。但凡知道服务器负载均衡这一名词的,基本都对F5有所耳闻,因为负载均衡正是F5的代表作,换句通俗易懂的话来说&#xf…

曲面构件的布尔运算

1.前言 布尔运算算法有多种,可以根据几何数据表达方式分为Brep布尔运算、CSG布尔运算、网格布尔运算等,而网格布尔运算又又多种,如BSP方式、八叉树方式,博主实现过Brep布尔运算、BSP和八叉树两种网格布尔运算。详细可参考博主文章…

threejs加载高度图渲染点云,不支持tiff

问题点 使用的point来渲染高度图点云&#xff0c;大数据图片无效渲染点多&#xff08;可以通过八叉树过滤掉无效点增加效率&#xff0c;这个太复杂&#xff09;&#xff0c;但是胜在简单能用 效果图 code 代码可运行&#xff0c;无需npm <!DOCTYPE html> <html la…

Springboot + netty + rabbitmq + myBatis+mysql流量消峰

目录 0.为什么用消息队列1.代码文件创建结构2.pom.xml文件3.三个配置文件开发和生产环境4.Rabbitmq 基础配置类 TtlQueueConfig5.建立netty服务器 + rabbitmq消息生产者6.建立常规队列的消费者 Consumer7.建立死信队列的消费者 DeadLetterConsumer8.建立mapper.xml文件9.建立ma…

使用 Higress AI 插件对接通义千问大语言模型

前言 什么是 AI Gateway AI Gateway 的定义是 AI Native 的 API Gateway&#xff0c;是基于 API Gateway 的能⼒来满⾜ AI Native 的需求。例如&#xff1a; 将传统的 QPS 限流扩展到 token 限流。将传统的负载均衡/重试/fallback 能力延伸&#xff0c;支持对接多个大模型厂…

Xcode16 iOS18 编译问题适配

问题1&#xff1a;ADClient编译报错问题 报错信息 Undefined symbols for architecture arm64:"_OBJC_CLASS_$_ADClient", referenced from:in ViewController.o ld: symbol(s) not found for architecture arm64 clang: error: linker command failed with exit co…

【Redis】初识 Redis

&#x1f970;&#x1f970;&#x1f970;来都来了&#xff0c;不妨点个关注叭&#xff01; &#x1f449;博客主页&#xff1a;欢迎各位大佬!&#x1f448; 文章目录 1. Redis是什么2. 浅谈分布式3. Redis的特性3.1 在内存中存储3.2 可编程性3.3 扩展性3.4 持久化3.5 集群3.6 …

C++ 刷题 使用到的一些有用的容器和函数

优先队列 c优先队列priority_queue&#xff08;自定义比较函数&#xff09;_c优先队列自定义比较-CSDN博客 373. 查找和最小的 K 对数字 - 力扣&#xff08;LeetCode&#xff09; 官方题解&#xff1a; class Solution { public:vector<vector<int>> kSmallestP…

如何检测并阻止机器人活动

恶意机器人流量逐年增加&#xff0c;占 2023 年所有互联网流量的近三分之一。恶意机器人会访问敏感数据、实施欺诈、窃取专有信息并降低网站性能。新技术使欺诈者能够更快地发动攻击并造成更大的破坏。机器人的无差别和大规模攻击对所有行业各种规模的企业都构成风险。 但您的…