数据聚类:Mean-Shift和EM算法


目录

  • 1. 高斯混合分布
  • 2. Mean-Shift算法
  • 3. EM算法
  • 4. 数据聚类
  • 5. 源码地址


1. 高斯混合分布

在高斯混合分布中,我们假设数据是由多个高斯分布组合而成的。每个高斯分布被称为一个“成分”(component),这些成分通过加权和的方式来构成整个混合分布。

高斯混合分布的公式可以表示为:

p ( x ) = ∑ i = 1 K π i N ( x ∣ μ i , Σ i ) p(x) = \sum^K_{i=1} \pi_i N(x|\mu_i, \Sigma_i) p(x)=i=1KπiN(xμi,Σi)

其中:

  • p ( x ) p(x) p(x)是观测数据点 x x x的概率密度函数,
  • K K K是高斯分布的数量,
  • π i \pi_i πi是第 i i i个高斯分布的混合系数,满足 ∑ i = 1 K π i = 1 \sum^K_{i=1} \pi_i = 1 i=1Kπi=1,
  • μ i \mu_i μi是第 i i i个高斯分布的均值向量,
  • Σ i \Sigma_i Σi是第 i i i个高斯分布的协方差矩阵。

为了简单呈现结果,我们选取 K = 2 K=2 K=2个高斯分布。下图为一个高斯混合分布的采样散点图,其中两个高斯分布的 μ i \mu_i μi分别为 [ 0 , 0 ] , [ 5 , 5 ] [0,0], [5,5] [0,0],[5,5],协方差矩阵均为:
[ 1 0 0 1 ] \begin{bmatrix} 1 & 0 \\ 0 & 1 \end{bmatrix} [1001]

在这里插入图片描述

Fig. 1. 高斯混合分布的采样散点图

2. Mean-Shift算法

Mean-Shift是一种非参数化的密度估计和聚类算法,用于将数据点组织成具有相似特征的群集。它是一种迭代算法,通过计算数据点的梯度信息来寻找数据点在特征空间中的密度极值点,从而确定聚类中心。

算法的核心思想是通过不断地更新数据点的位置,将它们移向密度估计函数梯度的最大方向,直到达到收敛条件。具体来说,Mean-Shift算法包括以下步骤:

  • 初始化:选择一个数据点作为初始聚类中心,或者随机选择一个点作为初始中心。
  • 确定梯度向量:对于每个数据点,计算其与其他数据点之间的距离,并根据一定的核函数(如高斯核)计算梯度向量。梯度向量的方向指向密度估计函数增加最快的方向。
  • 移动数据点:将每个数据点移动到梯度向量的方向上,即向密度估计函数增加最快的方向移动一定的步长。
  • 更新聚类中心:对于移动后的每个数据点,重新计算它们周围数据点的梯度向量,并更新它们的位置。重复这个过程直到达到收敛条件,比如聚类中心的移动距离小于某个阈值。
  • 形成聚类:最终,根据收敛后的聚类中心,将数据点分配到最近的聚类中心,形成最终的聚类结果。

Mean-Shift算法的优点是不需要事先指定聚类的个数,且能够自适应地调整聚类中心的数量和形状。它在处理非线性和非凸形状的数据集时表现出良好的聚类效果。然而,该算法对于大规模数据集的计算复杂度较高,且对初始聚类中心的选择敏感。Mean-Shift算法的具体实现见代码片:

class MeanShift:def __init__(self, bandwidth=1.0, max_iterations=100):self.min_shift = 1self.n_clusters_ = Noneself.cluster_centers_ = Noneself.labels_ = Noneself.bandwidth = bandwidthself.max_iterations = max_iterationsdef euclidean_distance(self, x1, x2):return np.sqrt(np.sum((x1 - x2) ** 2))def gaussian_kernel(self, distance, bandwidth):return (1 / (bandwidth * np.sqrt(2 * np.pi))) * np.exp(-0.5 * ((distance / bandwidth) ** 2))def shift_point(self, point, data, bandwidth):shift_x = 0.0shift_y = 0.0total_weight = 0.0for i in range(len(data)):distance = self.euclidean_distance(point, data[i])weight = self.gaussian_kernel(distance, bandwidth)shift_x += data[i][0] * weightshift_y += data[i][1] * weighttotal_weight += weightshift_x /= total_weightshift_y /= total_weightreturn np.array([shift_x, shift_y])def fit(self, data):centroids = np.copy(data)for _ in range(self.max_iterations):shifts = np.zeros_like(centroids)for i, centroid in enumerate(centroids):distances = cdist([centroid], data)[0]weights = self.gaussian_kernel(distances, self.bandwidth)shift = np.sum(weights[:, np.newaxis] * data, axis=0) / np.sum(weights)shifts[i] = shiftshift_distances = cdist(shifts, centroids)centroids = shiftsif np.max(shift_distances) < self.min_shift:breakunique_centroids = np.unique(np.around(centroids, 3), axis=0)self.cluster_centers_ = unique_centroidsself.labels_ = np.argmin(cdist(data, unique_centroids), axis=1)self.n_clusters_ = len(unique_centroids)

3. EM算法

EM算法是一种迭代的数值优化算法,用于求解包含隐变量的概率模型参数的最大似然估计。它在统计学和机器学习领域被广泛应用,尤其在聚类问题中有着重要的应用。其基于观测数据和隐变量之间的概率模型,通过交替进行两个步骤:E步骤(Expectation Step)和M步骤(Maximization Step)来迭代地优化模型参数。下面是EM算法的基本步骤:

  • 初始化:选择一组初始参数来开始迭代过程。
  • E步骤:根据当前的参数估计,计算隐变量的后验概率,即给定观测数据下隐变量的条件概率分布。
  • M步骤:使用在E步骤中计算得到的后验概率,对参数进行更新,以最大化对数似然函数。
  • 重复步骤2-3至收敛:重复执行E步骤和M步骤,直到参数的变化很小或满足收敛条件。

在聚类问题中,EM算法可以用于估计混合高斯模型的参数,从而实现数据的聚类。EM算法在聚类中的应用优点是能够处理具有隐变量的概率模型,适用于灵活的聚类问题。然而,EM算法对于初始参数的选择敏感,可能会陷入局部最优解,并且在处理大规模数据集时可能会面临计算复杂度的挑战。EM算法(包含正则化步骤)的具体实现见代码片:

class RegularizedEMClustering:def __init__(self, n_clusters, max_iterations=100, epsilon=1e-4, regularization=1e-6):self.labels_ = Noneself.X = Noneself.n_features = Noneself.n_samples = Noneself.cluster_probs_ = Noneself.cluster_centers_ = Noneself.n_clusters = n_clustersself.max_iterations = max_iterationsself.epsilon = epsilonself.regularization = regularizationdef fit(self, X):self.X = Xself.n_samples, self.n_features = X.shapeself.cluster_centers_ = X[np.random.choice(self.n_samples, self.n_clusters, replace=False)]self.cluster_probs_ = np.ones((self.n_samples, self.n_clusters)) / self.n_clusters# EMfor iteration in range(self.max_iterations):# E-stepprev_cluster_probs = self.cluster_probs_self._update_cluster_probs()# M-stepself._update_cluster_centers()delta = np.linalg.norm(self.cluster_probs_ - prev_cluster_probs)if delta < self.epsilon:breakself.labels_ = np.argmax(self.cluster_probs_, axis=1)def _update_cluster_probs(self):distances = np.linalg.norm(self.X[:, np.newaxis, :] - self.cluster_centers_, axis=2)# Calculate the cluster probabilities with regularizationnumerator = np.exp(-distances) + self.regularizationdenominator = np.sum(numerator, axis=1, keepdims=True)self.cluster_probs_ = numerator / denominatordef _update_cluster_centers(self):self.cluster_centers_ = np.zeros((self.n_clusters, self.n_features))for k in range(self.n_clusters):self.cluster_centers_[k] = np.average(self.X, axis=0, weights=self.cluster_probs_[:, k])def predict(self, X):distances = np.linalg.norm(X[:, np.newaxis, :] - self.cluster_centers_, axis=2)return np.argmin(distances, axis=1)

4. 数据聚类

Mean-Shift和EM算法的聚类结果分别如图2的a-b子图所示,由于MoG比较简单,两种算法均可以合理且完整地实现聚类,聚类中心也没有显著差异。

在这里插入图片描述

Fig. 2. Mean-Shift(a)和EM(b)算法的聚类结果

5. 源码地址

如果对您有用的话可以点点star哦~

https://github.com/Jurio0304/cs-math/blob/main/hw3_clustering.ipynb
https://github.com/Jurio0304/cs-math/blob/main/func.py


创作不易,麻烦点点赞和关注咯!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/315185.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Bert基础(十八)--Bert实战:NER命名实体识别

1、命名实体识别介绍 1.1 简介 命名实体识别&#xff08;NER&#xff09;是自然语言处理&#xff08;NLP&#xff09;中的一项关键技术&#xff0c;它的目标是从文本中识别出具有特定意义或指代性强的实体&#xff0c;并对这些实体进行分类。这些实体通常包括人名、地名、组织…

小龙虾优化算法(Crayfish Optimization Algorithm,COA)

小龙虾优化算法&#xff08;Crayfish Optimization Algorithm&#xff0c;COA&#xff09; 前言一、小龙虾优化算法的实现1.初始化阶段2.定义温度和小龙虾的觅食量3.避暑阶段&#xff08;探索阶段&#xff09;4.竞争阶段&#xff08;开发阶段&#xff09;5.觅食阶段&#xff08…

高速风筒电源IC辉芒微FT8440E /FT8440A

深圳市三佛科技有限公司分享相关资料&#xff0c;高精度、高效率、低成本离线式功率开关 FT8440x是-款高性能、高精度、低成本的非隔离PWM功率开关。它包含一个专门]的电流模PWM控制器和一个高压功率开关管。内置的误差放大器经过优化保证优越的动态响应。高精度的内部分压电阻…

深度学习-线性代数

目录 标量向量矩阵特殊矩阵特征向量和特征值 标量由只有一个元素的张量表示将向量视为标量值组成的列表通过张量的索引来访问任一元素访问张量的长度只有一个轴的张量&#xff0c;形状只有一个元素通过指定两个分量m和n来创建一个形状为mn的矩阵矩阵的转置对称矩阵的转置逻辑运…

【LLMOps】小白详细教程,在Dify中创建并使用自定义工具

文章目录 博客详细讲解视频点击查看高清脑图 1. 搭建天气查询http服务1.1. flask代码1.2. 接口优化方法 2. 生成openapi json schema2.1. 测试接口2.2. 生成openapi schema 3. 在dify中创建自定义工具3.1. 导入schema3.2. 设置工具认证信息3.3. 测试工具 4. 调用工具4.1. Agent…

【JavaWeb】Day51.Mybatis动态SQL(一)

什么是动态SQL 在页面原型中&#xff0c;列表上方的条件是动态的&#xff0c;是可以不传递的&#xff0c;也可以只传递其中的1个或者2个或者全部。 而在我们刚才编写的SQL语句中&#xff0c;我们会看到&#xff0c;我们将三个条件直接写死了。 如果页面只传递了参数姓名name 字…

Multitouch 1.27.28 免激活版 mac电脑多点触控手势增强工具

Multitouch 应用程序可让您将自定义操作绑定到特定的魔术触控板或鼠标手势。例如&#xff0c;三指单击可以执行粘贴。通过执行键盘快捷键、控制浏览器的选项卡、单击鼠标中键等来改进您的工作流程。 Multitouch 1.27.28 免激活版下载 强大的手势引擎 精心打造的触控板和 Magic …

iOS 模拟请求 (本地数据调试)

简介 在iOS 的日常开发中经常会遇到一下情况&#xff1a;APP代码已编写完成&#xff0c;但后台的接口还无法使用&#xff0c;这时 APP开发就可能陷入停滞。此时iOS 模拟请求就派上用场了&#xff0c;使用模拟请求来调试代码&#xff0c;如果调试都通过了&#xff0c;等后台接口…

迁移学习基础知识

简介 使用迁移学习的优势&#xff1a; 1、能够快速的训练出一个理想的结果 2、当数据集较小时也能训练出理想的效果。 注意&#xff1a;在使用别人预训练的参数模型时&#xff0c;要注意别人的预处理方式。 原理&#xff1a; 对于浅层的网络结构&#xff0c;他们学习到的…

matplotlib绘图

matplotlib版本&#xff1a;3.7.5 numpy版本&#xff1a;1.24.3 pandas版本&#xff1a;2.0.3 导包构造数据 import matplotlib.pyplot as plt import numpy as np import pandas as pd# %matplotlib inlinea np.linspace(0, 2 * np.pi, 100) asin np.sin(a) acos np.cos(…

c++理论篇(一) ——浅谈tcp缓存与tcp的分包与粘包

介绍 在网络通讯中,Linux系统为每一个socket创建了接收缓冲区与发送缓冲区,对于TCP协议来说,这两个缓冲区是必须的.应用程序在调用send/recv函数时,Linux内核会把数据从应用进程拷贝到socket的发送缓冲区中,应用程序在调用recv/read函数时,内核把接收缓冲区中的数据拷贝到应用…

Bert语言大模型基础

一、Bert整体模型架构 基础架构是transformer的encoder部分&#xff0c;bert使用多个encoder堆叠在一起。 主要分为三个部分&#xff1a;1、输入部分 2、注意力机制 3、前馈神经网络 bertbase使用12层encoder堆叠在一起&#xff0c;6个encoder堆叠在一起组成编码端&#xf…

CSS——前端笔记

CSS 1、选择器1.1、基础选择器1.2、复合选择器1.2.4、伪类选择器 1.3、属性选择器1.4、结构伪类选择器1.5、伪元素选择器 2、CSS的元素显示模式2.1、块元素2.2、行内元素2.3、行内块元素2.4、元素显示模式转换 3、字体属性3.1、font-family 字体3.2、font-size 字体大小3.3、fo…

vue echarts 柱状图 堆叠柱状图

echarts堆叠柱状图&#xff08;效果图在文章末尾&#xff09; 1、默认只显示 月度的 数据&#xff0c;手动点击 legend 季度的 数据才会显示&#xff1b; 2、监听左侧菜单栏的宽度变化&#xff0c;图表宽度自适应展示 <template><div><div id"barChart&q…

Java学习1:java入门

java入门 1.介绍Java java——sun公司——被甲骨文收购 一开始叫Oak&#xff0c;后期改名为java; java之父詹姆斯高斯林 企业级应用开发 JavaSE JavaEE JavaME 高级编程语言 2.搭建开发环境 JDK8&#xff0c;JDK11&#xff0c;JDK17————>LTS长期支持版 ps:在…

【Linux】基础指令

文章目录 基础指令1. pwd 指令2. cd 指令3. ls 指令4. touch 指令5. mkdir 指令6. rmdir 和 rm 指令7. man 指令8. cp 指令9. mv 指令10. cat 指令11. more 和 less 指令12. head 和 tail 指令13. date 指令14. cal 指令15. find 指令16. grep 指令18. zip 和 unzip 指令19. ta…

IPv4 NAT(含Cisco配置)

IPv4 NAT&#xff08;含Cisco配置&#xff09; IPv4私有空间地址 类RFC 1918 内部地址范围前缀A10.0.0.0 - 10.255.255.25510.0.0.0/8B172.16.0.0 - 172.31.255.255172.16.0.0/12C192.168.0.0 - 192.168.255.255192.168.0.0/16 这些私有地址可在企业或站点内使用&#xff0c…

jupyter notebook设置代码自动补全

jupyter notebook设置代码自动补全 Anaconda Prompt窗口执行 pip install jupyter_contrib_nbextensionsjupyter contrib nbextensions install --userpip install jupyter_nbextensions_configuratorjupyter nbextensions_configurator enable --user按如下图片设置 卸载jed…

创建Vue3项目遇到的问题 - TypeError: (0 , import_node_util.parseArgs) is not a function

印象中想要创建vue3项目&#xff0c;需要安装16.0或更高版本的Node.js&#xff0c;于是第一步检查现在所用node版本。 显示 v16.20.0。前置条件符合&#xff0c;开始愉快的创建项目。npm init vuelatest&#xff0c;报错了。 查了一下&#xff0c;发现官网已经改成了需要18.3或…

AI赋能分层模式,解构未来,智领风潮

​&#x1f308; 个人主页&#xff1a;danci_ &#x1f525; 系列专栏&#xff1a;《设计模式》 &#x1f4aa;&#x1f3fb; 制定明确可量化的目标&#xff0c;坚持默默的做事。 &#x1f680; 转载自热榜文章&#x1f525;&#xff1a;探索设计模式的魅力&#xff1a;AI赋能分…