Spark MLlib之 KMeans聚类算法详解
问题导读 1.什么是Spark MLlib ? 2.Spark MLlib 分为哪些类? 3.KMeans算法的基本思想是什么? 4.Spark Mllib KMeans源码包含哪些内容? 一直想学习下Spark 的机器学习,今天总结整理下。 1.什么是Spark MLlib MLlib 是Spark对常用的机器学习算法的实现库,同时包括相关的测试和数据生成器。 2.Spark MLlib 分类 MLlib 目前支持四种常见的机器学习问题:二元分类,回归,聚类以及协同过滤,同时也包括一个底层的梯度下降优化基础算法。 我们知道了分类,这里重点介绍聚类 3.KMeans算法的基本思想 KMeans算法的基本思想是初始随机给定K个簇中心,按照最邻近原则把待分类样本点分到各个簇。然后按平均法重新计算各个簇的质心,从而确定新的簇心。一直迭代,直到簇心的移动距离小于某个给定的值。 K-Means聚类算法主要分为三个步骤: (1)第一步是为待聚类的点寻找聚类中心; (2)第二步是计算每个点到聚类中心的距离,将每个点聚类到离该点最近的聚类中去; (3)第三步是计算每个聚类中所有点的坐标平均值,并将这个平均值作为新的聚类中心; 反复执行(2)、(3),直到聚类中心不再进行大范围移动或者聚类次数达到要求为止。 4.过程演示
下图展示了对n个样本点进行K-means聚类的效果,这里k取2:
(a)未聚类的初始点集;
(b)随机选取两个点作为聚类中心;
(c)计算每个点到聚类中心的距离,并聚类到离该点最近的聚类中去;
(d)计算每个聚类中所有点的坐标平均值,并将这个平均值作为新的聚类中心;
(e)重复(c),计算每个点到聚类中心的距离,并聚类到离该点最近的聚类中去;
(f)重复(d),计算每个聚类中所有点的坐标平均值,并将这个平均值作为新的聚类中心。
5.Spark Mllib KMeans源码分析
class KMeansprivate (
privatevar k: Int,
privatevar maxIterations: Int,
privatevar runs: Int,
privatevar initializationMode: String,
privatevar initializationSteps: Int,
privatevar epsilon: Double,
privatevar seed: Long)extends Serializablewith Logging {
// KMeans类参数:
k:聚类个数,默认2;maxIterations:迭代次数,默认20;runs:并行度,默认1;
initializationMode:初始中心算法,默认"k-means||";initializationSteps:初始步长,默认5;epsilon:中心距离阈值,默认1e-4;seed:随机种子。
/**
* Constructs a KMeans instance with default parameters: {k: 2, maxIterations: 20, runs: 1,
* initializationMode: "k-means||", initializationSteps: 5, epsilon: 1e-4, seed: random}.
*/
defthis() =this(2,20, 1, KMeans.K_MEANS_PARALLEL,5, 1e-4, Utils.random.nextLong())
// 参数设置
/** Set the number of clusters to create (k). Default: 2. */
def setK(k: Int):this.type = {
this.k = k
this
}
**省略各个参数设置代码**
// run方法,KMeans主入口函数
/**
* Train a K-means model on the given set of points; `data` should be cached for high
* performance, because this is an iterative algorithm.
*/
def run(data: RDD[Vector]): KMeansModel = {
if (data.getStorageLevel == StorageLevel.NONE) {
logWarning("The input data is not directly cached, which may hurt performance if its"
+ " parent RDDs are also uncached.")
}
// Compute squared norms and cache them.
// 计算每行数据的L2范数,数据转换:data[Vector]=> data[(Vector, norms)],其中norms是Vector的L2范数,norms就是:。
val norms = data.map(Vectors.norm(_,2.0))
norms.persist()
val zippedData = data.zip(norms).map {case (v, norm) =>
new VectorWithNorm(v, norm)
}
val model = runAlgorithm(zippedData)
norms.unpersist()
// Warn at the end of the run as well, for increased visibility.
if (data.getStorageLevel == StorageLevel.NONE) {
logWarning("The input data was not directly cached, which may hurt performance if its"
+ " parent RDDs are also uncached.")
}
model
}
// runAlgorithm方法,KMeans实现方法。
/**
* Implementation of K-Means algorithm.
*/
privatedef runAlgorithm(data: RDD[VectorWithNorm]): KMeansModel = {
val sc = data.sparkContext
val initStartTime = System.nanoTime()
val centers =if (initializationMode == KMeans.RANDOM) {
initRandom(data)
} else {
initKMeansParallel(data)
}
val initTimeInSeconds = (System.nanoTime() - initStartTime) /1e9
logInfo(s"Initialization with $initializationMode took " +"%.3f".format(initTimeInSeconds) +
" seconds.")
val active = Array.fill(runs)(true)
val costs = Array.fill(runs)(0.0)
var activeRuns =new ArrayBuffer[Int] ++ (0 until runs)
var iteration =0
val iterationStartTime = System.nanoTime()
//KMeans迭代执行,计算每个样本属于哪个中心点,中心点累加样本的值及计数,然后根据中心点的所有的样本数据进行中心点的更新,并比较更新前的数值,判断是否完成。其中runs代表并行度。
// Execute iterations of Lloyd's algorithm until all runs have converged
while (iteration < maxIterations && !activeRuns.isEmpty) {
type WeightedPoint = (Vector, Long)
def mergeContribs(x: WeightedPoint, y: WeightedPoint): WeightedPoint = {
axpy(1.0, x._1, y._1)
(y._1, x._2 + y._2)
}
val activeCenters = activeRuns.map(r => centers(r)).toArray
val costAccums = activeRuns.map(_ => sc.accumulator(0.0))
val bcActiveCenters = sc.broadcast(activeCenters)
// Find the sum and count of points mapping to each center
//计算属于每个中心点的样本,对每个中心点的样本进行累加和计算;
runs代表并行度,k中心点个数,sums代表中心点样本累加值,counts代表中心点样本计数;
contribs代表((并行度I,中心J),(中心J样本之和,中心J样本计数和));
findClosest方法:找到点与所有聚类中心最近的一个中心;
val totalContribs = data.mapPartitions { points =>
val thisActiveCenters = bcActiveCenters.value
val runs = thisActiveCenters.length
val k = thisActiveCenters(0).length
val dims = thisActiveCenters(0)(0).vector.size
val sums = Array.fill(runs, k)(Vectors.zeros(dims))
val counts = Array.fill(runs, k)(0L)
points.foreach { point =>
(0 until runs).foreach { i =>
val (bestCenter, cost) = KMeans.findClosest(thisActiveCenters(i), point)
costAccums(i) += cost
val sum = sums(i)(bestCenter)
axpy(1.0, point.vector, sum)
counts(i)(bestCenter) += 1
}
}
val contribs =for (i <-0 until runs; j <-0 until k) yield {
((i, j), (sums(i)(j), counts(i)(j)))
}
contribs.iterator
}.reduceByKey(mergeContribs).collectAsMap()
//更新中心点,更新中心点= sum/count;
判断newCenter与centers之间的距离是否 > epsilon * epsilon;
// Update the cluster centers and costs for each active run
for ((run, i) <- activeRuns.zipWithIndex) {
var changed =false
var j =0
while (j < k) {
val (sum, count) = totalContribs((i, j))
if (count !=0) {
scal(1.0 / count, sum)
val newCenter =new VectorWithNorm(sum)
if (KMeans.fastSquaredDistance(newCenter, centers(run)(j)) > epsilon * epsilon) {
changed = true
}
centers(run)(j) = newCenter
}
j += 1
}
if (!changed) {
active(run) = false
logInfo("Run " + run +" finished in " + (iteration +1) + " iterations")
}
costs(run) = costAccums(i).value
}
activeRuns = activeRuns.filter(active(_))
iteration += 1
}
val iterationTimeInSeconds = (System.nanoTime() - iterationStartTime) /1e9
logInfo(s"Iterations took " +"%.3f".format(iterationTimeInSeconds) +" seconds.")
if (iteration == maxIterations) {
logInfo(s"KMeans reached the max number of iterations: $maxIterations.")
} else {
logInfo(s"KMeans converged in $iteration iterations.")
}
val (minCost, bestRun) = costs.zipWithIndex.min
logInfo(s"The cost for the best run is $minCost.")
new KMeansModel(centers(bestRun).map(_.vector))
}
//findClosest方法:找到点与所有聚类中心最近的一个中心;
/**
* Returns the index of the closest center to the given point, as well as the squared distance.
*/
private[mllib]def findClosest(
centers: TraversableOnce[VectorWithNorm],
point: VectorWithNorm): (Int, Double) = {
var bestDistance = Double.PositiveInfinity
var bestIndex =0
var i =0
centers.foreach { center =>
// Since `|a - b| geq ||a| - |b||`, we can use this lower bound to avoid unnecessary
// distance computation.
var lowerBoundOfSqDist = center.norm - point.norm
lowerBoundOfSqDist = lowerBoundOfSqDist * lowerBoundOfSqDist
if (lowerBoundOfSqDist < bestDistance) {
val distance: Double = fastSquaredDistance(center, point)
if (distance < bestDistance) {
bestDistance = distance
bestIndex = i
}
}
i += 1
}
(bestIndex, bestDistance)
}
findClosest方法中:var lowerBoundOfSqDist = center.norm - point.norm
lowerBoundOfSqDist = lowerBoundOfSqDist * lowerBoundOfSqDist
如果中心点center是(a1,b1),需要计算的点point是(a2,b2),那么lowerBoundOfSqDist是:
如下是展开式,第二个是真正计算欧式距离时的除去开平方的公式。(在查找最短距离的时候无需计算开方,因为只需要计算出开方里面的式子就可以进行比较了,mllib也是这样做的)
可轻易证明上面两式的第一式将会小于等于第二式,因此在进行距离比较的时候,先计算很容易计算的lowerBoundOfSqDist,如果lowerBoundOfSqDist都不小于之前计算得到的最小距离bestDistance,那真正的欧式距离也不可能小于bestDistance了,因此这种情况下就不需要去计算欧式距离,省去很多计算工作。
如果lowerBoundOfSqDist小于了bestDistance,则进行距离的计算,调用fastSquaredDistance,这个方法将调用MLUtils.scala里面的fastSquaredDistance方法,计算真正的欧式距离,代码如下:
/**
* Returns the squared Euclidean distance between two vectors. The following formula will be used
* if it does not introduce too much numerical error:
* <pre>
* |a - b|_2^2 = |a|_2^2 + |b|_2^2 - 2 a^T b.
* </pre>
* When both vector norms are given, this is faster than computing the squared distance directly,
* especially when one of the vectors is a sparse vector.
*
* @param v1 the first vector
* @param norm1 the norm of the first vector, non-negative
* @param v2 the second vector
* @param norm2 the norm of the second vector, non-negative
* @param precision desired relative precision for the squared distance
* @return squared distance between v1 and v2 within the specified precision
*/
private[mllib]def fastSquaredDistance(
v1: Vector,
norm1: Double,
v2: Vector,
norm2: Double,
precision: Double = 1e-6): Double = {
val n = v1.size
require(v2.size == n)
require(norm1 >= 0.0 && norm2 >=0.0)
val sumSquaredNorm = norm1 * norm1 + norm2 * norm2
val normDiff = norm1 - norm2
var sqDist =0.0
/*
* The relative error is
* <pre>
* EPSILON * ( |a|_2^2 + |b\_2^2 + 2 |a^T b|) / ( |a - b|_2^2 ),
* </pre>
* which is bounded by
* <pre>
* 2.0 * EPSILON * ( |a|_2^2 + |b|_2^2 ) / ( (|a|_2 - |b|_2)^2 ).
* </pre>
* The bound doesn't need the inner product, so we can use it as a sufficient condition to
* check quickly whether the inner product approach is accurate.
*/
val precisionBound1 =2.0 * EPSILON * sumSquaredNorm / (normDiff * normDiff + EPSILON)
if (precisionBound1 < precision) {
sqDist = sumSquaredNorm - 2.0 * dot(v1, v2)
} elseif (v1.isInstanceOf[SparseVector] || v2.isInstanceOf[SparseVector]) {
val dotValue = dot(v1, v2)
sqDist = math.max(sumSquaredNorm - 2.0 * dotValue,0.0)
val precisionBound2 = EPSILON * (sumSquaredNorm +2.0 * math.abs(dotValue)) /
(sqDist + EPSILON)
if (precisionBound2 > precision) {
sqDist = Vectors.sqdist(v1, v2)
}
} else {
sqDist = Vectors.sqdist(v1, v2)
}
sqDist
}
fastSquaredDistance方法会先计算一个精度,有关精度的计算val precisionBound1 = 2.0 * EPSILON * sumSquaredNorm / (normDiff * normDiff + EPSILON),如果在精度满足条件的情况下,欧式距离sqDist = sumSquaredNorm - 2.0 * v1.dot(v2),sumSquaredNorm即为
,2.0 * v1.dot(v2)即为
。这也是之前将norm计算出来的好处。如果精度不满足要求,则进行原始的距离计算公式了
,即调用Vectors.sqdist(v1, v2)。
6.Mllib KMeans实例
1、数据
数据格式为:特征1 特征2 特征3 0.0 0.0 0.0 0.1 0.1 0.1 0.2 0.2 0.2 9.0 9.0 9.0 9.1 9.1 9.1 9.2 9.2 9.2
2、代码
//1读取样本数据
valdata_path ="/home/jb-huangmeiling/kmeans_data.txt"
valdata =sc.textFile(data_path)
valexamples =data.map { line =>
Vectors.dense(line.split(' ').map(_.toDouble))
}.cache()
valnumExamples =examples.count()
println(s"numExamples = $numExamples.")
//2建立模型
valk =2
valmaxIterations =20
valruns =2
valinitializationMode ="k-means||"
valmodel = KMeans.train(examples,k, maxIterations,runs, initializationMode)
//3计算测试误差
valcost =model.computeCost(examples)
println(s"Total cost = $cost.")
参考:
Spark MLlib KMeans聚类算法
作者:sunbow0
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- 搞它!!!2020年了,你还不会PXE+kickstart 一键式部署安装系统么
- shell脚本快速入门系列—————— shell脚本编程规范
- shell脚本快速入门系列之------条件语句(if、case)
- 搞它!!!2020年了你还不会Cobbler自动装机么(装机步骤,优化内容详解,导入系统镜像步骤,cobbler-web管理认证方式
- 搞它!!!深入了解FTP文件传输服务
- 搞它!!!CentOS 7.6 安装和配置samba文件共享服务
- shell脚本快速入门系列之------变量
- 弄它!!! 深入了解STP生成树协议
- kali linux下的常用bash命令
- shell脚本快速入门之-----linux设置 自定义脚本开机启动,一键式部署网卡配置文件
- jdbc连接oracle语法
- java实现数据库连接的工具类
- shell脚本快速入门之-----正则三剑客之一grep用法大全!!!
- 【网页特效】11 个文本输入和 6 个按钮操作 特效库
- shell脚本快速入门之-----正则三剑客之二sed用法大全!!!