1. 原理
2. 算法步骤
3. 目标函数
4. 优缺点
import torch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
import torch.nn as nn# ====== 数据准备 ======
# 生成数据:100 个张量,每个张量是 5 维向量
torch.manual_seed(42)data1 = torch.randn(50, 5) + 2 # 第一类# 其中包含标准正态分布的随机数(均值为 0,标准差为 1)
# + 2 将张量中的每个值加上 2,使数据的中心(均值)移动到 2,相当于对整个分布进行平移。data2 = torch.randn(30, 5) - 2 # 第二类
data3 = torch.randn(20, 5) # 第三类
data = torch.cat([data1, data2, data3], dim=0)# 转为 NumPy 数组
data_np = data.numpy()# ====== 方法 1: K-Means 聚类 ======
# 聚类为 3 类
kmeans = KMeans(n_clusters=3, random_state=42)# 创建一个 K-Means 聚类模型实例, random_state=42 固定随机数种子,保证聚类结果的可复现性。labels_kmeans = kmeans.fit_predict(data_np)# 将数据 data_np 输入到 K-Means 模型中,完成聚类并返回每个数据点的类别标签。
# labels_kmeans 是一个长度为 100 的数组,其中每个元素是对应数据点的聚类标签(0、1 或 2)。# 使用 PCA 将数据降维到 2D 用于可视化
pca = PCA(n_components=2)# 创建一个主成分分析(PCA)模型实例。
# n_components=2:将高维数据降维到 2 维。data_pca = pca.fit_transform(data_np)# 对数据 data_np 进行 PCA 降维,并返回降维后的数据。
# data_pca 是一个形状为 (100, 2) 的二维数组,每行是原始数据在降维后空间中的坐标。# 可视化 K-Means 聚类结果
plt.figure(figsize=(8, 6))
plt.scatter(data_pca[:, 0], data_pca[:, 1], c=labels_kmeans, cmap='viridis', s=50)
plt.title("K-Means Clustering Results (PCA 2D)")
plt.xlabel("PCA Dimension 1")
plt.ylabel("PCA Dimension 2")
plt.colorbar(label='Cluster')
plt.show()
PS:聚类数据生成方法
data, true_labels = make_blobs(n_samples=n_samples, centers=n_clusters, cluster_std=[1.0, 2.5, 0.5], random_state=42)
- n_samples 指定要生成的数据点总数。
如果是一个整数(如 n_samples=300),表示生成 300 个数据点。
如果是一个列表或数组(如 [100, 150, 50]),则指定每个簇分别生成的数据点数量。- centers 指定簇的数量或簇中心的具体坐标。
如果是一个整数(如 n_clusters=3),表示生成 3 个簇,簇中心的坐标会自动随机生成。
如果是一个数组(如 [[0, 0], [3, 3], [1, -2]]),则指定每个簇的中心坐标。- cluster_std 指定每个簇的标准差(数据点的离散程度)。
如果是一个单一值(如 cluster_std=1.0),表示所有簇的标准差相同。
如果是一个列表(如 [1.0, 2.5, 0.5]),表示每个簇有不同的标准差,分别为 1.0、2.5 和 0.5。- random_state 控制随机数生成器,用于保证数据可复现。
如果设置为固定的整数值(如 random_state=42),每次生成的数据点相同。
如果不指定或设置为 None,每次生成的数据点可能不同。
返回值:
- data: 生成的样本数据。ndarray,形状为 (n_samples, n_features)。
- true_labels: 每个样本对应的真实类别标签。ndarray,形状为 (n_samples,)。
每个元素是一个整数,表示该样本所属簇的索引(如 0 表示第一个簇,1 表示第二个簇,以此类推)。
实例:
import matplotlib.pyplot as plt
from sklearn.datasets import make_blobs# 参数设置
n_samples = 300
n_clusters = 3
cluster_std = [1.0, 2.5, 0.5]# 生成数据
data, true_labels = make_blobs(n_samples=n_samples,centers=n_clusters,cluster_std=cluster_std,random_state=42
)# 可视化数据
plt.figure(figsize=(8, 6))
plt.scatter(data[:, 0], data[:, 1], c=true_labels, cmap='viridis', s=50)
plt.title("Generated Data with make_blobs")
plt.xlabel("Feature 1")
plt.ylabel("Feature 2")
plt.colorbar(label='Cluster Index')
plt.show()
这里只是生成了三个簇的数据,并没有进行k-means聚类,只是用不同颜色对应true_labels。
其中0/1/2对应三种不同的簇
5. 优化策略
5.1 数据生成与可视化
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
from sklearn.datasets import make_blobs# 生成模拟数据
np.random.seed(42)
X, _ = make_blobs(n_samples=500, centers=4, cluster_std=1.0, random_state=42)# 数据分布可视化
plt.scatter(X[:, 0], X[:, 1], s=30, c='gray', alpha=0.5)
plt.title("Generated Data")
plt.show()
# 使用 KMeans++ 初始化
kmeans = KMeans(n_clusters=4, init='k-means++', random_state=42)
labels = kmeans.fit_predict(X)# 聚类结果可视化
plt.scatter(X[:, 0], X[:, 1], c=labels, cmap='viridis', s=30, alpha=0.5)
plt.scatter(kmeans.cluster_centers_[:, 0], kmeans.cluster_centers_[:, 1], s=200, c='red', marker='X', label='center node')
plt.title("K-Means ++")
plt.legend()
plt.show()
# 肘部法则确定簇数
wcss = []
for k in range(1, 10):kmeans = KMeans(n_clusters=k, init='k-means++', random_state=42)kmeans.fit(X) # 对数据集 X 进行 k-means 聚类。wcss.append(kmeans.inertia_) # 将当前 k 值下的 WCSS 添加到列表中。# 绘制肘部法则图
plt.plot(range(1, 10), wcss, marker='o')
plt.title("Elbow Method")
plt.xlabel("Cluster")
plt.ylabel("WCSS")
plt.show()# 计算轮廓系数
for k in range(2, 7):kmeans = KMeans(n_clusters=k, random_state=42)labels = kmeans.fit_predict(X) # 聚类并获取每个样本的标签(簇分配)。score = silhouette_score(X, labels) # 计算轮廓系数。print(f"Cluster {k}: Silhouette Coefficient = {score:.2f}")
这两种方法最终获得的簇标签 labels 是相同的,选择哪个取决于具体需求:
- 如果还需要访问其他模型属性,fit 是更灵活的选择。
- 如果只关心簇标签,fit_predict 更便捷。
6. K-means++
6.1 原理