【python】sklearn基础教程及示例

【python】sklearn基础教程及示例


Scikit-learn(简称sklearn)是一个非常流行的Python机器学习库,提供了许多常用的机器学习算法和工具。以下是一个基础教程的概述:


 1. 安装scikit-learn


首先,确保你已经安装了Python和pip,然后使用以下命令安装scikit-learn:

pip install -U scikit-learn

2. 导入库

在你的Python脚本或Jupyter Notebook中,首先导入scikit-learn库:

import sklearn

3. 加载数据

你可以加载各种数据集,包括样本数据集和真实世界数据集。例如,加载经典的鸢尾花数据集:

from sklearn.datasets import load_iris
iris = load_iris()
X = iris.data  # 特征矩阵
y = iris.target  # 目标向量

4. 数据预处理

在应用机器学习算法之前,通常需要进行一些数据预处理,例如特征缩放、特征选择、数据清洗等。以下是一些常用的数据预处理方法:

from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

5. 数据拆分

将数据集拆分为训练集和测试集:

from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

6. 建立模型

使用各种机器学习算法来建立模型,例如逻辑回归:

from sklearn.linear_model import LogisticRegression
model = LogisticRegression()
model.fit(X_train, y_train)

7. 模型评估

在训练模型之后,评估模型的性能,例如使用准确度评估:

from sklearn.metrics import accuracy_score
y_pred = model.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print(f"Accuracy: {accuracy}")

8. 交叉验证

使用交叉验证来评估模型的稳定性和泛化能力:

from sklearn.model_selection import cross_validate
result = cross_validate(model, X, y, cv=5)
print(result['test_score'])

sklearn示例

1.简单例子:鸢尾花分类

这是一个经典的机器学习任务,用于分类鸢尾花的种类。

load_iris 是一个经典的机器学习数据集,通常用于分类和聚类任务。这个数据集包含了三种不同种类的鸢尾花(Iris Setosa、Iris Versicolour 和 Iris Virginica)的信息,每种鸢尾花有四个特征:花萼长度、花萼宽度、花瓣长度和花瓣宽度。

具体来说,load_iris 数据集包含以下内容:

  • 150个样本:每种鸢尾花各50个样本。
  • 4个特征:花萼长度、花萼宽度、花瓣长度和花瓣宽度。
  • 目标标签:每个样本的目标类别标签,分别为0(Setosa)、1(Versicolour)和2(Virginica)。
  • StandardScaler 是 scikit-learn 库中的一个类,用于对数据进行标准化处理。标准化的目的是将数据的特征缩放到相同的尺度,通常是均值为0,标准差为1。这对于许多机器学习算法来说是非常重要的,特别是那些基于距离的算法(如K-近邻、支持向量机等)和需要计算协方差矩阵的算法(如PCA、线性回归等)。

# 导入必要的库
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score# 加载数据集
iris = load_iris()
X = iris.data
y = iris.target# 数据预处理
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)# 拆分数据集
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.2, random_state=42)# 建立和训练模型
model = LogisticRegression()
model.fit(X_train, y_train)# 预测和评估
y_pred = model.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print(f"Accuracy: {accuracy}")

2.复杂例子:手写数字识别

这个例子使用手写数字数据集,并应用支持向量机(SVM)进行分类。

load_digits 是 scikit-learn 提供的一个经典数据集,用于手写数字识别任务。这个数据集包含了 0 到 9 共 10 个数字的手写图像,每个图像是一个 8x8 的灰度图像。

  • 数据集内容 样本数量:1797 个手写数字图像。
  • 特征维度:每个图像有 64 个特征(8x8 像素)。
  • 特征值:每个特征值是一个整数,范围从 0 到 16,表示像素的灰度值。
  • 目标标签:每个样本对应一个目标标签,表示数字 0 到 9。

# 导入必要的库
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
from sklearn.metrics import classification_report# 加载数据集
digits = load_digits()
X = digits.data
y = digits.target# 数据预处理
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)# 拆分数据集
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.2, random_state=42)# 使用网格搜索进行超参数调优
param_grid = {'C': [0.1, 1, 10, 100], 'gamma': [1, 0.1, 0.01, 0.001]}
grid = GridSearchCV(SVC(), param_grid, refit=True, verbose=2)
grid.fit(X_train, y_train)# 最佳参数和模型评估
print(f"Best Parameters: {grid.best_params_}")
y_pred = grid.predict(X_test)
print(classification_report(y_test, y_pred))

在这个复杂的例子中,我们使用了网格搜索(GridSearchCV)来找到支持向量机(SVM)的最佳超参数,并使用分类报告(classification_report)来评估模型的性能。

  • param_grid:这是一个字典,定义了要搜索的参数范围。在这个例子中,我们要调整两个参数:
    • C:正则化参数,控制模型的复杂度。较小的 C 值会使模型更简单,但可能欠拟合;较大的 C 值会使模型更复杂,但可能过拟合。
    • gamma:核函数系数,控制单个训练样本的影响范围。较大的 gamma 值会使模型更复杂,但可能过拟合;较小的 gamma 值会使模型更简单,但可能欠拟合。
  • GridSearchCV:这是 scikit-learn 提供的一个工具,用于通过交叉验证来搜索最佳参数组合。
    • SVC():支持向量机分类器。
    • param_grid:要搜索的参数网格。
    • refit=True:在找到最佳参数组合后,使用整个训练集重新训练模型。
    • verbose=2:设置详细程度,输出更多的搜索过程信息。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/385529.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

搜索引擎项目(四)

SearchEngine 王宇璇/submit - 码云 - 开源中国 (gitee.com) 基于Servlet完成前后端交互 WebServlet("/searcher") public class DocSearcherServlet extends HttpServlet {private static DocSearcher docSearcher new DocSearcher();private ObjectMapper obje…

Kettle下载安装

环境说明 虚拟机:Win7;MySql8.0 主机:Win11;JDK1.8;Kettle 9.4(Pentaho Data Integration 9.4)(下载方式见文末) 安装说明 【1】解压后运行Spoon.bat 【2】将jar包 复…

【Linux C | 网络编程】进程池退出的实现详解(五)

上一篇中讲解了在进程池文件传输的过程如何实现零拷贝,具体的方法包括使用mmap,sendfile,splice等等。 【Linux C | 网络编程】进程池零拷贝传输的实现详解(四) 这篇内容主要讲解进程池如何退出。 1.进程池的简单退…

聊聊基于Alink库的主成分分析(PCA)

概述 主成分分析(Principal Component Analysis,PCA)是一种常用的数据降维和特征提取技术,用于将高维数据转换为低维的特征空间。其目标是通过线性变换将原始特征转化为一组新的互相无关的变量,这些新变量称为主成分&…

7月24日JavaSE学习笔记

序列化版本控制 序列化:将内存对象转换成序列(流)的过程 反序列化:将对象序列读入程序,转换成对象的方式;反序列化的对象是一个新的对象。 serialVersionUID 是一个类的序列化版本号 private static fin…

算法通关:006_1二分查找

二分查找 查找一个数组里面是否存在num主要代码运行结果 详细写法自动生成数组和num,利用对数器查看二分代码是否正确 查找一个数组里面是否存在num 主要代码 /*** Author: ggdpzhk* CreateTime: 2024-07-27*/ public class cg {//二分查找public static boolean …

戴着苹果Vision Pro,如何吃花生米

6月底苹果Vision Pro国内开售,我早早到官网预订了一台。选择必要的配件,输入视力信息,定制符合自己视力的蔡司镜片。确实贵。把主要配件和镜片配齐,要3万6,比Pico、META的眼镜贵一个数量级。 Vision Pro出来后&#x…

C++的STL简介(一)

目录 1.什么是STL 2.STL的版本 3.STL的六大组件 4.string类 4.1为什么学习string类? 4.2string常见接口 4.2.1默认构造 ​编辑 4.2.2析构函数 Element access: 4.2.3 [] 4.2.4迭代器 ​编辑 auto 4.2.4.1 begin和end 4.2.4.2.regin和rend Capacity: 4.2.5…

Q238. 除自身以外数组的乘积

思路 一开始想到的是按位乘 看了题解,思路是存i左边的乘积和 与 i右边的乘积和 代码一: 需要三次循环,需要额外空间 left和right数组 代码: public int[] productExceptSelf(int[] nums) {int[] left new int[nums.length];int[] right …

【爱上C++】list用法详解、模拟实现

文章目录 一:list介绍以及使用1.list介绍2.基本用法①list构造方式②list迭代器的使用③容量④元素访问⑤插入和删除⑥其他操作image.png 3.list与vector对比 二:list模拟实现1.基本框架2.节点结构体模板3.__list_iterator 结构体模板①模板参数说明②构…

基于Xejen框架实现的C# winform鼠标点击器、电脑按键自动点击器的软件开发及介绍

功能演示 文章开始之前,仍然是先来个视频,以便用户知道鼠标连点器的基本功能 软件主界面 多功能鼠标连点器 快速点击: 痕即鼠标点击器可以设定每秒点击次数,让您轻松应对高频点击需求。 切换时长,即每次动作之间的间…

大数据的数据质量有效提升的研究

大数据的数据质量有效提升是一个涉及多个环节和维度的复杂过程。以下是从数据采集、处理、管理到应用等方面,对大数据数据质量有效提升的研究概述: 一、数据采集阶段 明确采集需求:在数据采集前,需明确数据需求,包括…

leetocde662. 二叉树最大宽度,面试必刷题,思路清晰,分点解析,附代码详解带你完全弄懂

leetocde662. 二叉树最大宽度 做此题之前可以先做一下二叉树的层序遍历。具体题目如下: leetcode102二叉树的层序遍历 我也写过题解,可以先看看学习一下,如果会做层序遍历了,那么这题相对来说会简单很多。 具体题目 给你一棵…

把 网页代码 嵌入到 单片机程序中 2 日志2024/7/26

之前不是说把 网页代码 嵌入到 单片机程序中 嘛! 目录 之前不是说把 网页代码 嵌入到 单片机程序中 嘛! 修改vs的tasks.json配置 然后 测试 结果是正常的,可以编译了 但是:当我把我都html代码都写上去之后 还是会报错!!! 内部被检测到了,没辙,只有手动更新了小工具代码 …

Python3网络爬虫开发实战(2)爬虫基础库

文章目录 一、urllib1. urlparse 实现 URL 的识别和分段2. urlunparse 用于构造 URL3. urljoin 用于两个链接的拼接4. urlencode 将 params 字典序列化为 params 字符串5. parse_qs 和 parse_qsl 用于将 params 字符串反序列化为 params 字典或列表6. quote 和 unquote 对 URL的…

JAVAWeb实战(前端篇)

项目实战一 0.项目结构 1.创建vue3项目,并导入所需的依赖 npm install vue-router npm install axios npm install pinia npm install vue 2.定义路由,axios,pinia相关的对象 文件(.js) 2.1路由(.js) import {cre…

HarmonyOS Next 省市区级联(三级联动)筛选框

效果图 完整代码 实例对象 export class ProvinceBean {id?: stringpid?: stringisSelect?: booleandeep?: objectextName?: stringchildren?: ProvinceBean[] }级联代码 import { MMKV } from tencent/mmkv/src/main/ets/utils/MMKV import { ProvinceBean } from ..…

nodeselector

1.概述 在创建pod资源是,k8s集群系统会给我们将pod资源随机分配到不同服务器上。我们通过配置nodeSelector可以将pod资源指定到拥有某个标签的服务器上 使用nodeselector前我们要先给每个节点打上标签,在编辑pod资源的时候选择该标签 2.示例 给节点打标…

数据科学统计面试问题 -40问

前 40 名数据科学统计面试问题 一、介绍 正如 Josh Wills 曾经说过的那样,“数据科学家是一个比任何程序员都更擅长统计、比任何统计学家都更擅长编程的人”。统计学是数据科学中处理数据及其分析的基本工具。它提供了工具和方法,可帮助数据科学家获得…