3.22【机器学习】决策树作业代码实现

4.1由于决策树只在样本同属于一类或者所有特征值都用完或缺失时生成叶节点,同一节点的样本,在路径上的特征值都相同,而训练集中又没有冲突数据,所以必定存在训练误差为0的决策树

4.2使用最小训练误差会导致过拟合,使得学习模型泛化能力下降。

4.3

青绿0,乌黑1,浅白2

蜷缩0,稍蜷1,硬挺2

浊响0,沉闷1,清脆2

清晰0,烧糊1,模胡2

凹陷0,稍凹1,平坦2

硬滑0,软粘1

import numpy as np
import treePlotter
np.random.seed(100)
class DecisionTreeClassifier:def __init__(self,tree_type='ID3',k_classes=2):self.tree_type=tree_typeself.k_classes=k_classesif tree_type=='ID3':self.gain_func=self.Gainelif tree_type=='CART':self.gain_func=self.GiniIndexelif tree_type=='C45':self.gain_func=self.GainRatioelse:raise ValueError('must be ID3 or CART or C45')self.tree=Nonedef fit(self,X,y):D={}D['X']=XD['y']=yA=np.arange(X.shape[1])aVs={}for a in A:aVs[a]=np.unique(X[:,a])self.tree=self.TreeGenerate(D,A,aVs)def predict(self,X):if self.tree is None:raise RuntimeError('cant predict before fit')y_pred=[]for i in range(X.shape[0]):tree = self.treex=X[i]while True:if not isinstance(tree,dict):y_pred.append(tree)breaka=list(tree.keys())[0]tree=tree[a]if isinstance(tree,dict):val = x[a]tree = tree[val]else:y_pred.append(tree)breakreturn np.array(y_pred)def TreeGenerate(self,D,A,aVs):X=D['X']y=D['y']# 情形1unique_classes=np.unique(y)if len(unique_classes)==1:return unique_classes[0]flag=Truefor a in A:if(len(np.unique(X[:,a]))>1):flag=Falsebreak# 情形2if flag:return np.argmax(np.bincount(y))gains=np.zeros((len(A),))if self.tree_type=='C45':gains=np.zeros((len(A),2))for i in range(len(A)):gains[i]=self.gain_func(D,A[i])#print(gains)subA=Noneif self.tree_type=='CART':a_best=A[np.argmin(gains)]subA=np.delete(A,np.argmin(gains))elif self.tree_type=='ID3':a_best=A[np.argmax(gains)]subA=np.delete(A,np.argmax(gains))elif self.tree_type=='C45':gain_mean=np.mean(gains[:,0])higher_than_mean_indices=np.where(gains[:,0]>=gain_mean)higher_than_mean=gains[higher_than_mean_indices,1][0]index=higher_than_mean_indices[0][np.argmax(higher_than_mean)]a_best=A[index]subA=np.delete(A,index)tree={a_best:{}}for av in aVs[a_best]:indices=np.where(X[:,a_best]==av)Dv={}Dv['X']=X[indices]Dv['y']=y[indices]if len(Dv['y'])==0:tree[a_best][av]=np.argmax(np.bincount(y))else:tree[a_best][av]=self.TreeGenerate(Dv,subA,aVs)return tree@classmethoddef Ent(cls,D):y=D['y']bin_count=np.bincount(y)total=len(y)ent=0.for k in range(len(bin_count)):p_k=bin_count[k]/totalif p_k!=0:ent+=p_k*np.log2(p_k)return -ent@classmethoddef Gain(cls,D,a):X=D['X']y=D['y']aV=np.unique(X[:,a])sum=0.for v in range(len(aV)):Dv={}indices=np.where(X[:,a]==aV[v])Dv['X']=X[indices]Dv['y']=y[indices]ent=cls.Ent(Dv)sum+=(len(Dv['y'])/len(y)*ent)gain=cls.Ent(D)-sumreturn gain@classmethoddef Gini(cls,D):y = D['y']bin_count = np.bincount(y)total = len(y)ent = 0.for k in range(len(bin_count)):p_k = bin_count[k] / totalent+=p_k**2return 1-ent@classmethoddef GiniIndex(cls,D,a):X = D['X']y = D['y']aV = np.unique(X[:, a])sum = 0.for v in range(len(aV)):Dv = {}indices = np.where(X[:, a] == aV[v])Dv['X'] = X[indices]Dv['y'] = y[indices]ent = cls.Gini(Dv)sum += (len(Dv['y']) / len(y) * ent)gain = sumreturn gain@classmethoddef GainRatio(cls,D,a):X = D['X']y = D['y']aV = np.unique(X[:, a])sum = 0.intrinsic_value=0.for v in range(len(aV)):Dv = {}indices = np.where(X[:, a] == aV[v])Dv['X'] = X[indices]Dv['y'] = y[indices]ent = cls.Ent(Dv)sum += (len(Dv['y']) / len(y) * ent)intrinsic_value+=(len(Dv['y'])/len(y))*np.log2(len(Dv['y'])/len(y))gain = cls.Ent(D) - sumintrinsic_value=-intrinsic_valuegain_ratio=gain/intrinsic_valuereturn np.array([gain,gain_ratio])if __name__=='__main__':watermelon_data = np.array([[0, 0, 0, 0, 0, 0], [1, 0, 1, 0, 0, 0],[1, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0],[2, 0, 0, 0, 0, 0], [0, 1, 0, 0, 1, 1],[1, 1, 0, 1, 1, 1], [1, 1, 0, 0, 1, 0],[1, 1, 1, 1, 1, 0], [0, 2, 2, 0, 2, 1],[2, 2, 2, 2, 2, 0], [2, 0, 0, 2, 2, 1],[0, 1, 0, 1, 0, 0], [2, 1, 1, 1, 0, 0],[1, 1, 0, 0, 1, 1], [2, 0, 0, 2, 2, 0],[0, 0, 1, 1, 1, 0]])label = np.array([1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0])X_test=np.array([[0, 0, 1, 0, 0, 0], [1, 0, 1, 0, 0, 0],[1, 1, 0, 1, 1, 0], [1, 0, 1, 1, 1, 0],[1, 1, 0, 0, 1, 1], [2, 0, 0, 2, 2, 0],[0, 0, 1, 1, 1, 0]])decision_clf=DecisionTreeClassifier(tree_type='ID3')decision_clf.fit(watermelon_data,label)print(decision_clf.tree)treePlotter.createPlot(decision_clf.tree)y_pred=decision_clf.predict(X_test)print('y_pred:',y_pred)
import matplotlib.pyplot as plt
from pylab import mpl
mpl.rcParams['font.sans-serif'] = ['FangSong'] 
mpl.rcParams['axes.unicode_minus'] = False 
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")def plotNode(nodeTxt, centerPt, parentPt, nodeType):createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', \xytext=centerPt, textcoords='axes fraction', \va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)def getNumLeafs(myTree):numLeafs = 0firstStr = list(myTree.keys())[0]secondDict = myTree[firstStr]for key in secondDict.keys():if type(secondDict[key]).__name__ == 'dict':numLeafs += getNumLeafs(secondDict[key])else:numLeafs += 1return numLeafsdef getTreeDepth(myTree):maxDepth = 0firstStr = list(myTree.keys())[0]secondDict = myTree[firstStr]for key in secondDict.keys():if type(secondDict[key]).__name__ == 'dict':thisDepth = getTreeDepth(secondDict[key]) + 1else:thisDepth = 1if thisDepth > maxDepth:maxDepth = thisDepthreturn maxDepthdef plotMidText(cntrPt, parentPt, txtString):xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]createPlot.ax1.text(xMid, yMid, txtString)def plotTree(myTree, parentPt, nodeTxt):numLeafs = getNumLeafs(myTree)depth = getTreeDepth(myTree)firstStr = list(myTree.keys())[0]cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalw, plotTree.yOff)plotMidText(cntrPt, parentPt, nodeTxt)plotNode(firstStr, cntrPt, parentPt, decisionNode)secondDict = myTree[firstStr]plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalDfor key in secondDict.keys():if type(secondDict[key]).__name__ == 'dict':plotTree(secondDict[key], cntrPt, str(key))else:plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalwplotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalDdef createPlot(inTree):fig = plt.figure(1, facecolor='white')fig.clf()axprops = dict(xticks=[], yticks=[])createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)plotTree.totalw = float(getNumLeafs(inTree))plotTree.totalD = float(getTreeDepth(inTree))plotTree.xOff = -0.5 / plotTree.totalwplotTree.yOff = 1.0plotTree(inTree, (0.5, 1.0), '')plt.show()

4.5

import numpy as np
import treePlotter
import sklearn.datasets as datasets
from sklearn.metrics import mean_squared_error
import sklearn.tree as tree
import graphvizclass DecisionTreeRegressor:def __init__(self, min_samples_split=3,min_samples_leaf=1,random_state=False):self.min_samples_split=min_samples_splitself.min_samples_leaf=min_samples_leafself.random=random_stateself.tree = Nonedef fit(self, X, y):D = {}D['X'] = XD['y'] = yA = np.arange(X.shape[1])self.tree = self.TreeGenerate(D, A)def predict(self, X):if self.tree is None:raise RuntimeError('cant predict before fit')y_pred = []for i in range(X.shape[0]):tree = self.treex = X[i]while True:if not isinstance(tree, dict):y_pred.append(tree)breaka = list(tree.keys())[0]tree = tree[a]if isinstance(tree, dict):val = x[a]split_val=float(list(tree.keys())[0][1:])if val<=split_val:tree=tree[list(tree.keys())[0]]else:tree=tree[list(tree.keys())[1]]else:y_pred.append(tree)breakreturn np.array(y_pred)def TreeGenerate(self, D, A):X = D['X']y = D['y']if len(y)<=self.min_samples_split:return np.mean(y)split_j=Nonesplit_s=Nonemin_val=1.e10select_A=Aif self.random is True:d=len(A)select_A=np.random.choice(A,size=int(d//2),replace=False)for j in select_A:for s in np.unique(X[:,j]):left_indices=np.where(X[:,j]<=s)[0]right_indices=np.where(X[:,j]>s)[0]if len(left_indices)<self.min_samples_leaf or len(right_indices)<self.min_samples_leaf:continueval=np.sum((y[left_indices]-np.mean(y[left_indices]))**2)+np.sum((y[right_indices]-np.mean(y[right_indices]))**2)if val<min_val:split_j=jsplit_s=smin_val=valif split_j is None:return np.mean(y)tree = {split_j: {}}left_indices=np.where(X[:,split_j]<=split_s)[0]right_indices=np.where(X[:,split_j]>split_s)[0]D_left, D_right = {},{}D_left['X'],D_left['y'] = X[left_indices],y[left_indices]D_right['X'],D_right['y']=X[right_indices],y[right_indices]tree[split_j]['l'+str(split_s)]=self.TreeGenerate(D_left,A)tree[split_j]['r'+str(split_s)]=self.TreeGenerate(D_right,A)# 当前节点值tree[split_j]['val']=np.mean(y)return treeif __name__=='__main__':breast_data = datasets.load_boston()X, y = breast_data.data, breast_data.targetX_train, y_train = X[:200], y[:200]X_test, y_test = X[200:], y[200:]decisiontree_reg=DecisionTreeRegressor(min_samples_split=20,min_samples_leaf=5)decisiontree_reg.fit(X_train,y_train)print(decisiontree_reg.tree)treePlotter.createPlot(decisiontree_reg.tree)y_pred=decisiontree_reg.predict(X_test)print('tinyml mse:',mean_squared_error(y_test,y_pred))sklearn_reg=tree.DecisionTreeRegressor(min_samples_split=20,min_samples_leaf=5,random_state=False)sklearn_reg.fit(X_train,y_train)print(sklearn_reg.feature_importances_)sklearn_pred=sklearn_reg.predict(X_test)print('sklearn mse:',mean_squared_error(y_test,sklearn_pred))dot_data=tree.export_graphviz(sklearn_reg,out_file=None)graph=graphviz.Source(dot_data)

4.9

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/481290.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

C语言:C语言实现对MySQL数据库表增删改查功能

基础DOME可以用于学习借鉴&#xff1b; 具体代码 #include <stdio.h> #include <mysql.h> // mysql 文件&#xff0c;如果配置ok就可以直接包含这个文件//宏定义 连接MySQL必要参数 #define SERVER "localhost" //或 127.0.0.1 #define USER "roo…

Spark基本命令详解

文章目录 Spark基本命令详解一、引言二、Spark Core 基本命令1、Transformations&#xff08;转换操作&#xff09;1.1、groupBy(func)1.2、filter(func) 2、Actions&#xff08;动作操作&#xff09;2.1、distinct([numTasks])2.2、sortBy(func, [ascending], [numTasks]) 三、…

github webhooks 实现网站自动更新

本文目录 Github Webhooks 介绍Webhooks 工作原理配置与验证应用云服务器通过 Webhook 自动部署网站实现复制私钥编写 webhook 接口Github 仓库配置 webhook以服务的形式运行 app.py Github Webhooks 介绍 Webhooks是GitHub提供的一种通知方式&#xff0c;当GitHub上发生特定事…

蓝桥杯模拟题不知名题目

题目:p是一个质数&#xff0c;但p是n的约数。将p称为是n的质因数。求2024最大质因数。 #include<iostream> #include<algorithm> using namespace std; bool fun(int x) {for(int i 2 ; i * i < x ; i){if(x % i 0)return false;}return true; } int main() …

【从零开始的LeetCode-算法】3264. K 次乘运算后的最终数组 I

给你一个整数数组 nums &#xff0c;一个整数 k 和一个整数 multiplier 。 你需要对 nums 执行 k 次操作&#xff0c;每次操作中&#xff1a; 找到 nums 中的 最小 值 x &#xff0c;如果存在多个最小值&#xff0c;选择最 前面 的一个。将 x 替换为 x * multiplier 。 请你…

Python 爬虫指定数据提取【Xpath】

Xpath 是一个非常好用的解析方法&#xff0c;使用前需要安装对应的库&#xff0c;这个自行搜索&#xff0c;很简单&#xff01; 示例代码 from lxml import etree text <div><ul><li class"item-0"><a href"link1.html">first …

C++设计模式(观察者模式)

一、介绍 1.动机 在软件构建过程中&#xff0c;我们需要为某些对象建立一种“通知依赖关系”&#xff0c;即一个对象的状态发生改变&#xff0c;所有的依赖对象&#xff08;观察者对象&#xff09;都将得到通知。如果这样的依赖关系过于紧密&#xff0c;将使软件不能很好地抵…

排序算法2

排序算法1-CSDN博客 排序算法1中提及的是较为基础(暴力实现&#xff0c;复杂度较高)的排序算法&#xff0c;不适合于数据量较大的场景&#xff0c;比如序列长度达到1e5 接下来以蓝桥另一道题目来理解其它的排序算法 蓝桥3226 蓝桥账户中心 样例 5 1 5 9 3 7 4、快速排序 快速排…

go结构体匿名“继承“方法冲突时继承优先顺序

在 Go 语言中&#xff0c;匿名字段&#xff08;也称为嵌入字段&#xff09;可以用来实现继承的效果。当你在一个结构体中匿名嵌入另一个结构体时&#xff0c;嵌入结构体的方法会被提升到外部结构体中。这意味着你可以直接通过外部结构体调用嵌入结构体的方法。 如果多个嵌入结…

Ubuntu Server 22.04.5 从零到一:详尽安装部署指南

文章目录 Ubuntu Server 22.04.5 从零到一&#xff1a;详尽安装部署指南一、部署环境二、安装系统2.1 安装2.1.1 选择安装方式2.1.2 选择语言2.1.3 选择不更新2.1.4 选择键盘标准2.1.5 选择安装版本2.1.6 设置网卡2.1.7 配置代理2.1.8 设置镜像源2.1.9 选择装系统的硬盘2.1.10 …

鸿蒙征文|鸿蒙技术分享:使用到的开发框架和技术概览

目录 每日一句正能量前言正文1. 开发环境搭建关键技术&#xff1a;2. 用户界面开发关键技术&#xff1a;3. 应用逻辑开发关键技术&#xff1a;4. 应用测试关键技术&#xff1a;5. 应用签名和打包关键技术&#xff1a;6. 上架流程关键技术&#xff1a;7. 后续维护和更新关键技术…

C++初阶—C++入门

第一章&#xff1a;C关键字(C98) C 总计 63个关键字&#xff0c;下面只是看一下 C 有多少关键字&#xff0c;不对关键字进行具体的讲解。 第二章&#xff1a;命名空间 #include <stdio.h> #include <stdlib.h>int rand 0;int main() {printf("%d\n", r…

马斯克的 AI 游戏工作室:人工智能与游戏产业的融合新纪元

近日&#xff0c;马斯克在 X 平台&#xff08;前身为 Twitter&#xff09;发文称&#xff0c;“太多游戏工作室被大型企业所拥有&#xff0c;xAI 将启动一个 AI 游戏工作室&#xff0c;让游戏再次变得精彩”。这一言论不仅展示了马斯克对游戏行业现状的不满&#xff0c;也揭示了…

数据库期末复习题库

1. Mysql日志功能有哪些? 记录日常操作和错误信息&#xff0c;以便了解Mysql数据库的运行情况&#xff0c;日常操作&#xff0c;错误信息和进行相关的优化。 2. 数据库有哪些备份方法 完全备份&#xff1a;全部都备份一遍表备份&#xff1a;只提取数据库中的数据&#xff0…

opencv 区域提取三种算法

opencv 区域提取三种算法 1.轮廓查找 findContours()函数&#xff0c;得到轮廓的点集集合 cv::vector<cv::vector<Point>> contours;threshold(roiMat,binImg,m_pPara.m_nMinGray,m_pPara.m_nMaxGray,THRESH_BINARY);//膨胀处理Mat dilaElement getStructuringE…

如何快速上手UPR ---查看资源检测报告

上一章说了如何快速使用资源检测器 那么如何修复我们 的不规范资源呢&#xff1f; 我们都知道一些常规的美术资源优化&#xff0c;但是还是会有一些没有注意到的点 导致我们游戏的性能降低 可以看到我们的Animation 的的不规范 检查动画曲线精度 &#xff0c;其实我觉得他是…

摄影相关常用名词

本文介绍与摄影相关的常用名词。 曝光 Exposure 感光元件接收光线的过程&#xff0c;决定图像的明暗程度和细节表现。 光圈 Aperture 控制镜头进光量的孔径大小&#xff0c;用 F 值&#xff08;f-stop&#xff09; 表示。 光圈越大&#xff08;F 值越小&#xff09;&#xff0c…

NeuIPS 2024 | YOCO的高效解码器-解码器架构

该研究提出了一种新的大模型架构&#xff0c;名为YOCO&#xff08;You Only Cache Once&#xff09;&#xff0c;其目的是解决长序列语言模型推理中的内存瓶颈。YOCO通过解码器-解码器结构的创新设计&#xff0c;显著减少推理时的显存占用并提升了长序列的处理效率。 现有大模…

webrtc视频会议学习(三)

文章目录 关联&#xff1a;源码搭建coturn服务器nginx配置ice配置需服务器要开放的端口 效果 关联&#xff1a; webrtcP2P音视频通话&#xff08;一&#xff09; webrtcP2P音视频通话&#xff08;二&#xff09; webrtc视频会议学习&#xff08;三&#xff09; 源码 WebRTC…

C++ 红黑树 【内含代码】

1. 红黑树 1.1 红黑树的概念 红黑树&#xff0c;是一种二叉搜索树&#xff0c;但在每个节点上增加一个存储为表示节点的颜色&#xff0c;可以使Red或Black。通过对任何一条从根到叶子的路径上各个节点着色方式的限制&#xff0c;红黑树确保没有一条路径会比其他路径长出两倍&…