使用Faiss进行K-Means聚类

📝 本文需要的前置知识:Faiss的基本使用

目录

  • 1. 源码剖析
    • 1.1 参数解释
  • 2. 聚类过程详解
    • 2.1 初始化聚类中心
    • 2.2 分配步骤(Assignment)
    • 2.3 更新步骤(Update)
    • 2.4 收敛与终止条件
  • 3. GPU 加速
    • 3.1 索引结构与 GPU
    • 3.2 GPU 训练过程
    • 3.3 多 GPU 训练
  • 4. 聚类后的操作
    • 4.1 获取聚类中心
    • 4.2 分配新数据点
    • 4.3 评估聚类效果
  • 5. 参数调优与最佳实践
    • 5.1 选择合适的簇数(k)
    • 5.2 调整迭代次数(niter)
    • 5.3 使用 GPU 的优化
    • 5.4 数据预处理
  • 6. 实际案例分析
    • 6.1 数据集准备
    • 6.2 聚类模型训练
    • 6.3 聚类结果分析
    • 6.4 使用聚类结果进行图像检索
  • 7. 常见问题与解决方案
    • 7.1 内存不足
    • 7.2 聚类效果不佳
    • 7.3 GPU 资源不足
  • 8. 高级用法与扩展
  • 9. 性能优化技巧
  • Ref

1. 源码剖析

如下是 Kmeans 的源码(摘自faiss 1.7.4版本):

class Kmeans:"""Object that performs k-means clustering and manages the centroids.The `Kmeans` class is essentially a wrapper around the C++ `Clustering` object.Parameters----------d : intDimension of the vectors to cluster.k : intNumber of clusters.gpu: bool or int, optionalFalse: don't use GPUTrue: use all GPUsnumber: use this many GPUsprogressive_dim_steps:Use a progressive dimension clustering (with that number of steps).Subsequent parameters are fields of the Clustering object. The most important are:niter: int, optionalClustering iterations.nredo: int, optionalRedo clustering this many times and keep the best.verbose: bool, optionalspherical: bool, optionalDo we want normalized centroids?int_centroids: bool, optionalRound centroids coordinates to integer.seed: int, optionalSeed for the random number generator."""def __init__(self, d, k, **kwargs):"""d: input dimension, k: nb of centroids. Additionalparameters are passed on the ClusteringParameters object,including niter=25, verbose=False, spherical=False."""self.d = dself.k = kself.gpu = Falseif "progressive_dim_steps" in kwargs:self.cp = ProgressiveDimClusteringParameters()else:self.cp = ClusteringParameters()for k, v in kwargs.items():if k == 'gpu':if v is True or v == -1:v = get_num_gpus()self.gpu = velse:# if this raises an exception, it means that it is a non-existent fieldgetattr(self.cp, k)setattr(self.cp, k, v)self.centroids = Nonedef train(self, x, weights=None, init_centroids=None):"""Perform k-means clustering.On output of the function call:- The centroids are in the centroids field of size (`k`, `d`).- The objective value at each iteration is in the array obj (size `niter`).- Detailed optimization statistics are in the array iteration_stats.Parameters----------x : array_likeTraining vectors, shape (n, d), `dtype` must be float32 and n shouldbe larger than the number of clusters `k`.weights : array_likeWeight associated to each vector, shape `n`.init_centroids : array_likeInitial set of centroids, shape (n, d).Returns-------final_obj: floatFinal optimization objective."""x = np.ascontiguousarray(x, dtype='float32')n, d = x.shapeassert d == self.dif self.cp.__class__ == ClusteringParameters:# Regular clusteringclus = Clustering(d, self.k, self.cp)if init_centroids is not None:nc, d2 = init_centroids.shapeassert d2 == dfaiss.copy_array_to_vector(init_centroids.ravel(), clus.centroids)if self.cp.spherical:self.index = IndexFlatIP(d)else:self.index = IndexFlatL2(d)if self.gpu:self.index = faiss.index_cpu_to_all_gpus(self.index, ngpu=self.gpu)clus.train(x, self.index, weights)else:# Not supported for progressive dimassert weights is Noneassert init_centroids is Noneassert not self.cp.sphericalclus = ProgressiveDimClustering(d, self.k, self.cp)if self.gpu:fac = GpuProgressiveDimIndexFactory(ngpu=self.gpu)else:fac = ProgressiveDimIndexFactory()clus.train(n, swig_ptr(x), fac)centroids = faiss.vector_float_to_array(clus.centroids)self.centroids = centroids.reshape(self.k, d)stats = clus.iteration_statsstats = [stats.at(i) for i in range(stats.size())]self.obj = np.array([st.obj for st in stats])# Copy all the iteration_stats objects to a Python arraystat_fields = 'obj time time_search imbalance_factor nsplit'.split()self.iteration_stats = [{field: getattr(st, field) for field in stat_fields}for st in stats]return self.obj[-1] if self.obj.size > 0 else 0.0def assign(self, x):"""Assign data points to the nearest cluster centroid.Parameters----------x : array_likeData points to assign, shape (n, d), `dtype` must be float32.Returns-------D : array_likeDistances of each data point to its nearest centroid.I : array_likeIndex of the nearest centroid for each data point."""x = np.ascontiguousarray(x, dtype='float32')assert self.centroids is not None, "Should train before assigning"self.index.reset()self.index.add(self.centroids)D, I = self.index.search(x, 1)return D.ravel(), I.ravel()

聚类时基本只会用到 ClusteringParameters(),以下是该类的源码:

class ClusteringParameters(object):r"""Class for the clustering parameters. Can be passed to theconstructor of the Clustering object."""thisown = property(lambda x: x.this.own(), lambda x, v: x.this.own(v), doc="The membership flag")__repr__ = _swig_reprniter = property(_swigfaiss.ClusteringParameters_niter_get, _swigfaiss.ClusteringParameters_niter_set, doc=r""" clustering iterations""")nredo = property(_swigfaiss.ClusteringParameters_nredo_get, _swigfaiss.ClusteringParameters_nredo_set, doc=r""" redo clustering this many times and keep best""")verbose = property(_swigfaiss.ClusteringParameters_verbose_get, _swigfaiss.ClusteringParameters_verbose_set)spherical = property(_swigfaiss.ClusteringParameters_spherical_get, _swigfaiss.ClusteringParameters_spherical_set, doc=r""" do we want normalized centroids?""")int_centroids = property(_swigfaiss.ClusteringParameters_int_centroids_get, _swigfaiss.ClusteringParameters_int_centroids_set, doc=r""" round centroids coordinates to integer""")update_index = property(_swigfaiss.ClusteringParameters_update_index_get, _swigfaiss.ClusteringParameters_update_index_set, doc=r""" re-train index after each iteration?""")frozen_centroids = property(_swigfaiss.ClusteringParameters_frozen_centroids_get, _swigfaiss.ClusteringParameters_frozen_centroids_set, doc=r"""use the centroids provided as input and do notchange them during iterations""")min_points_per_centroid = property(_swigfaiss.ClusteringParameters_min_points_per_centroid_get, _swigfaiss.ClusteringParameters_min_points_per_centroid_set, doc=r""" otherwise you get a warning""")max_points_per_centroid = property(_swigfaiss.ClusteringParameters_max_points_per_centroid_get, _swigfaiss.ClusteringParameters_max_points_per_centroid_set, doc=r""" to limit size of dataset""")seed = property(_swigfaiss.ClusteringParameters_seed_get, _swigfaiss.ClusteringParameters_seed_set, doc=r""" seed for the random number generator""")decode_block_size = property(_swigfaiss.ClusteringParameters_decode_block_size_get, _swigfaiss.ClusteringParameters_decode_block_size_set, doc=r""" how many vectors at a time to decode""")def __init__(self):r""" sets reasonable defaults"""_swigfaiss.ClusteringParameters_swiginit(self, _swigfaiss.new_ClusteringParameters())__swig_destroy__ = _swigfaiss.delete_ClusteringParameters

1.1 参数解释

Kmeans 初始化时的部分参数来源于 ClusteringParameters,以下是对常用参数的解释:

def __init__(self, d, k, **kwargs):"""Parameters----------d : int参与聚类的向量的维度k : int聚类后簇的个数gpu: bool or int, optionalFalse: 不使用GPUTrue: 使用所有GPUnumber: 使用number个GPU,number=-1时也代表使用所有GPU默认为Falseniter: int, optional聚类算法的迭代次数,默认为25verbose: bool, optional是否输出详细信息,默认为Falsespherical: bool, optional是否在每次迭代后归一化聚类中心,默认为Falsemin_points_per_centroid: int, optional每个簇中的最小点数,默认为39max_points_per_centroid: int, optional每个簇中的最大点数,默认为256seed: int, optional随机种子,默认为1234"""

n n n 为参与训练的向量个数, k k k 为簇数,一些注意事项总结如下:

  • n > max_points_per_centroid * k,则只会采样 max_points_per_centroid * k 个向量进行训练,默认是 256 k 256k 256k 个;若 n < min_points_per_centroid * kn < k,则会直接报错。理想情况是 min_points_per_centroid * k <= n <= max_points_per_centroid * k保险起见,通常会选择设置 min_points_per_centroid = 1max_points_per_centroid = n
  • 当迭代次数超过 20 20 20 次且 n > 1000 k n>1000k n>1000k 时,继续增加迭代次数或训练点数量并不会显著提高算法性能,所以faiss默认会选择下采样。

2. 聚类过程详解

在了解了 Kmeans 类和 ClusteringParameters 的基本结构与参数之后,接下来我们深入剖析其聚类过程。这一过程主要包括以下几个步骤:

  1. 初始化聚类中心(Initialization)
  2. 分配数据点到最近的聚类中心(Assignment)
  3. 更新聚类中心(Update)
  4. 迭代直到收敛或达到最大迭代次数

2.1 初始化聚类中心

初始化是 K-Means 算法中至关重要的一步,因为不良的初始化可能导致收敛到局部最优解。faissKmeans 类默认使用 k-means++ 初始化方法,这是一种改进的初始化策略,能够显著提高聚类的效果和收敛速度。

def initialize_centroids(self, x):"""初始化聚类中心,使用 k-means++ 算法。Parameters----------x : array_like训练向量,形状为 (n, d)。Returns-------centroids : array_like初始化后的聚类中心,形状为 (k, d)。"""n, d = x.shapecentroids = np.empty((self.k, d), dtype='float32')# 随机选择第一个聚类中心indices = np.random.choice(n)centroids[0] = x[indices]# 计算每个点到最近聚类中心的距离distances = np.full(n, np.inf)for i in range(1, self.k):distances = np.minimum(distances, np.linalg.norm(x - centroids[i-1], axis=1)**2)probabilities = distances / distances.sum()cumulative_probabilities = np.cumsum(probabilities)r = np.random.rand()next_index = np.searchsorted(cumulative_probabilities, r)centroids[i] = x[next_index]return centroids

上述代码展示了一个简单的 k-means++ 初始化过程。faiss 通过内部的优化和并行计算,实际实现可能更加高效。

2.2 分配步骤(Assignment)

在每次迭代中,算法需要将每个数据点分配到距离最近的聚类中心。faiss 使用了高效的索引结构来加速这一过程,特别是在高维数据和大规模数据集的情况下。

def assign(self, x):"""将数据点分配到最近的聚类中心。Parameters----------x : array_like数据点,形状为 (n, d),dtype 必须为 float32。Returns-------D : array_like每个数据点到最近聚类中心的距离。I : array_like每个数据点所属的聚类中心索引。"""x = np.ascontiguousarray(x, dtype='float32')assert self.centroids is not None, "Should train before assigning"self.index.reset()self.index.add(self.centroids)D, I = self.index.search(x, 1)return D.ravel(), I.ravel()

faiss 利用了 IndexFlatL2IndexFlatIP 索引,根据是否进行球面聚类(spherical)来选择不同的距离度量。使用 GPU 加速后,可以显著提升大规模数据的分配速度。

2.3 更新步骤(Update)

一旦所有数据点被分配到最近的聚类中心,下一步就是更新这些聚类中心的位置。新的聚类中心通常是分配到该簇的所有数据点的均值。

def update_centroids(self, x, assignments):"""更新聚类中心为每个簇中所有数据点的均值。Parameters----------x : array_like数据点,形状为 (n, d)。assignments : array_like每个数据点所属的聚类中心索引。Returns-------new_centroids : array_like更新后的聚类中心,形状为 (k, d)。"""new_centroids = np.zeros((self.k, self.d), dtype='float32')counts = np.bincount(assignments, minlength=self.k)np.add.at(new_centroids, assignments, x)new_centroids /= counts[:, np.newaxis]return new_centroids

faiss 中,更新步骤同样经过优化以适应大规模数据和高维空间的需求。

2.4 收敛与终止条件

K-Means 算法通过不断迭代分配和更新步骤,直到满足以下任一终止条件:

  • 达到最大迭代次数(niter
  • 聚类中心的变化低于某个阈值(收敛)

faiss 中,Clustering 对象会记录每次迭代的目标函数值(即总的平方误差),并通过比较相邻迭代的目标函数值来判断是否收敛。

def has_converged(self, old_obj, new_obj, threshold=1e-4):"""判断聚类是否收敛。Parameters----------old_obj : float前一次迭代的目标函数值。new_obj : float当前迭代的目标函数值。threshold : float收敛阈值。Returns-------converged : bool是否收敛。"""return abs(old_obj - new_obj) < threshold

faiss 通过内部的 iteration_stats 数组记录每次迭代的详细信息,包括目标函数值、时间消耗等,以便进行后续分析和调优。

3. GPU 加速

faiss 的一大优势在于其对 GPU 的支持,这使得在处理大规模、高维度的数据时,聚类过程能够显著加快。以下是 faissKmeans 类中如何利用 GPU 的一些关键点。

3.1 索引结构与 GPU

train 方法中,根据 spherical 参数的不同,faiss 选择不同的索引结构:

  • IndexFlatL2:用于欧氏距离(L2 距离)
  • IndexFlatIP:用于内积距离(通常用于球面聚类)
if self.cp.spherical:self.index = IndexFlatIP(d)
else:self.index = IndexFlatL2(d)

一旦索引结构确定,faiss 会将其转换为 GPU 索引:

if self.gpu:self.index = faiss.index_cpu_to_all_gpus(self.index, ngpu=self.gpu)

faiss.index_cpu_to_all_gpus 函数会自动将索引复制到所有可用的 GPU 上,充分利用 GPU 的并行计算能力。

3.2 GPU 训练过程

在 GPU 上训练 K-Means 时,faiss 通过以下方式优化计算:

  1. 并行计算距离:利用 GPU 的并行计算能力,快速计算所有数据点到聚类中心的距离。
  2. 高效内存管理:通过 CUDA 流和批处理,最大限度地减少数据传输时间。
  3. 优化的算法实现:利用高效的 CUDA 核函数,优化 K-Means 的各个步骤。

以下是一个使用 GPU 进行训练的示例:

import faiss
import numpy as np# 生成随机数据
d = 128  # 向量维度
k = 100  # 聚类中心数
n = 1000000  # 数据点数量x = np.random.random((n, d)).astype('float32')# 初始化 Kmeans 对象,使用所有可用的 GPU
kmeans = faiss.Kmeans(d=d, k=k, niter=20, verbose=True, gpu=True)# 训练聚类模型
kmeans.train(x)# 获取聚类中心
centroids = kmeans.centroids# 分配数据点到最近的聚类中心
D, I = kmeans.assign(x)

通过上述代码,可以看到 faiss 的 GPU 加速使用非常简洁,只需在初始化时设置 gpu=True 即可。

3.3 多 GPU 训练

对于极大规模的数据集,单个 GPU 可能无法承载全部计算需求。faiss 通过支持多 GPU 训练,进一步提升了聚类的效率。

# 使用指定数量的 GPU 进行训练
ngpu = 4  # 假设有4个 GPU
kmeans = faiss.Kmeans(d=d, k=k, niter=20, verbose=True, gpu=ngpu)
kmeans.train(x)

在多 GPU 环境下,faiss 会将数据和计算任务分配到多个 GPU 上,充分利用并行计算资源,显著缩短聚类时间。

4. 聚类后的操作

完成聚类后,faiss 提供了一些便捷的方法来进行后续操作,例如获取聚类中心、分配新数据点到最近的聚类中心等。

4.1 获取聚类中心

聚类完成后,聚类中心存储在 centroids 属性中,可以方便地进行访问和保存。

# 获取聚类中心
centroids = kmeans.centroids# 保存聚类中心到文件
np.save('centroids.npy', centroids)# 加载聚类中心
loaded_centroids = np.load('centroids.npy')

4.2 分配新数据点

使用训练好的聚类模型,可以将新的数据点快速分配到最近的聚类中心。

# 生成新的随机数据
new_x = np.random.random((10000, d)).astype('float32')# 分配新数据点
D, I = kmeans.assign(new_x)# D 是距离,I 是聚类中心索引

这种分配操作在许多应用场景中非常有用,例如在向量检索系统中,快速定位相似向量的簇。

4.3 评估聚类效果

faiss 记录了每次迭代的目标函数值和其他统计信息,可以用于评估聚类的效果和收敛情况。

# 获取目标函数值
final_obj = kmeans.obj[-1]
print(f"Final objective value: {final_obj}")# 获取详细的迭代统计信息
iteration_stats = kmeans.iteration_stats
for i, stats in enumerate(iteration_stats):print(f"Iteration {i}: Obj={stats['obj']}, Time={stats['time']}, "f"Imbalance={stats['imbalance_factor']}, Nsplit={stats['nsplit']}")

通过分析这些统计信息,可以了解聚类过程中的优化情况和可能的瓶颈。

5. 参数调优与最佳实践

为了获得最佳的聚类效果和性能,合理地调优 Kmeans 类的参数是至关重要的。以下是一些参数调优的建议和最佳实践:

5.1 选择合适的簇数(k)

选择合适的簇数是 K-Means 聚类中的一个关键问题。常用的方法包括:

  • 肘部法则(Elbow Method):绘制不同 k 值下的目标函数值(总的平方误差),选择拐点所在的 k 值。
  • 轮廓系数(Silhouette Coefficient):评估聚类的紧密度和分离度,选择轮廓系数最高的 k 值。
  • 业务需求:根据具体应用场景的需求,选择合适的簇数。
import matplotlib.pyplot as plt# 计算不同 k 值下的目标函数值
ks = range(10, 200, 10)
objs = []
for k in ks:kmeans = faiss.Kmeans(d=d, k=k, niter=20, verbose=False, gpu=True)kmeans.train(x)objs.append(kmeans.obj[-1])# 绘制肘部图
plt.plot(ks, objs, 'bx-')
plt.xlabel('Number of clusters (k)')
plt.ylabel('Objective value')
plt.title('Elbow Method for Optimal k')
plt.show()

5.2 调整迭代次数(niter)

niter 参数决定了聚类算法的最大迭代次数。默认值通常足够,但在某些情况下,增加迭代次数可以获得更好的聚类结果,尤其是在数据集较为复杂时。

# 增加迭代次数以提高聚类精度
kmeans = faiss.Kmeans(d=d, k=k, niter=100, verbose=True, gpu=True)
kmeans.train(x)

5.3 使用 GPU 的优化

在处理大规模数据时,合理配置 GPU 资源可以显著提升性能:

  • 选择合适的 GPU 数量:根据数据规模和硬件资源,选择合适的 GPU 数量进行并行计算。
  • 优化批处理大小:调整批处理大小以充分利用 GPU 的计算能力,避免内存不足或计算资源浪费。
  • 监控 GPU 利用率:使用工具如 nvidia-smi 监控 GPU 的利用率,确保计算资源的高效使用。
# 查看当前 GPU 使用情况
nvidia-smi

5.4 数据预处理

良好的数据预处理可以提升聚类效果和算法的收敛速度:

  • 标准化(Normalization):将数据标准化到相同的尺度,避免某些特征对距离度量的影响过大。
  • 降维(Dimensionality Reduction):使用 PCA 等方法降低数据维度,减少计算量,同时可能提升聚类效果。
  • 去除异常值(Outlier Removal):去除数据中的异常值,避免对聚类结果产生不利影响。
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA# 标准化数据
scaler = StandardScaler()
x_scaled = scaler.fit_transform(x)# 降维
pca = PCA(n_components=100)
x_pca = pca.fit_transform(x_scaled).astype('float32')

6. 实际案例分析

为了更好地理解 faiss 的 K-Means 实现,下面通过一个实际案例进行演示。假设我们有一个大规模的图像特征数据集,目标是将这些特征聚类为多个类别,以便于后续的图像检索或分类任务。

6.1 数据集准备

我们使用 faiss 自带的随机数据作为示例:

import faiss
import numpy as np# 设置随机种子以保证结果可重复
np.random.seed(42)# 生成随机图像特征,假设每个特征是 512 维
d = 512
k = 1000  # 预设的簇数
n = 10000000  # 1千万个数据点# 生成随机数据
x = np.random.random((n, d)).astype('float32')

6.2 聚类模型训练

使用 faissKmeans 类进行聚类训练:

# 初始化 Kmeans 对象,使用所有可用的 GPU
kmeans = faiss.Kmeans(d=d, k=k, niter=20, verbose=True, gpu=True)# 训练聚类模型
kmeans.train(x)# 获取聚类中心
centroids = kmeans.centroids

6.3 聚类结果分析

训练完成后,我们可以分析聚类结果,包括每个簇的大小、聚类中心的分布等。

# 获取每个簇的大小
cluster_sizes = np.bincount(kmeans.assign(x)[1], minlength=k)# 绘制簇大小分布
import matplotlib.pyplot as pltplt.hist(cluster_sizes, bins=50)
plt.xlabel('Cluster Size')
plt.ylabel('Number of Clusters')
plt.title('Distribution of Cluster Sizes')
plt.show()

通过上述分析,可以观察到聚类中心的分布是否均匀,以及是否存在某些簇过大或过小的情况。

6.4 使用聚类结果进行图像检索

聚类结果可以用于加速图像检索。例如,将查询图像的特征首先分配到最近的簇,然后只在该簇内进行详细的相似度计算,从而减少计算量。

def search(query, centroids, kmeans, top_n=10):"""使用聚类结果进行快速图像检索。Parameters----------query : array_like查询图像的特征,形状为 (d,)。centroids : array_like聚类中心,形状为 (k, d)。kmeans : faiss.Kmeans训练好的 Kmeans 对象。top_n : int返回最近的 top_n 个结果。Returns-------indices : array_like最近的图像索引。distances : array_like最近的图像距离。"""# 分配查询到最近的聚类中心D, I = kmeans.assign(query.reshape(1, -1))cluster_idx = I[0]# 获取该簇内所有数据点的索引cluster_mask = (kmeans.assign(x)[1] == cluster_idx)cluster_data = x[cluster_mask]# 计算查询与簇内数据点的距离index = faiss.IndexFlatL2(d)index.add(cluster_data)D, I = index.search(query.reshape(1, -1), top_n)return I.ravel(), D.ravel()# 示例查询
query = np.random.random((d,)).astype('float32')
indices, distances = search(query, centroids, kmeans, top_n=10)
print(f"Top 10 nearest indices: {indices}")
print(f"Top 10 nearest distances: {distances}")

通过这种方式,检索效率得到了显著提升,特别是在处理大规模数据集时。

7. 常见问题与解决方案

在使用 faiss 进行 K-Means 聚类时,可能会遇到一些常见问题。以下是一些常见问题及其解决方案:

7.1 内存不足

问题描述:在处理大规模数据集时,可能会遇到内存不足的问题,尤其是在 CPU 内存或 GPU 显存有限的情况下。

解决方案

  • 下采样:选择数据集的一部分进行聚类训练。
  • 增大 max_points_per_centroid:调整 ClusteringParameters 中的 max_points_per_centroid 参数,控制每个簇的最大数据点数。
  • 使用多 GPU:分散内存压力,利用多个 GPU 分担计算和存储。
# 使用部分数据进行训练
sample_size = 500000  # 50万数据点
indices = np.random.choice(n, sample_size, replace=False)
x_sample = x[indices]kmeans.train(x_sample)

7.2 聚类效果不佳

问题描述:聚类结果不理想,可能是因为聚类中心初始化不当、数据预处理不充分等原因。

解决方案

  • 调整初始化方法:尝试不同的初始化策略,如随机初始化或 k-means++。
  • 数据标准化:对数据进行标准化或归一化处理,确保各特征在相同尺度上。
  • 增加迭代次数:适当增加 niter 参数,允许算法有更多的迭代机会进行优化。
# 标准化数据
from sklearn.preprocessing import StandardScalerscaler = StandardScaler()
x_scaled = scaler.fit_transform(x)# 重新训练聚类模型
kmeans = faiss.Kmeans(d=d, k=k, niter=50, verbose=True, gpu=True)
kmeans.train(x_scaled)

7.3 GPU 资源不足

问题描述:在 GPU 资源有限的情况下,可能无法加载整个数据集或聚类模型。

解决方案

  • 分批处理:将数据集分成多个批次,逐批进行聚类训练。
  • 减少簇数:适当减少簇数,降低 GPU 的计算和存储压力。
  • 升级硬件:如果条件允许,升级 GPU 硬件以满足计算需求。
# 分批训练示例
batch_size = 1000000  # 每批 100万数据点
for i in range(0, n, batch_size):batch = x[i:i+batch_size]kmeans.train(batch)

8. 高级用法与扩展

faiss 不仅支持基本的 K-Means 聚类,还提供了许多高级功能和扩展,适用于更复杂的应用场景。

增量聚类(Incremental Clustering)

在某些应用中,数据是动态增长的,此时需要对新数据进行增量聚类,而不是重新训练整个模型。faiss 提供了相关的 API 支持增量聚类。

# 假设已有初始聚类模型
kmeans = faiss.Kmeans(d=d, k=k, niter=20, verbose=True, gpu=True)
kmeans.train(initial_x)# 增量训练新的数据
kmeans.train(new_x)

分布式聚类

对于极大规模的数据集,单机训练可能无法满足需求。faiss 通过分布式计算框架,支持在多台机器上进行聚类训练。

# 使用 faiss 的分布式 API 进行聚类
# 具体实现依赖于集群环境和分布式框架

自定义距离度量

除了默认的欧氏距离和内积距离,faiss 还支持自定义距离度量,以适应不同的应用需求。

# 定义自定义距离度量函数
def custom_distance(x, y):# 例如,使用曼哈顿距离return np.sum(np.abs(x - y), axis=1)# 在聚类过程中使用自定义距离
# 需要修改 faiss 的内部实现或扩展现有类

与其他算法结合

faiss 的聚类结果可以与其他算法结合,构建更复杂的机器学习管道。例如,将聚类结果作为分类器的输入特征,或结合深度学习模型进行特征学习。

from sklearn.linear_model import LogisticRegression# 使用聚类中心作为分类特征
cluster_features = kmeans.assign(x)[1]# 训练分类模型
clf = LogisticRegression()
clf.fit(cluster_features, labels)

9. 性能优化技巧

为了充分发挥 faiss 的性能,以下是一些性能优化的技巧:

数据存储格式

确保数据以连续的内存块存储,使用 float32 类型,这样可以提高数据访问和计算效率。

x = np.ascontiguousarray(x, dtype='float32')

并行计算

利用 faiss 的多线程和多 GPU 支持,充分利用计算资源。

import faiss# 设置线程数
faiss.omp_set_num_threads(8)# 使用多 GPU
kmeans = faiss.Kmeans(d=d, k=k, niter=20, verbose=True, gpu=4)

预分配内存

对于大规模数据集,预先分配内存可以减少内存碎片和分配时间。

# 预分配聚类中心
centroids = np.empty((k, d), dtype='float32')

缓存优化

确保数据在内存中的布局有利于缓存访问,减少缓存未命中的次数。

# 确保数据按行主序存储
x = np.ascontiguousarray(x, dtype='float32')

Ref

[1] https://github.com/facebookresearch/faiss/wiki/FAQ
[2] https://github.com/facebookresearch/faiss
[3] Faiss GitHub Repository
[4] K-Means++: The Advantages of Careful Seeding
[5] Scikit-learn Clustering Documentation

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/430167.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

C++笔记21•C++11的新特性•

相比于 C98/03&#xff0c;C11则带来了数量可观的变化&#xff0c;其中包含了约140个新特性&#xff0c;以及对C03标准中约600个缺陷的修正&#xff0c;这使得C11更像是从C98/03中孕育出的一种新语言。相比较而言&#xff0c;C11能更好地用于系统开发和库开发、语法更加泛华和简…

如何合并pdf文件,四款软件,三步搞定!

在数字化办公的浪潮中&#xff0c;PDF文档因其跨平台兼容性和安全性&#xff0c;成为了我们日常工作中不可或缺的一部分。然而&#xff0c;面对多个PDF文件需要整合成一个文件时&#xff0c;不少小伙伴可能会感到头疼。别担心&#xff0c;今天我们就来揭秘四款高效PDF合并软件&…

演示:基于WPF的DrawingVisual开发的Chart图表和表格绘制

一、目的&#xff1a;基于WPF的DrawingVisual开发的Chart图表和表格绘制 二、预览 钻井井轨迹表格数据演示示例&#xff08;应用Table布局&#xff0c;模拟井轨迹深度的绘制&#xff09; 饼图表格数据演示示例&#xff08;应用Table布局&#xff0c;模拟多个饼状图组合显示&am…

尚品汇-秒杀商品定时任务存入缓存、Redis发布订阅实现状态位(五十一)

目录&#xff1a; &#xff08;1&#xff09;秒杀业务分析 &#xff08;2&#xff09;搭建秒杀模块 &#xff08;3&#xff09;秒杀商品导入缓存 &#xff08;4&#xff09;redis发布与订阅实现 &#xff08;1&#xff09;秒杀业务分析 需求分析 所谓“秒杀”&#xff0…

又到了金九银十,你的简历写好了吗?

又到了金九银十的招聘季&#xff0c;不过这几年求职环境越来越差&#xff0c;相比于跳槽找新机会&#xff0c;大家可能更倾向于守住自己手头的工作&#xff0c;稳字当头。当然&#xff0c;也有很多工作实在干烦了的朋友&#xff0c;想要换个新赛道试试。今天就给大家带来一个新…

django实现开发、测试、生产环境配置区分

文章目录 一、为什么要区分开发 (dev)、测试 (test) 和生产 (prod) 环境二、django项目如何通过配置实现环境配置的区分1、针对不同的环境创建不同的设置文件settings.py2、在设置文件中根据需要进行配置区分3、根据不同的环境运行使用不同的设置文件 任何实际的软件项目中都要…

【中级通信工程师】终端与业务(二):终端产品

【零基础3天通关中级通信工程师】 终端与业务(二)&#xff1a;终端产品 本文是中级通信工程师考试《终端与业务》科目第二章《终端产品》的复习资料和真题汇总。终端与业务是通信考试里最简单的科目&#xff0c;有效复习通过率可达90%以上&#xff0c;本文结合了高频考点和近几…

医学数据分析实训 项目三 关联规则分析作业--在线购物车分析--痹症方剂用药规律分析

文章目录 项目三 关联规则分析一、实践目的二、实践平台三、实践内容任务一&#xff1a;在线购物车分析&#xff08;一&#xff09;数据读入&#xff08;二&#xff09;数据理解&#xff08;三&#xff09;数据预处理&#xff08;四&#xff09;生成频繁项集&#xff08;五&…

基于微信小程序的美食外卖管理系统

作者&#xff1a;计算机学姐 开发技术&#xff1a;SpringBoot、SSM、Vue、MySQL、JSP、ElementUI、Python、小程序等&#xff0c;“文末源码”。 专栏推荐&#xff1a;前后端分离项目源码、SpringBoot项目源码、Vue项目源码、SSM项目源码 精品专栏&#xff1a;Java精选实战项目…

分享两个虚拟试衣工具,一个在线,一个离线,还有ComfyUI插件

SAM &#xff0c;对不住了&#xff01; 我没记错的话&#xff0c;OpenAI CEO&#xff0c;性别男&#xff0c;取向男&#xff0c;配偶男。 这又让我联想到了苹果CEO库克... 所以OpenAI和Apple可以一啪即合。 钢铁直男老马就和他们都不对付~~ 开个玩笑&#xff0c;聊…

WebGL入门(一)绘制一个点

源码&#xff1a; <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>Document</title><scr…

Web+Mysql——MyBatis

MyBatis 目标 能够完成Mybatis代理方式查询数据能够理解Mybatis核心配置文件的配置 1&#xff0c;Mybatis 1.1 Mybatis概述 1.1.1 Mybatis概念 MyBatis 是一款优秀的持久层框架&#xff0c;用于简化 JDBC 开发 MyBatis 本是 Apache 的一个开源项目iBatis, 2010年这个项目由…

高等代数笔记(2)————(弱/强)数学归纳法

数学归纳法的引入情景其实很简单&#xff0c;就是多米诺骨牌。 推倒所有多米诺骨牌的关键就是推倒第一块&#xff0c;以及确保第一块倒下后会带动第二块&#xff0c;第二块带动第三块&#xff0c;以此类推&#xff0c;也就是可以递推。由此我们可以归纳出所有的多米诺骨牌都可…

MySQL学习(索引)

文章目录 基本概念单列索引普通索引&#xff08;index&#xff09;唯一索引&#xff08;unique&#xff09;主键索引 组合索引全文索引&#xff08;fulltext&#xff09;空间索引&#xff08;spatial&#xff09;MySQL存储引擎 基本概念 通过某种算法&#xff0c;构建数据模型&…

LeetCode 2374.边积分最高的节点:模拟

【LetMeFly】2374.边积分最高的节点&#xff1a;模拟 力扣题目链接&#xff1a;https://leetcode.cn/problems/node-with-highest-edge-score/ 给你一个有向图&#xff0c;图中有 n 个节点&#xff0c;节点编号从 0 到 n - 1 &#xff0c;其中每个节点都 恰有一条 出边。 图…

k8s中pod的创建过程和阶段状态

管理k8s集群 kubectl k8s中有两种用户 一种是登录的 一种是/sbin/nologin linux可以用密码登录&#xff0c;也可以用证书登录 k8s只能用证书登录 谁拿到这个证书&#xff0c;谁就可以管理集群 在k8s中&#xff0c;所有节点都被网络组件calico设置了路由和通信 所以pod的ip是可以…

如何使用 maxwell 同步到 redis?

文章目录 1、MaxwellListener2、MxwObject1. 使用Maxwell捕获MySQL变更2. 将Maxwell的输出连接到消息系统3. 从消息系统读取数据并同步到Redis注意事项 1、MaxwellListener package com.atguigu.tingshu.album.listener;import com.alibaba.fastjson.JSON; import org.apache.…

mysql中的json查询

首先来构造数据 查询department里面name等于研发部的数据 查询语句跟普通的sql语句差不多&#xff0c;也就是字段名要用到path表达式 select * from user u where u.department->$.name 研发部 模糊查询 select * from user u where u.department->$.name like %研发%…

Go-知识recover

Go-知识recover 1. 介绍2. 工作机制2.1 recover 定义2.2 工作流程2.3 总结 3. 原理3.1 recover函数的真正逻辑3.2 恢复逻辑3.3 生效条件 4. 总结4.1 recover的返回值是什么&#xff1f;4.2 执行recover之后程序将从哪里继续运行&#xff1f;4.3 recover为什么一定要在defer中使…

无法删除选定的端口,不支持请求【笔记】

场景&#xff1a;在删除打印机端口时&#xff0c;提示&#xff1a;“无法删除选定的端口&#xff0c;不支持请求”&#xff0c;如下图所示。 以下以删除USB036端口为示例&#xff0c;操作步骤如下&#xff1a; 在注册表编辑器中&#xff0c;从以下注册表项中“计算机\HKEY_LO…