机器学习·最近邻方法(k-NN)

前言

上一篇简单介绍了决策树,而本篇讲解与决策树相近的 最近邻方法k-NN

机器学习·决策树-CSDN博客


一、算法原理对比
特性决策树最近邻方法(k-NN)
核心思想通过特征分割构建树结构,递归划分数据基于距离度量,用最近的k个样本投票预测
训练方式显式构建模型(预训练)惰性学习(无显式训练,预测时计算)
关键参数max_depthmin_samples_leafn_neighborsmetric(距离度量)
分割标准信息增益、基尼系数欧氏距离、曼哈顿距离、余弦相似度等
输出类型分类树(类别标签)、回归树(连续值)分类(多数投票)、回归(均值/中位数)

二、概念
  1. 决策树

    • 信息增益(Entropy)
      \( S = -\sum_{i=1}^N p_i \log_2 p_i \)
      选择分割时最大化信息增益,减少不确定性。

    • 基尼系数(Gini Index)
      \( G = 1 - \sum_{k} (p_k)^2 \)
      衡量数据不纯度,值越小分割越优。

    • 剪枝策略

      • 预剪枝:限制树深度(max_depth)、叶节点最小样本数(min_samples_leaf)。

      • 后剪枝:构建完整树后合并冗余节点。

  2. k-NN

    • 距离度量

      • 欧氏距离(默认):\( d(x,y) = \sqrt{\sum (x_i - y_i)^2} \)

      • 曼哈顿距离:\( d(x,y) = \sum |x_i - y_i| \)

      • 余弦相似度:衡量向量方向相似性。

    • 参数调优

      • n_neighbors:邻居数,小值易过拟合,大值易欠拟合。

      • weights:邻居权重(uniform均等权重,distance按距离反比加权)。


三、交叉验证与调优
  1. 交叉验证方法

    • k折交叉验证:数据分为k个子集,轮流用k-1个子集训练,1个子集验证,取平均性能。

    • 留出法:按比例(如70%-30%)划分训练集和验证集。

  2. GridSearchCV 参数调优

    from sklearn.model_selection import GridSearchCV# 决策树参数网格
    tree_params = {'max_depth': [3, 5, 7], 'max_features': [10, 20, 30]}
    tree_grid = GridSearchCV(DecisionTreeClassifier(), tree_params, cv=5)
    tree_grid.fit(X_train, y_train)# k-NN参数网格(需标准化)
    knn_pipe = Pipeline([('scaler', StandardScaler()), ('knn', KNeighborsClassifier())])
    knn_params = {'knn__n_neighbors': range(1, 10)}
    knn_grid = GridSearchCV(knn_pipe, knn_params, cv=5)
    knn_grid.fit(X_train, y_train)
     

四、实际应用与性能对比
  1. 客户流失预测任务

    • 数据集:电信客户流失数据(特征包括国际套餐、语音邮箱等)。

    • 结果对比

      模型留置集准确率交叉验证最佳准确率
      决策树(调优)94.6%94.0%
      k-NN(调优)89.0%88.5%
      随机森林95.3%93.5%
  2. MNIST手写数字识别

    • 数据集:8x8像素手写数字图片。

    • 结果对比

      模型留置集准确率交叉验证最佳准确率
      决策树(调优)84.4%66.6%
      k-NN(调优)98.7%97.6%
      随机森林93.4%-

五、优缺点总结
算法优点缺点
决策树1. 可解释性强,规则可视化
2. 支持类别/数值特征
3. 训练速度快
1. 对噪声敏感,易过拟合
2. 边界为轴平行,灵活性差
3. 无法外推
k-NN1. 简单易实现
2. 无需显式训练
3. 适应复杂边界(小k值)
1. 预测速度慢(大数据集)
2. 高维数据效果差(维度灾难)
3. 依赖距离度量

六、应用场景
  1. 选择决策树

    • 需要可解释性强的模型(如金融风控、医疗诊断)。

    • 数据特征存在明显分层逻辑(如年龄分段、阈值判断)。

    • 实时预测需求(快速推理)。

  2. 选择k-NN

    • 数据维度较低且分布复杂(如小规模图像分类)。

    • 需要快速原型验证(基线模型)。

    • 数据特征尺度一致(需标准化)。


七、结论
  1. 模型选择优先级

    • 优先尝试简单模型(如决策树、k-NN),再过渡到复杂模型(随机森林、神经网络)。

    • 决策树在结构化数据中表现优异,k-NN适合小规模非结构化数据。

  2. 调优核心

    • 决策树:控制深度(max_depth)和叶节点样本数(min_samples_leaf)。

    • k-NN:选择合适的邻居数(n_neighbors)和距离度量(metric)。

  3. 交叉验证必要性

    • 避免过拟合,确保模型泛化性,尤其在参数调优时不可或缺。


八、完整代码

1.客户流失预测任务

import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split, StratifiedKFold, GridSearchCV, cross_val_score
from sklearn.neighbors import KNeighborsClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.datasets import load_digits
from sklearn.tree import export_graphviz
import pydotplus
from io import StringIO
import matplotlib.pyplot as plt
from IPython.display import Image# 客户离网率预测任务
# 数据预处理
df = pd.read_csv('https://labfile.oss.aliyuncs.com/courses/1283/telecom_churn.csv')
df['International plan'] = pd.factorize(df['International plan'])[0]
df['Voice mail plan'] = pd.factorize(df['Voice mail plan'])[0]
df['Churn'] = df['Churn'].astype('int')
states = df['State']
y = df['Churn']
df.drop(['State', 'Churn'], axis=1, inplace=True)# 划分数据集
X_train, X_holdout, y_train, y_holdout = train_test_split(df.values, y, test_size=0.3, random_state=17)# 训练决策树和K近邻模型(随机参数)
tree = DecisionTreeClassifier(max_depth=5, random_state=17)
knn = KNeighborsClassifier(n_neighbors=10)
tree.fit(X_train, y_train)
knn.fit(X_train, y_train)# 模型评估
tree_pred = tree.predict(X_holdout)
print("决策树准确率(随机参数):", accuracy_score(y_holdout, tree_pred))
knn_pred = knn.predict(X_holdout)
print("K近邻准确率(随机参数):", accuracy_score(y_holdout, knn_pred))# 决策树交叉验证调优
tree_params = {'max_depth': range(5, 7),'max_features': range(16, 18)}
tree_grid = GridSearchCV(tree, tree_params, cv=5, n_jobs=-1, verbose=True)
tree_grid.fit(X_train, y_train)
print("决策树最佳参数:", tree_grid.best_params_)
print("决策树最佳分数:", tree_grid.best_score_)
print("决策树调优后准确率:", accuracy_score(y_holdout, tree_grid.predict(X_holdout)))# 绘制决策树
dot_data = StringIO()
export_graphviz(tree_grid.best_estimator_, feature_names=df.columns, out_file=dot_data, filled=True)
graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
Image(value=graph.create_png())# K近邻交叉验证调优
knn_pipe = Pipeline([('scaler', StandardScaler()), ('knn', KNeighborsClassifier(n_jobs=-1))])
knn_params = {'knn__n_neighbors': range(6, 8)}
knn_grid = GridSearchCV(knn_pipe, knn_params, cv=5, n_jobs=-1, verbose=True)
knn_grid.fit(X_train, y_train)
print("K近邻最佳参数:", knn_grid.best_params_)
print("K近邻最佳分数:", knn_grid.best_score_)
print("K近邻调优后准确率:", accuracy_score(y_holdout, knn_grid.predict(X_holdout)))# 训练随机森林模型
forest = RandomForestClassifier(n_estimators=100, n_jobs=-1, random_state=17)
print("随机森林交叉验证分数:", np.mean(cross_val_score(forest, X_train, y_train, cv=5)))
forest_params = {'max_depth': range(8, 10),'max_features': range(5, 7)}
forest_grid = GridSearchCV(forest, forest_params, cv=5, n_jobs=-1, verbose=True)
forest_grid.fit(X_train, y_train)
print("随机森林最佳参数:", forest_grid.best_params_)
print("随机森林最佳分数:", forest_grid.best_score_)
print("随机森林准确率:", accuracy_score(y_holdout, forest_grid.predict(X_holdout)))# 简单分类任务
# 生成数据
def form_linearly_separable_data(n=500, x1_min=0, x1_max=30, x2_min=0, x2_max=30):data, target = [], []for i in range(n):x1 = np.random.randint(x1_min, x1_max)x2 = np.random.randint(x2_min, x2_max)if np.abs(x1 - x2) > 0.5:data.append([x1, x2])target.append(np.sign(x1 - x2))return np.array(data), np.array(target)X, y = form_linearly_separable_data()
plt.scatter(X[:, 0], X[:, 1], c=y, cmap='autumn', edgecolors='black')# 训练决策树并绘制分类边界
tree = DecisionTreeClassifier(random_state=17).fit(X, y)def get_grid(X):x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.1), np.arange(y_min, y_max, 0.1))return xx, yyxx, yy = get_grid(X)
predicted = tree.predict(np.c_[xx.ravel(), yy.ravel()]).reshape(xx.shape)
plt.pcolormesh(xx, yy, predicted, cmap='autumn')
plt.scatter(X[:, 0], X[:, 1], c=y, s=100, cmap='autumn', edgecolors='black', linewidth=1.5)
plt.title('Easy task. Decision tree compexifies everything')# 可视化决策树
dot_data = StringIO()
export_graphviz(tree, feature_names=['x1', 'x2'], out_file=dot_data, filled=True)
graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
Image(value=graph.create_png())# 训练K近邻模型
knn = KNeighborsClassifier(n_neighbors=1).fit(X, y)
xx, yy = get_grid(X)
predicted = knn.predict(np.c_[xx.ravel(), yy.ravel()]).reshape(xx.shape)
plt.pcolormesh(xx, yy, predicted, cmap='autumn')
plt.scatter(X[:, 0], X[:, 1], c=y, s=100, cmap='autumn', edgecolors='black', linewidth=1.5)
plt.title('Easy task, kNN. Not bad')# MNIST手写数字识别任务
# 加载数据
data = load_digits()
X, y = data.data, data.target# 绘制MNIST手写数字
f, axes = plt.subplots(1, 4, sharey=True, figsize=(16, 6))
for i in range(4):axes[i].imshow(X[i, :].reshape([8, 8]), cmap='Greys')# 划分数据集
X_train, X_holdout, y_train, y_holdout = train_test_split(X, y, test_size=0.3, random_state=17)# 训练决策树和K近邻模型(随机参数)
tree = DecisionTreeClassifier(max_depth=5, random_state=17)
knn_pipe = Pipeline([('scaler', StandardScaler()), ('knn', KNeighborsClassifier(n_neighbors=10))])
tree.fit(X_train, y_train)
knn_pipe.fit(X_train, y_train)# 模型预测与评估
tree_pred = tree.predict(X_holdout)
knn_pred = knn_pipe.predict(X_holdout)
print("MNIST任务中决策树准确率(随机参数):", accuracy_score(y_holdout, tree_pred))
print("MNIST任务中K近邻准确率(随机参数):", accuracy_score(y_holdout, knn_pred))# 决策树交叉验证调优
tree_params = {'max_depth': [10, 20, 30],'max_features': [30, 50, 64]}
tree_grid = GridSearchCV(tree, tree_params, cv=5, n_jobs=-1, verbose=True)
tree_grid.fit(X_train, y_train)
print("MNIST任务中决策树最佳参数:", tree_grid.best_params_)
print("MNIST任务中决策树最佳分数:", tree_grid.best_score_)# K近邻交叉验证调优
print("MNIST任务中K近邻交叉验证分数:", np.mean(cross_val_score(KNeighborsClassifier(n_neighbors=1), X_train, y_train, cv=5)))# 训练随机森林模型
print("MNIST任务中随机森林交叉验证分数:", np.mean(cross_val_score(RandomForestClassifier(random_state=17), X_train, y_train, cv=5)))# 最近邻方法复杂情形
# 生成数据
def form_noisy_data(n_obj=1000, n_feat=100, random_seed=17):np.seed = random_seedy = np.random.choice([-1, 1], size=n_obj)x1 = 0.3 * yx_other = np.random.random(size=[n_obj, n_feat - 1])return np.hstack([x1.reshape([n_obj, 1]), x_other]), yX, y = form_noisy_data()# 划分数据集
X_train, X_holdout, y_train, y_holdout = train_test_split(X, y, test_size=0.3, random_state=17)# 训练K近邻模型并绘制验证曲线
cv_scores, holdout_scores = [], []
n_neighb = [1, 2, 3, 5] + list(range(50, 550, 50))for k in n_neighb:knn_pipe = Pipeline([('scaler', StandardScaler()), ('knn', KNeighborsClassifier(n_neighbors=k))])cv_scores.append(np.mean(cross_val_score(knn_pipe, X_train, y_train, cv=5)))knn_pipe.fit(X_train, y_train)holdout_scores.append(accuracy_score(y_holdout, knn_pipe.predict(X_holdout)))plt.plot(n_neighb, cv_scores, label='CV')
plt.plot(n_neighb, holdout_scores, label='holdout')
plt.title('Easy task. kNN fails')
plt.legend()# 决策树训练与评估
tree = DecisionTreeClassifier(random_state=17, max_depth=1)
tree_cv_score = np.mean(cross_val_score(tree, X_train, y_train, cv=5))
tree.fit(X_train, y_train)
tree_holdout_score = accuracy_score(y_holdout, tree.predict(X_holdout))
print('Decision tree. CV: {}, holdout: {}'.format(tree_cv_score, tree_holdout_score))

2.MNIST手写数字识别

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split, GridSearchCV, cross_val_score
from sklearn.tree import DecisionTreeClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score# 载入 MNIST 手写数字数据集
data = load_digits()
X, y = data.data, data.target# 查看第一个样本的 8x8 矩阵形式
print(X[0, :].reshape([8, 8]))# 绘制一些 MNIST 手写数字
f, axes = plt.subplots(1, 4, sharey=True, figsize=(16, 6))
for i in range(4):axes[i].imshow(X[i, :].reshape([8, 8]), cmap='Greys')
plt.show()# 分割数据集
X_train, X_holdout, y_train, y_holdout = train_test_split(X, y, test_size=0.3, random_state=17)# 使用随机参数训练决策树和 k-NN
tree = DecisionTreeClassifier(max_depth=5, random_state=17)
knn_pipe = Pipeline([('scaler', StandardScaler()),('knn', KNeighborsClassifier(n_neighbors=10))])tree.fit(X_train, y_train)
knn_pipe.fit(X_train, y_train)# 在留置集上做出预测并评估
tree_pred = tree.predict(X_holdout)
knn_pred = knn_pipe.predict(X_holdout)
tree_accuracy = accuracy_score(y_holdout, tree_pred)
knn_accuracy = accuracy_score(y_holdout, knn_pred)
print(f"决策树(随机参数)在留置集上的准确率: {tree_accuracy}")
print(f"k-NN(随机参数)在留置集上的准确率: {knn_accuracy}")# 使用交叉验证调优决策树模型
tree_params = {'max_depth': [10, 20, 30],'max_features': [30, 50, 64]}tree_grid = GridSearchCV(tree, tree_params,cv=5, n_jobs=-1, verbose=True)tree_grid.fit(X_train, y_train)# 查看交叉验证得到的最佳参数组合和相应的准确率
best_tree_params = tree_grid.best_params_
best_tree_score = tree_grid.best_score_
print(f"决策树最佳参数: {best_tree_params}")
print(f"决策树最佳交叉验证准确率: {best_tree_score}")# 使用交叉验证调优 k-NN 模型
knn_cv_score = np.mean(cross_val_score(KNeighborsClassifier(n_neighbors=1), X_train, y_train, cv=5))
print(f"调优后 k-NN 的交叉验证准确率: {knn_cv_score}")# 训练随机森林模型
forest_cv_score = np.mean(cross_val_score(RandomForestClassifier(random_state=17), X_train, y_train, cv=5))
print(f"随机森林的交叉验证准确率: {forest_cv_score}")

结果:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/18667.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Deesek:新一代数据处理与分析框架实战指南

Deesek:新一代数据处理与分析框架实战指南 引言 在大数据时代,高效处理和分析海量数据是企业和开发者面临的核心挑战。传统工具如Pandas、Spark等虽功能强大,但在实时性、易用性或性能上仍有提升空间。Deesek(假设名称&#xff…

【Vue】打包vue3+vite项目发布到github page的完整过程

文章目录 第一步:打包第二步:github仓库设置第三步:安装插件gh-pages第四步:两个配置第五步:上传github其他问题1. 路由2.待补充 参考文章: 环境: vue3vite windows11(使用终端即可&…

JVM内存模型详解

文章目录 1. 程序计数器(Program Counter Register)2. Java虚拟机栈(Java Virtual Machine Stacks)3. 本地方法栈(Native Method Stacks)4. Java堆(Java Heap)5. 方法区(…

KubeSphere 和 K8s 高可用集群离线部署全攻略

本文首发:运维有术,作者术哥。 今天,我们将一起探索如何在离线环境中部署 K8s v1.30.6 和 KubeSphere v4.1.2 高可用集群。对于离线环境的镜像仓库管理,官方推荐使用 Harbor 作为镜像仓库管理工具,它为企业级用户提供…

代码随想录-训练营-day30

今天我们要进入动态规划的背包问题,背包问题也是一类经典问题了。总的来说可以分为: 今天让我们先来复习0-1背包的题目,这也是所有背包问题的基础。所谓的0-1背包问题一般来说就是给一个背包带有最大容量,然后给一个物体对应的需要…

百问网(100ask)提供的烧写工具的原理和详解;将自己编译生成的u-boot镜像文件烧写到eMMC中

百问网(100ask)提供的烧写工具的原理 具体的实现原理见链接 http://wiki.100ask.org/100ask_imx6ull_tool 为了防止上面这个链接失效,我还对上面这个链接指向的页面保存成了mhtml文件,这个mhtml文件的百度网盘下载链接: https://pan.baidu.c…

【旋转框目标检测】基于YOLO11/v8深度学习的遥感视角船只智能检测系统设计与实现【python源码+Pyqt5界面+数据集+训练代码】

《------往期经典推荐------》 一、AI应用软件开发实战专栏【链接】 项目名称项目名称1.【人脸识别与管理系统开发】2.【车牌识别与自动收费管理系统开发】3.【手势识别系统开发】4.【人脸面部活体检测系统开发】5.【图片风格快速迁移软件开发】6.【人脸表表情识别系统】7.【…

侯捷 C++ 课程学习笔记:C++ 面向对象开发的艺术

在侯捷老师的 C 系列课程中,《C 面向对象开发》这门课程让我对面向对象编程有了更深入的理解。面向对象编程(OOP)是现代软件开发中最重要的编程范式之一,而 C 作为支持 OOP 的语言,提供了强大的工具和特性。侯捷老师通…

神经网络常见激活函数 12-Swish函数

Swish 函数导函数 Swish函数 S w i s h ( x ) x ⋅ σ ( β x ) x 1 e − β x \begin{aligned} \rm Swish(x) & x \cdot \sigma(\beta x) \\ & \frac{x}{1 e^{-\beta x}} \end{aligned} Swish(x)​x⋅σ(βx)1e−βxx​​ Swish函数导数 d d x S w i s h ( x…

CF 137B.Permutation(Java 实现)

题目分析 输入n个样本,将样本调整为从1到n的包含,需要多少此更改 思路分析 由于样本量本身就是n,无论怎么给数据要么是重复要么不在1到n的范围,只需要遍历1到n判断数据组中有没有i值即可。 代码 import java.util.*;public clas…

web第三次作业

弹窗案例 1.首页代码 <!DOCTYPE html><html lang"en"><head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>综合案例</title><st…

go语言简单快速的按顺序遍历kv结构(map)

文章目录 需求描述用map实现按照map的key排序用二维切片实现用结构体实现 需求描述 在go语言中&#xff0c;如果需要对map遍历&#xff0c;每次输出的顺序是不固定的&#xff0c;可以考虑存储为二维切片或结构体。 假设现在需要在页面的下拉菜单中展示一些基础的选项&#xff…

Unity 命令行设置运行在指定的显卡上

设置运行在指定的显卡上 -force-device-index

分享一个使用的音频裁剪chrome扩展-Ringtone Maker

一、插件简介 铃声制作器是一个简单易用的 Chrome 扩展&#xff0c;专门用于制作手机铃声。它支持裁剪音频文件的特定片段&#xff0c;并将其下载为 WAV 格式&#xff0c;方便我们在手机上使用。无论是想从一段长音频中截取精彩部分作为铃声&#xff0c;还是对现有的音频进行个…

数据开放共享和平台整合优化取得实质性突破的智慧物流开源了

智慧物流视频监控平台是一款功能强大且简单易用的实时算法视频监控系统。它的愿景是最底层打通各大芯片厂商相互间的壁垒&#xff0c;省去繁琐重复的适配流程&#xff0c;实现芯片、算法、应用的全流程组合&#xff0c;从而大大减少企业级应用约95%的开发成本可通过边缘计算技术…

预留:大数据Hadoop之——部署hadoop+hive+Mysql环境(Linux)

传送门目录 前期准备 一、JDK的安装 1、安装jdk 2、配置Java环境变量 3、加载环境变量 4、进行校验 二、hadoop的集群搭建 1、hadoop的下载安装 2、配置文件设置 2.1. 配置 hadoop-env.sh 2.2. 配置 core-site.xml 2.3. 配置hdfs-site.xml 2.4. 配置 yarn-site.xm…

《Spring实战》(第6版)第1章 Spring起步

第1部分 Spring基础 第1章 Spring起步 1.1 什么是Spring Spring的核心是提供一个容器(container)。 称为Spring应用上下文(Spring application context)。 创建和管理应用的组件(bean)&#xff0c;与上下文装配在一起。 Bean装配通过依赖注入(Dependency Injection,DI)。…

DesignCon2019 Paper分享--Automotive 芯片封装的SIPI优化

本期分享一篇intel在DesignCon2019上发表的介绍汽车芯片封装SIPI优化的paper--《Signal/Power Integrity Optimizations In An IoT Automotive Package》,文章主要介绍汽车芯片在SIPI上面临的挑战并提出了一些优化措施。 汽车芯片的发展趋势 如今&#xff0c;消费者对于车内用…

技术评测:MaxCompute MaxFrame——阿里云自研分布式计算框架的Python编程接口

引言 随着大数据和人工智能技术的发展&#xff0c;数据处理的需求日益增长。阿里云推出的MaxCompute MaxFrame&#xff08;简称“MaxFrame”&#xff09;是一个专为Python开发者设计的分布式计算框架&#xff0c;它不仅支持Python编程接口&#xff0c;还能直接利用MaxCompute的…

优选算法《位运算》

在本篇当中我们将会复习之前在C语言阶段学习的各种位运算&#xff0c;并且在复习当中将再补充一些在算法题当中没有进行总结的位运算的使用方法&#xff0c;再总结完常见的位运算使用方法之和接下来还是和之前的算法篇章一样通过几道算法题来对这些位运算的方法技巧进行巩固。在…