sklearn中使用决策树

1.示例

criterion可以是信息熵,entropy,可以是基尼系数gini

# -*-coding:utf-8-*-
from sklearn import tree
from sklearn.datasets import load_wine
from sklearn.model_selection import train_test_split
wine=load_wine()# print ( wine.feature_names )
#(178, 13)
print(wine.data.shape)Xtrain,Xtest,Ytrain,Ytest=train_test_split(wine.data,wine.target,test_size=0.3)#random_state=30:输入任意整数,会一直长同一棵树,让模型稳定下来
clf=tree.DecisionTreeClassifier(criterion="entropy",random_state=30,splitter="best")
# clf=tree.DecisionTreeClassifier(criterion="entropy")
clf=clf.fit(Xtrain,Ytrain)
#返回预测准确度accuracy
score=clf.score(Xtest,Ytest)print( score )import graphviz
dot_data=tree.export_graphviz(clf,feature_names=wine.feature_names,class_names=["wine1","wine2","wine3"],filled=True,rounded=True)
graph=graphviz.Source(dot_data)
#生成pdf文件
graph.render(view=True, format="pdf", filename="tree_pdf")
print ( graph )
#feature_importances_:每个特征在决策树中的重要成都
print(clf.feature_importances_)
print ( [*zip(wine.feature_names,clf.feature_importances_)] )

决策树生成的pdf 

 2.示例

max_depth:这参数用来控制决策树的最大深度。以下示例,构建1~10深度的决策时,看哪个深度的决策树的精确率(score)高

# -*-coding:utf-8-*-
from sklearn import tree
from sklearn.datasets import load_wine
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as pltplt.switch_backend("TkAgg")wine=load_wine()# print ( wine.feature_names )
#(178, 13)
print(wine.data.shape)import pandas as pd
# print (pd.concat([pd.DataFrame(wine.data),pd.DataFrame(wine.target)],axis=1))
#所有的train,test必须是二维矩阵
Xtrain,Xtest,Ytrain,Ytest=train_test_split(wine.data,wine.target,test_size=0.3)test=[]
bestScore=-1
bestClf=None
for i in range(10):clf=tree.DecisionTreeClassifier(max_depth=i+1,criterion="entropy",random_state=30,splitter="random")clf=clf.fit(Xtrain,Ytrain)score=clf.score(Xtest,Ytest)test.append(score)if score>bestScore:bestScore=scorebestClf=clf
print(test)
print(test.index(bestScore))
#predict返回每个测试样本的分类/回归结果
predicted=bestClf.predict(Xtest)
print(predicted)#返回每个测试样本的叶子节点的索引
leaf=bestClf.apply(Xtest)
print(leaf)plt.plot(range(1,11),test,color="red",label="max_depth")
plt.legend()
plt.show()

结果:

(178, 13)
[0.5555555555555556, 0.8148148148148148, 0.9444444444444444, 0.9259259259259259, 0.8518518518518519, 0.8333333333333334, 0.8333333333333334, 0.8333333333333334, 0.8333333333333334, 0.8333333333333334]
2
[0 1 0 1 2 0 1 1 1 2 2 0 0 2 0 1 1 0 0 0 0 1 1 0 2 1 0 2 2 1 2 1 1 1 1 0 12 2 0 1 1 2 0 2 1 1 0 1 1 2 1 2 2]
[12  7 12 11  3 12  7  7  4  3  3 12 12  3 12  9  7 12 12 12 12  7  9 123  9 12  3  3  4  3  4  7  7  7 12  7  3  3 12  9  9  3 12  3  7  7 127  7  3  7  3  3]

3.交叉熵验证的示例 

# -*-coding:utf-8-*-
from sklearn.model_selection import cross_val_score
from sklearn.tree import DecisionTreeRegressor
import sklearn
from sklearn.datasets import fetch_california_housinghousing=fetch_california_housing()
# print(housing)
# print(housing.data)
# print(housing.target)regressor=DecisionTreeRegressor(random_state=0)#cv=10,10次交叉验证,default:cv=5
#scoring="neg_mean_squared_error",评价指标是负的均方误差
cross_res=cross_val_score(regressor,housing.data,housing.target,scoring="neg_mean_squared_error",cv=10)
print(cross_res)
[-1.30551334 -0.78405711 -0.72809865 -0.50413232 -0.79683323 -0.83698199-0.56591889 -1.03621067 -1.02786488 -0.51371889]

4.Titanic生存者预测

数据来源:

Titanic - Machine Learning from Disaster | Kaggle

数据预处理

读取数据 

import pandas as pd
from sklearn.tree import DecisionTreeClassifier
import matplotlib.pyplot as plt
from sklearn.model_selection import GridSearchCV
#---------设置pd,在pycharm中显示完全表格-------
pd.set_option('display.max_columns', 1000)
pd.set_option('display.width', 1000)
pd.set_option('display.max_colwidth', 1000)
#----------------------------------------
data=pd.read_csv("./data.csv")
print (data.head(5))
print(data.info())
PassengerId  Survived  Pclass                                                 Name     Sex   Age  SibSp  Parch            Ticket     Fare Cabin Embarked
0            1         0       3                              Braund, Mr. Owen Harris    male  22.0      1      0         A/5 21171   7.2500   NaN        S
1            2         1       1  Cumings, Mrs. John Bradley (Florence Briggs Thayer)  female  38.0      1      0          PC 17599  71.2833   C85        C
2            3         1       3                               Heikkinen, Miss. Laina  female  26.0      0      0  STON/O2. 3101282   7.9250   NaN        S
3            4         1       1         Futrelle, Mrs. Jacques Heath (Lily May Peel)  female  35.0      1      0            113803  53.1000  C123        S
4            5         0       3                             Allen, Mr. William Henry    male  35.0      0      0            373450   8.0500   NaN        S
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 891 entries, 0 to 890
Data columns (total 12 columns):#   Column       Non-Null Count  Dtype  
---  ------       --------------  -----  0   PassengerId  891 non-null    int64  1   Survived     891 non-null    int64  2   Pclass       891 non-null    int64  3   Name         891 non-null    object 4   Sex          891 non-null    object 5   Age          714 non-null    float646   SibSp        891 non-null    int64  7   Parch        891 non-null    int64  8   Ticket       891 non-null    object 9   Fare         891 non-null    float6410  Cabin        204 non-null    object 11  Embarked     889 non-null    object 
dtypes: float64(2), int64(5), object(5)
memory usage: 83.7+ KB
NoneProcess finished with exit code 0

筛选特征

data.drop(["Cabin","Name","Ticket"],inplace=True,axis=1)
print(data.head())
print(data.info())
   PassengerId  Survived  Pclass     Sex   Age  SibSp  Parch     Fare Embarked
0            1         0       3    male  22.0      1      0   7.2500        S
1            2         1       1  female  38.0      1      0  71.2833        C
2            3         1       3  female  26.0      0      0   7.9250        S
3            4         1       1  female  35.0      1      0  53.1000        S
4            5         0       3    male  35.0      0      0   8.0500        S
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 891 entries, 0 to 890
Data columns (total 9 columns):#   Column       Non-Null Count  Dtype  
---  ------       --------------  -----  0   PassengerId  891 non-null    int64  1   Survived     891 non-null    int64  2   Pclass       891 non-null    int64  3   Sex          891 non-null    object 4   Age          714 non-null    float645   SibSp        891 non-null    int64  6   Parch        891 non-null    int64  7   Fare         891 non-null    float648   Embarked     889 non-null    object 
dtypes: float64(2), int64(5), object(2)
memory usage: 62.8+ KB
None

处理缺失值

#年龄用均值填补
data["Age"]=data["Age"].fillna(data["Age"].mean())
print(data.info())
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 891 entries, 0 to 890
Data columns (total 9 columns):#   Column       Non-Null Count  Dtype  
---  ------       --------------  -----  0   PassengerId  891 non-null    int64  1   Survived     891 non-null    int64  2   Pclass       891 non-null    int64  3   Sex          891 non-null    object 4   Age          891 non-null    float645   SibSp        891 non-null    int64  6   Parch        891 non-null    int64  7   Fare         891 non-null    float648   Embarked     889 non-null    object 
dtypes: float64(2), int64(5), object(2)
memory usage: 62.8+ KB
None
#删除有缺失值的行,Embarked缺了两行
data=data.dropna()
print(data.info())
<class 'pandas.core.frame.DataFrame'>
Int64Index: 889 entries, 0 to 890
Data columns (total 9 columns):#   Column       Non-Null Count  Dtype  
---  ------       --------------  -----  0   PassengerId  889 non-null    int64  1   Survived     889 non-null    int64  2   Pclass       889 non-null    int64  3   Sex          889 non-null    object 4   Age          889 non-null    float645   SibSp        889 non-null    int64  6   Parch        889 non-null    int64  7   Fare         889 non-null    float648   Embarked     889 non-null    object 
dtypes: float64(2), int64(5), object(2)
memory usage: 69.5+ KB
None

处理非数值的列

查看非数值列的所有值

print(data["Embarked"].unique())
print(data["Sex"].unique())#------------结果如下----------
['S' 'C' 'Q']
['male' 'female']
labels=data["Embarked"].unique().tolist()
#x代表data[Embarked]的每一行的值,S-->0,C-->1,Q-->2
data["Embarked"]=data["Embarked"].apply(lambda x:labels.index(x))#把条件为True的转为int行
#也可以这样写:data.loc[:,"Sex"]=(data["Sex"]=="male").astype("int")
#male-->0,female-->1
data["Sex"]=(data["Sex"]=="male").astype("int")

提取数据

x=data.iloc[:, data.columns!="Survived"]
y=data.iloc[:,data.columns=="Survived"]#Xtrain:(622, 8)
#划分数据集和测试集
from sklearn.model_selection import train_test_split
Xtrain,Xtest,Ytrain,Ytest=train_test_split(x,y,test_size=0.3)#把索引变为从0~622
for i in [Xtrain,Xtest,Ytrain,Ytest]:i.index=range(i.shape[0])

第一种方法构建决策树

# clf=DecisionTreeClassifier(random_state=25)
# clf=clf.fit(Xtrain,Ytrain)
# score=clf.score(Xtest,Ytest)
# print(score)
from sklearn.model_selection import cross_val_score
# clf=DecisionTreeClassifier(random_state=25)
# score=cross_val_score(clf,x,y,cv=10).mean()
# print(score)tr=[]
te=[]
for i in range(10):clf=DecisionTreeClassifier(random_state=25,max_depth=i+1,criterion="entropy")clf=clf.fit(Xtrain,Ytrain)score_tr=clf.score(Xtrain,Ytrain)score_te=cross_val_score(clf,x,y,cv=10).mean()tr.append(score_tr)te.append(score_te)
print(max(te))
plt.plot(range(1,11),tr,color="red",label="train")
plt.plot(range(1,11),te,color="blue",label="test")
#1~10全部显示
plt.xticks(range(1,11))
plt.legend()
plt.show()

不同深度的决策树的测试集和训练集的表现 

 第二种方法构建决策树

import pandas as pd
from sklearn.tree import DecisionTreeClassifier
import matplotlib.pyplot as plt
plt.switch_backend("TkAgg")
from sklearn.model_selection import GridSearchCV
import numpy as np#---------设置pd,在pycharm中显示完全表格-------
pd.set_option('display.max_columns', 1000)
pd.set_option('display.width', 1000)
pd.set_option('display.max_colwidth', 1000)
#----------------------------------------
data=pd.read_csv("./data.csv")
# print (data.head(5))
# print(data.info())#去掉姓名、Cabin、票号的特征
data.drop(["Cabin","Name","Ticket"],inplace=True,axis=1)
# print(data.head())
# print(data.info())#处理缺失值
#年龄用均值填补
data["Age"]=data["Age"].fillna(data["Age"].mean())
# print(data.info())#删除有缺失值的行,Embarked缺了两行,所有的数据去掉不完整的行
data=data.dropna()
# print(data.info())# print(data["Embarked"].unique())
# print(data["Sex"].unique())labels=data["Embarked"].unique().tolist()
#x代表data[Embarked]的每一行的值,S-->0,C-->1,Q-->2
data["Embarked"]=data["Embarked"].apply(lambda x:labels.index(x))#把条件为True的转为int行
#也可以这样写:data.loc[:,"Sex"]=(data["Sex"]=="male").astype("int")
#male-->0,female-->1
data["Sex"]=(data["Sex"]=="male").astype("int")x=data.iloc[:, data.columns!="Survived"]
y=data.iloc[:,data.columns=="Survived"]#Xtrain:(622, 8)
#划分数据集和测试集
from sklearn.model_selection import train_test_split
Xtrain,Xtest,Ytrain,Ytest=train_test_split(x,y,test_size=0.3)#把索引变为从0~622
for i in [Xtrain,Xtest,Ytrain,Ytest]:i.index=range(i.shape[0])from sklearn.model_selection import cross_val_scoreclf=DecisionTreeClassifier(random_state=25)
#GridSearchCV:满足fit,score,交叉验证三个功能
#parameters:一串参数和这些参数对应的,我们希望网格搜索来搜索对应的参数的取值范围
parameters={"criterion":("gini","entropy"),"splitter":("best","random"),"max_depth":[*range(1,10)],"min_samples_leaf":[*range(1,50,5)],"min_impurity_decrease":[*np.linspace(0,0.5,20)]
}
GS=GridSearchCV(clf,parameters,cv=10)
gs=GS.fit(Xtrain,Ytrain)#从输入的参数和参数取值中,返回最佳组合
print(gs.best_params_)#网格搜索后的模型的评判标准
print(gs.best_score_)
{'criterion': 'entropy', 'max_depth': 3, 'min_impurity_decrease': 0.0, 'min_samples_leaf': 1, 'splitter': 'best'}
0.8297235023041475

这种方法构建的决策树的准确率比第一种的还低

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/85232.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【uniapp】uniapp打包H5(网页端):

文章目录 一、设置appid&#xff1a;二、设置router&#xff1a;三、打包&#xff1a;【1】[CLI 发行uni-app到H5&#xff1a;https://hx.dcloud.net.cn/cli/publish-h5](https://hx.dcloud.net.cn/cli/publish-h5)【2】HBuilderX 四、最终效果&#xff1a; 一、设置appid&…

(三)Node.js - 模块化

1. Node.js中的模块化 Node.js中根据模块来源不同&#xff0c;将模块分为了3大类&#xff0c;分别是&#xff1a; 内置模块&#xff1a;内置模块由Node.js官方提供的&#xff0c;例如fs、path、http等自定义模块&#xff1a;用户创建的每个.js文件&#xff0c;都是自定义模块…

Flink-串讲面试题

1. 概念 有状态的流式计算框架 可以处理源源不断的实时数据&#xff0c;数据以event为单位&#xff0c;就是一条数据。 2. 开发流程 先获取执行环境env&#xff0c;然后添加source数据源&#xff0c;转换成datastream&#xff0c;然后使用各种算子进行计算&#xff0c;使用s…

【从零学习python 】12.Python字符串操作与应用

文章目录 学习目标字符串介绍字符串表示方式小总结转义字符 下标和切片一、 下标/索引1. 如果想取出部分字符&#xff0c;那么可以通过下标的方法&#xff0c;&#xff08;注意在计算机中&#xff0c;下标从 0 开始&#xff09;2. 遍历3. 切片 进阶案例 学习目标 字符串的表示…

深度学习和OpenCV的对象检测(MobileNet SSD图像识别)

基于深度学习的对象检测时,我们主要分享以下三种主要的对象检测方法: Faster R-CNN(后期会来学习分享)你只看一次(YOLO,最新版本YOLO3,后期我们会分享)单发探测器(SSD,本节介绍,若你的电脑配置比较低,此方法比较适合R-CNN是使用深度学习进行物体检测的训练模型; 然而,…

JDBC数据库连接

目录 引言 一&#xff0c;基本概念 二&#xff0c;常用操作步骤 三&#xff0c;连接操作 引言 JDBC(Java DataBase Connectivity,java数据库连接)是一种用于执行SQL语句的Java API&#xff0c;可以为多种 关系数据库提供统一访问&#xff0c;它由一组用Java语言编写的类和接口…

【GPT-3 】创建能写博客的AI工具

一、说明 如何使用OpenAI API&#xff0c;GPT-3和Python创建AI博客写作工具。 在本教程中&#xff0c;我们将从 OpenAI API 中断的地方继续&#xff0c;并创建我们自己的 AI 版权工具&#xff0c;我们可以使用它使用 GPT-3 人工智能 &#xff08;AI&#xff09; API 创建独特的…

js玩儿爬虫

前言 提到爬虫可能大多都会想到python&#xff0c;其实爬虫的实现并不限制任何语言。 下面我们就使用js来实现&#xff0c;后端为express&#xff0c;前端为vue3。 实现功能 话不多说&#xff0c;先看结果&#xff1a; 这是项目链接&#xff1a;https://gitee.com/xi1213/w…

时序预测 | MATLAB实现BO-LSTM贝叶斯优化长短期记忆神经网络时间序列预测

时序预测 | MATLAB实现BO-LSTM贝叶斯优化长短期记忆神经网络时间序列预测 目录 时序预测 | MATLAB实现BO-LSTM贝叶斯优化长短期记忆神经网络时间序列预测效果一览基本介绍模型搭建程序设计参考资料 效果一览 基本介绍 MATLAB实现BO-LSTM贝叶斯优化长短期记忆神经网络时间序列预…

无涯教程-Perl - lock函数

描述 此函数将咨询锁放在共享变量或THING中包含的引用对象上,直到该锁超出范围。 lock()是一个"弱关键字":这意味着,如果您在调用该函数之前已通过该名称定义了该函数,则将改为调用该函数。 语法 以下是此函数的简单语法- lock THING返回值 此函数不返回任何值…

校对软件在司法系统中的应用:加强刑事文书审查

校对软件在司法系统中的应用可以加强刑事文书审查&#xff0c;提高文书的准确性和可靠性。 以下是校对软件在刑事文书审查方面的应用&#xff1a; 1.语法和拼写检查&#xff1a;校对软件可以自动检查刑事文书中的语法错误和拼写错误。这包括句子结构、主谓一致、动词形式等方面…

Nginx启动报错- Failed to start The nginx HTTP and reverse proxy server

根据日志&#xff0c;仍然出现 “bind() to 0.0.0.0:8888 failed (13: Permission denied)” 错误。这意味着 Nginx 仍然无法绑定到 8888 端口&#xff0c;即使使用 root 权限。 请执行以下操作来进一步排查问题&#xff1a; 确保没有其他进程占用 8888 端口&#xff1a;使用以…

【Tomcat】tomcat的多实例和动静分离

多实例&#xff1a; 在一台服务器上有多台Tomcat&#xff1b;就算是多实例 安装telnet服务&#xff0c;可以用来测试端口通信是否正常 yum -y install telnettelnet 192.168.220.112 80 tomcat的日志文件 cd /usr/local/tomcat/logsvim catalina.out Tomcat多实例部署&…

[免费在线] 将 PDF 转换为 Excel 或 Excel 转换为 PDF | 5 工具

有了免费的在线 PDF 转换器&#xff0c;您可以轻松免费在线将 PDF 转换为 Excel 或 Excel 转换为 PDF。这篇文章为您筛选了 5 个最常用的工具。要从存储介质恢复错误删除或丢失的 PDF 文档、Excel 电子表格、Word 文件或任何其他文件&#xff0c;您可以使用免费的数据恢复程序 …

vscode-启动cljs

打开vscode &#xff0c;打开cljs项目文件 先npm installvscode安装插件Calva: Clojure & ClojureScript启动REPL 选择Start yout project with a REPL and connect(a.k.a. jack) 后选择shadow-cljs&#xff0c;然后选择shadow&#xff0c;如果需要选择build的话&#xf…

海外电子商务源代码跨境系统开发,Java现成代码全开源

海外电子商务跨境系统的开发是一个复杂的过程&#xff0c;而利用现成的Java代码进行开发可以节省时间和成本。下面是海外电子商务跨境系统开发的全开源步骤。 第一步&#xff1a;需求分析和规划 在开发海外电子商务跨境系统之前&#xff0c;需要进行需求分析和规划。这包括确定…

MySQL多表连接查询3

目录 表结构 创建表 表数据 查询需求&#xff1a; 1.查询student表的所有记录 2.查询student表的第2条到4条记录 3.从student表查询所有学生的学号&#xff08;id&#xff09;、姓名&#xff08;name&#xff09;和院系&#xff08;department&#xff09;的信息 4.从s…

React使用antd的图片预览组件,点击哪个图片就预览哪个的设置

使用了官方推荐的相册模式的预览&#xff0c;但是点击预览之后&#xff0c;每次都是从图片列表的第一张开始预览&#xff0c;而不是点击哪张就从哪张开始预览&#xff1a; 所以这里我就封装了一下&#xff0c;对初始化预览的列表进行了逻辑处理&#xff1a; 当点击开始预览的…

竞赛项目 深度学习的水果识别 opencv python

文章目录 0 前言2 开发简介3 识别原理3.1 传统图像识别原理3.2 深度学习水果识别 4 数据集5 部分关键代码5.1 处理训练集的数据结构5.2 模型网络结构5.3 训练模型 6 识别效果7 最后 0 前言 &#x1f525; 优质竞赛项目系列&#xff0c;今天要分享的是 &#x1f6a9; 深度学习…

MongoDB 备份与恢复

1.1 MongoDB的常用命令 mongoexport / mongoimport mongodump / mongorestore 有以上两组命令在备份与恢复中进行使用。 1.1.1 导出工具mongoexport Mongodb中的mongoexport工具可以把一个collection导出成JSON格式或CSV格式的文件。可以通过参数指定导出的数据项&#xff0c…