聚类算法之MeanShift

时间:2019-11-27
本文章向大家介绍聚类算法之MeanShift,主要包括聚类算法之MeanShift使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。

机器学习的研究方向主要分为三大类:聚类,分类与回归。

MeanShift作为聚类方法之一,在视觉领域有着广泛的应用,尤其是作为深度学习回归后的后处理模块而存在着。

接下来,我们先介绍下基本功能流程,然后会用代码的形式来分析。

一、功能流程:

    MeanShift是基于滑窗的算法,尝试找到数据点密集的区域。

    通过感兴趣区域内的数据密度变化计算中心点的漂移向量,从而移动中心点进行下一次迭代,直到到达密度最大处(中心点不变)。从每个数据点出发都可以进行该操作,统计出现在感兴趣区域内数据的次数,该参数将在最后作为分类的依据。

    步骤:

           1)在未被标记的数据点中随机选择一个点作为起始中心点center;

           2)找出以center为中心半径为radius的区域中出现的所有数据点,认为这些点同属于一个聚类C,同时在该聚类中记录数据点出现的次数加1;

           3)以center为中心点,计算从center开始到集合M中每个元素的向量,将这些向量相加,得到向量shift;

           4)center=center+shift,即center沿着shift方向移动,移动距离为||shift||;

           5)重复步骤2,3,4,直到shift很小,记得此时的center。注意,这个迭代过程中遇到的点都应该归类到簇C;

           6)如果收敛时当前簇C的center与其它已经存在的簇C2中心的距离小于阈值,那么把C2与C合并,数据点出现次数也对应合并。否则把C作为新的聚类;

           7)重复1,2,3,4,5直到所有点都被标记为已访问;

           8)分类:根据每个类,对每个点的访问频率,取访问频率最大的那个类,作为当前点集的所属类。

   通过在数据集中寻找被低密度区域分离的高密度区域,将分离出的高密度区域作为一个独立的类别。   

   优点:可自动决定类别的数目。

二、代码流程:

from sklearn.cluster import MeanShift

    @staticmethod
    def _cluster(prediction, bandwidth):
        ms = MeanShift(bandwidth, bin_seeding=True)
        # log.info('开始Mean shift聚类 ...')
        tic = time.time()
        try:
            ms.fit(prediction)
        except ValueError as err:
            log.error(err)
            return 0, [], []
        # log.info('Mean Shift耗时: {:.5f}s'.format(time.time() - tic))
        labels = ms.labels_
        cluster_centers = ms.cluster_centers_

        num_clusters = cluster_centers.shape[0]

        # log.info('聚类簇个数为: {:d}'.format(num_clusters))

        return num_clusters, labels, cluster_centers

原文地址:https://www.cnblogs.com/jimchen1218/p/11940340.html