SparkML模型选择(超参数调整)与调优
Spark ML模型选择与调优
本文主要讲解如何使用Spark MLlib的工具去调优ML算法和Pipelines。内置的交叉验证和其他工具允许用户优化算法和管道中的超参数。
模型选择(又称为超参数调整)
ML中的一个重要任务是模型选择,或者使用数据来找出给定任务的最佳模型或参数。这也被称为调优。可以针对单个独立的Estimator进行调优,例如LogisticRegression,也可以针对整个Pipeline进行调优。用户可以一次针对整个pipeline进行调优,而不是单独调优pipeline内部的元素。
Mllib支持模型选择,可以使用工具CrossValidator 和TrainValidationSplit,这些工具支持下面的条目:
Estimator:需要调优的算法或者pipeline。
ParamMaps的集合:可供选择的参数,有时称为用来搜索“参数网格”
Evaluator:度量标准来衡量一个拟合Model在测试数据上的表现
在高层面上,这些模型选择工具的作用如下:
- 他们将输入数据分成单独的训练和测试数据集
- 对每个(训练,测试)对,他们迭代遍历ParamMaps集合:对于每一个ParamMap,他们使用这些参数调用Estimator的fit,得到拟合Model,并使用Evaluator评估Model的性能。
- 他们选择由产生的最佳性能参数生成的模型。
Evaluator可以是RegressionEvaluator 用于回归问题中,BinaryClassificationEvaluator 对于二分类,或MulticlassClassificationEvaluator 为多类问题。用于选择最佳值ParamMap的默认度量指标可以被evaluators的setMetricName方法覆盖。
Cross-Validation-交叉验证
CrossValidator开始的时候会将数据分割成很多测试集和训练集对儿。例如,k=3folds,crossValidator将会产生三组(training,test)数据集对儿,没对都是2/3用来训练,1/3用来测试。为了评估出一个组特殊的paramMap,crossValidator 会计算通过Estimator在三组不同数据集上调用fit产生的3个模型的平均评估指标。
确定最佳ParamMap后,CrossValidator最后使用最佳ParamMap和整个数据集重新拟合Estimator。
例子
以下示例演示如何使用CrossValidator从参数网格中进行选择。
请注意,参数网格上的交叉验证非常耗性能的。例如,在下面的例子中,参数网格中hashingTF.numFeatures有三个值,并且lr.regParam两个值,CrossValidator使用了2folds。将会倍增到(3×2)×2=12模型需要训练。在现实的设置中,尝试更多的参数并且使用更多的folds(k=3,k=10是非常常见的)。换句话说使用交叉验证代价是非常大的。然而,它也是一个比较合理的方法,用于选择比启发式手调整更具统计稳健性的参数。
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import org.apache.spark.ml.feature.{HashingTF, Tokenizer}
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder}
import org.apache.spark.sql.Row
//准备训练数据,格式(id,text,label)
val training = spark.createDataFrame(Seq(
(0L, "a b c d e spark", 1.0),
(1L, "b d", 0.0),
(2L, "spark f g h", 1.0),
(3L, "hadoop mapreduce", 0.0),
(4L, "b spark who", 1.0),
(5L, "g d a y", 0.0),
(6L, "spark fly", 1.0),
(7L, "was mapreduce", 0.0),
(8L, "e spark program", 1.0),
(9L, "a e c l", 0.0),
(10L, "spark compile", 1.0),
(11L, "hadoop software", 0.0)
)).toDF("id", "text", "label")
//配置一个ML pipeline,总共有三个stages:tokenizer, hashingTF, and lr
val tokenizer = new Tokenizer()
.setInputCol("text")
.setOutputCol("words")
val hashingTF = new HashingTF()
.setInputCol(tokenizer.getOutputCol)
.setOutputCol("features")
val lr = new LogisticRegression()
.setMaxIter(10)//输入label,features,prediction均可采用默认值名称。
val pipeline = new Pipeline()
.setStages(Array(tokenizer, hashingTF, lr))
//用ParamGridBuilder构建一个查询用的参数网格
//hashingTF.numFeatures有三个值,lr.regParam有两个值,
//该网格将会有3*2=6组参数被CrossValidator使用
val paramGrid = new ParamGridBuilder()
.addGrid(hashingTF.numFeatures, Array(10, 100, 1000))
.addGrid(lr.regParam, Array(0.1, 0.01))
.build()
//这里对将整个PipeLine视为一个Estimator
//这种方式允许我们联合选择这个Pipeline stages参数
//一个CrossValidator需要一个Estimator,一组Estimator ParamMaps,一个Evaluator。
//这个Evaluator是一个BinaryClassificationEvaluator,它默认度量是areaUnderROC
val cv = new CrossValidator()
.setEstimator(pipeline)
.setEvaluator(new BinaryClassificationEvaluator)
.setEstimatorParamMaps(paramGrid)
.setNumFolds(2) // 生产中使用3+
// 运行交叉验证,选择最佳参数
val cvModel = cv.fit(training)
//准备测试文档,这些文档是未打标签的
val test = spark.createDataFrame(Seq(
(4L, "spark i j k"),
(5L, "l m n"),
(6L, "mapreduce spark"),
(7L, "apache hadoop")
)).toDF("id", "text")
//使用训练好的最佳模型,去对测试集进行预测。
cvModel.transform(test)
.select("id", "text", "probability", "prediction")
.collect()
.foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) =>
println(s"($id, $text) --> prob=$prob, prediction=$prediction")
}
查看预测结果
TrainValidationSplit
除了CrossValidator,spark还提供了TrainValidationSplit用于超参数的调整。TrainValidationSplit只对一次参数的每个组合进行一次评估,与CrossValidator的k词调整相对。真就意味着代价相对少了一些,当训练集不是很大的时候,将不会产生一个可靠的结果。
不像CrossValidator,TrainValidationSplit产生一个(training,test)数据集对。通过使用trainRatio参数将数据集分割成两个部分。例如,trainRatio=0.75, TrainValidationSplit将会产生一个训练集和一个测试集,其中75%数据用来训练,25%数据用来验证。
和CrossValidator一样, TrainValidationSplit在最后会使用最佳的参数和整个数据集对Estimator进行拟合。
例子
import org.apache.spark.ml.evaluation.RegressionEvaluator
import org.apache.spark.ml.regression.LinearRegression
import org.apache.spark.ml.tuning.{ParamGridBuilder, TrainValidationSplit}
// 准测试数据
val data = spark.read.format("libsvm").load("data/mllib/sample_linear_regression_data.txt")
val Array(training, test) = data.randomSplit(Array(0.9, 0.1), seed = 12345)
val lr = new LinearRegression()
.setMaxIter(10)
// We use a ParamGridBuilder to construct a grid of parameters to search over.
// TrainValidationSplit will try all combinations of values and determine best model using
// the evaluator.
//使用ParamGridBuilder构建一个parameters网格,用来存储查询参数
//TrainValidationSplit会尝试所有值的组合使用evaluator来产生一个最佳模型
val paramGrid = new ParamGridBuilder()
.addGrid(lr.regParam, Array(0.1, 0.01))
.addGrid(lr.fitIntercept)
.addGrid(lr.elasticNetParam, Array(0.0, 0.5, 1.0))
.build()
// A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, and an Evaluator.
//在这个例子中,Estimator选用简单的线性回归模型
val trainValidationSplit = new TrainValidationSplit()
.setEstimator(lr)
.setEvaluator(new RegressionEvaluator)
.setEstimatorParamMaps(paramGrid)
//80%数据用来训练,20%用来验证
.setTrainRatio(0.8)
//运行TrainValidationSplit,选出最佳参数
val model = trainValidationSplit.fit(training)
//对测试数据进行预测。参数就是刚刚训练的最佳参数。
model.transform(test)
.select("features", "label", "prediction")
.show()
- 菜单常用:复位全部并设置某个项的样式
- Mysql更换MyISAM存储引擎为Innodb的操作记录
- 比特币分叉倒计时,糖果福利又来了
- 执行git push出现"Everything up-to-date"
- linux下EOF写法梳理
- 用AngularJS来实现异步数据的购物车功能设计
- span不如div的地方
- 分布式监控系统Zabbix--完整安装记录(7)-使用percona监控MySQL
- 10x Python开发者必读:本月Python文章TOP 10
- Linux下更换默认yum源为网易yum源的操作记录
- yum源使用的几个报错小总结
- JQuery笔记(一)
- Haproxy和Nginx负载均衡测试效果对比记录
- JQuery笔记(三) jquery的用途
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- 使用ES6的fetch API读取数据时要注意的一个和cookie相关的坑
- 跟牛老师一起学WEBGIS——WEBGIS基础(地图切片)
- Go语言 | 并发设计中的同步锁与waitgroup用法
- LeetCode 99 | 如何不用递归遍历二叉搜索树?MT方法给你答案
- 以攻击者角度学习某风控设备指纹产品
- 高并发系统三大利器之缓存
- 前端测试题:(解析)js中关于类(class)的继承的说法,下面错误的是?
- 程序员深夜惨遭老婆鄙视,原因竟是CAS原理太简单?| 每一张图都力求精美
- MySQL数据延迟跳动的问题分析
- Python GUI项目实战(八)修改密码功能的实现
- Prometheus监控神器-Alertmanager篇(3)
- Prometheus监控神器-Alertmanager篇(4)
- 71-STM32+ESP8266+AIR202基本控制篇-移植使用-移植微信小程序MQTT底层包到自己的工程项目
- 目标检测 | Anchor free之CornerNet网络深度解析
- 手把手教你 3 分钟搞定个人网站 http 免费升级到 https