MLlib中的随机森林和提升方法
本帖是与来自于Origami Logic 的Manish Amd共同撰写的。
Apache Spark 1.2将随机森林和梯度提升树(GBT)引入到MLlib中。这两个算法适用于分类和回归,是最成功的且被广泛部署的机器学习方法之一。随机森林和GBT是两类集成学习算法,它们结合了多个决策树,以生成更强大的模型。在这篇文章中,我们将描述这些模型和它们在MLlib中的分布式实现。我们还展示了一些简单的例子,并提供了一些我们该如何开始学习的建议。
集成方法
简而言之,集成学习算法通过组合不同的模型,是建立在其他机器学习方法之上的算法。这种组合可以比任意的单个模型更加强大且准确。
在MLlib 1.2中,我们使用决策树作为基础模型。我们提供了两种集成方法:随机森林和梯度提升树(GBT)。这两种算法的主要区别在于集成模型中每个树部件的训练顺序。
随机森林使用数据的随机样本独立地训练每棵树。这种随机性有助于使模型比单个决策树更健壮,而且不太可能会在训练数据上过拟合。
GBT(梯度提升树)每次只训练一棵树,每棵新树帮助纠正先前训练过的树所产生的错误。随着每一棵新树的加入,模型变得更加具有表现力。
最后,这两种方法都会产生一个决策树的加权集合。集成模型通过结合所有单个树的结果进行预测。下图显示了一个采用三棵树进行集成的简单例子。
在上面的集成回归的例子中,每棵树都预测了一个实值。然后将这三个预测结合起来获得集成模型的最终预测。在这里,我们使用均值来将结合不同的预测值(但具体的算法设计时,需要根据预测任务的特点来使用不同的技术)。
分布式集成学习
在MLlib中,随机森林和GBT(梯度提升树)通过实例(行)来对数据进行划分。该实现建立在最初的决策树代码之上,该代码实现了单个决策树的学习(在较早的博客文章中进行了描述)。我们的许多优化都基于Google的PLANET项目,这是发表过的、在分布式环境下进行决策树集成学习的主要作品之一。
随机森林:由于随机森林中的每棵树都是独立训练的,所以可以并行地训练多棵树(作为并行化训练单颗树的补充)。MLlib正是这样做的:并行地训练可变数目的子树,这里的子树的数目根据内存约束在每次迭代中都进行优化。
GBT:由于GBT(梯度提升树)必须一次训练一棵树,所以训练只在单颗树的水平上进行并行化。
我们想强调在MLlib中使用的两个关键优化:
- 内存:随机森林使用不同的数据子样本来训练每棵树。我们不使用显式复制数据,而是使用TreePoint结构来保存内存信息,该结构存储每个子样本中每个实例的副本数量。
- 通信:在决策树中的每个决策节点,决策树通常是通过从所有特征中选择部分特征来进行训练的,随机森林经常在每个节点将特征的选择限制在某个随机子集上。MLlib的实现利用了这种二次采样的优点来减少通信开销:例如,如果在每个节点只使用1/3的特征,那么我们可以将通信减少到原来的1/3。
更多的详细信息,请参见“MLlib编程指南”中的“集成”部分。
使用MLlib集成
我们演示如何使用MLlib来学习集成模型。以下Scala示例展示了如何读取数据集、将数据拆分为训练集和测试集、学习模型、打印模型和测试其精度。有关Java和Python中的示例,请参阅MLlib编程指南。请注意,GBT(梯度提升树)还没有Python API,但我们预计它将在Spark 1.3的发行版中出现(通过Github PR 3951)。
随机森林示例
import org.apache.spark.mllib.tree.RandomForest
import org.apache.spark.mllib.tree.configuration.Strategy
import org.apache.spark.mllib.util.MLUtils
// 加载并解析数据文件。
val data =
MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
// 将数据拆分为训练/测试集
val splits = data.randomSplit(Array(0.7, 0.3))
val (trainingData, testData) = (splits(0), splits(1))
// 训练随机森林模型。
val treeStrategy = Strategy.defaultStrategy("Classification")
val numTrees = 3 // 在实际中使用更多的numTrees
val featureSubsetStrategy = "auto" // 让算法进行选择。
val model = RandomForest.trainClassifier(trainingData,
treeStrategy, numTrees, featureSubsetStrategy, seed = 12345)
// 在测试实例上评估模型并计算测试错误
val testErr = testData.map { point =>
val prediction = model.predict(point.features)
if (point.label == prediction) 1.0 else 0.0
}.mean()
println("Test Error = " + testErr)
println("Learned Random Forest:n" + model.toDebugString)
梯度提升树示例
import org.apache.spark.mllib.tree.GradientBoostedTrees
import org.apache.spark.mllib.tree.configuration.BoostingStrategy
import org.apache.spark.mllib.util.MLUtils
// 加载并解析数据文件。
val data =
MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
// 将数据拆分为训练/测试集
val splits = data.randomSplit(Array(0.7, 0.3))
val (trainingData, testData) = (splits(0), splits(1))
// 训练梯度提升树模型。
val boostingStrategy =
BoostingStrategy.defaultParams("Classification")
boostingStrategy.numIterations = 3 // 注意: 在实际中使用更多的numIterations
val model =
GradientBoostedTrees.train(trainingData, boostingStrategy)
// 在测试实例上评估模型并计算测试错误
val testErr = testData.map { point =>
val prediction = model.predict(point.features)
if (point.label == prediction) 1.0 else 0.0
}.mean()
println("Test Error = " + testErr)
println("Learned GBT model:n" + model.toDebugString)
可扩展性
我们利用一些关于二元分类问题的实证结果展示了MLlib集成学习的可扩展性。下面的每张图比较了梯度增强树("GBT")和随机森林("RF"),这些图中的树被构建到不同的最大深度。
这些测试是在一个根据音频特征来预测歌曲发行日期的回归任务上进行的(特征来自UCI(加州大学尔湾分校)的ML(机器学习)库的YearPredictionMSD数据集)。我们使用EC2 r3.2xlarge机器。除另有说明外,算法参数保持为默认值。
扩展模型大小:训练时间和测试错误
下面的两幅图显示了增加集成模型中树的数量时的效果。对于两者而言,增加树的个数需要更长的时间来学习(第一张图),但在测试时的均方误差(MSE)上却获得了更好的结果(第二张图)。
这两种方法相比较,随机森林训练速度更快,但是他们通常比GBT(梯度提升树)需要训练更深的树来达到相同的误差。GBT(梯度提升树)可以进一步减少每次迭代的误差,但是经过多次迭代后,他们可能开始过拟合(即增加了测试的误差)。随机森林不容易过拟合,但他们的测试错误趋于平稳,无法进一步降低。
为了解MSE均方误差的基础,以下请注意,最左边的点显示了使用单个决策树时的错误率(深度分别为2、5或10)。
详情:463715个训练实例,16个工作节点。
扩展训练数据集大小:训练时间和测试错误
接下来的两张图片显示了使用更大的训练数据集时的效果。在有更多的数据时,这两种方法都需要更长时间的训练,但取得了更好的测试结果。
详细信息:16个工作节点。
强大的扩展:利用更多的工作节点完成更快的训练
最后这张图显示了使用更大的计算集群来解决同一个问题时的效果。使用更多的工作节点时,这两种方法都会变快很多。例如,利用深度为2的树进行GBT(梯度提升树)集成训练时,在16个工作节点上训练的速度比在2个工作节点上快4.7倍,较大的数据集能够产生更大倍数的加速。
详情:有463715个训练实例。
下一步有什么?
GBT将很快包含有一个Python API。未来发展的另一个重点是可插拔性:集成方法几乎可以应用在任何分类或回归算法上,而不仅仅是决策树。由Spark 1.2中实验性spark.ml包引入的管道 API 将使我们能够将集成学习方法拓展为真正可插拔的算法。
要开始自己使用决策树,请下载Spark 1.2!
进一步阅读
致谢
MLlib集成学习算法是由本文的作者李奇平(阿里巴巴)、宋钟(Alpine数据实验室)和Davies·刘(Databricks)合作开发的。我们也感谢Lee Yang,Andrew Feng和Hirakendu Das(雅虎)在设计和测试方面的帮助。我们也欢迎您的贡献!
- Jrebel6.3.3破解,配置图文教程
- Spring Cloud(十一)高可用的分布式配置中心 Spring Cloud Bus 消息总线集成(RabbitMQ)
- Keras中带LSTM的多变量时间序列预测
- Spring Cloud(十)高可用的分布式配置中心 Spring Cloud Config 中使用 Refresh
- Hibernate 的性能优化的时候碰到了"抓取策略",有四种
- 基于 Spring Cloud 完整的微服务架构实战
- maven build时报错Failed to execute goal org.apache.maven.plugins:maven-surefire-plugin:2.12.4:test
- Spring Cloud(九)高可用的分布式配置中心 Spring Cloud Config 集成 Eureka 服务
- Spring Cloud(八)高可用的分布式配置中心 Spring Cloud Config
- 用Raspberry Pi Zero打造「即插即用」的Web服务器
- Spring Cloud(七)服务网关 Zuul Filter 使用
- 基于Metronic的Bootstrap开发框架经验总结(1)-框架总览及菜单模块的处理
- Spring Cloud(六)服务网关 zuul 快速入门
- Docker Registry Server 搭建,配置免费HTTPS证书,及拥有权限认证、TLS 的私有仓库
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- CentOS8更换yum源后出现同步仓库缓存失败的问题
- log4j配置方式
- 基于MHA搭建MySQL Replication集群高可用架构
- PyQt5 技巧篇-解决相对路径无法加载图片问题,styleSheet通过"相对"路径加载图片,python获取当前运行文件的绝对路径。
- 基于MMM搭建MySQL Replication集群高可用架构
- Python 技术篇-按任意格式灵活获取日期、时间、年月日、时分秒。日期格式化。
- 当删库时如何避免跑路
- Python 句法错误:"SyntaxError: invalid character in identifier",原因及解决方法
- Python3 多线程问题:ModuleNotFoundError: No module named 'thread',原因及解决办法。
- 文件传输和秒传
- 关于数据库的各种备份与还原姿势详解
- Python 技术篇-多线程的2种创建方法,多线程的简单用法,快速上手。
- Python 技术篇-调用浏览器访问指定网页,一行代码实现。非Selenium。
- 数据库热备份神器 - XtraBackup
- Python 技术篇-读取文件,将内容保存dict字典中。去掉字符串中的指定字符方法。dict字典的遍历。