Apache Spark 2.0预览:机器学习模型持久性
介绍
机器学习(ML)的应用场景:
- 数据科学家生成一个ML模型,并让工程团队将其部署在生产环境中。
- 每个数据引擎集成一个Python模型训练集和一个Java模型服务集。
- 数据科学家创任务去训练各种ML模型,然后将它们保存并进行评估。
以上所有应用场景在模型持久性、保存和加载模型的能力方面都更为容易。随着Apache Spark 2.0即将发布,Spark的机器学习库MLlib将在DataFrame-based的API中对ML提供长期的近乎完整的支持。本博客给出了关于它的早期概述、代码示例以及MLlib的持久性API的一些细节。
ML持久性的关键特征包括:
- 支持所有Spark API中使用的语言:Scala,Java,Python&R
- 支持几乎所有的DataFrame-based的API中的ML算法
- 支持单个模型和完整的Pipelines,包括非适应(a recipe)和适应(a result)
- 使用可交换格式的分布式存储
感谢所有帮助MLlib实现飞跃的社区贡献者!参阅JIRA获取Scala / Java,Python和R贡献者的完整名单。
学习API
在Apache Spark 2.0中,MLlib的DataFrame-based的API在Spark上占据了ML的重要地位(请参阅曾经的博客文章获取针对此API的介绍以及它所介绍的“Pipelines”概念)。此MLlib的DataFrame-based的API提供了用于保存和加载模拟相似的Spark Data Source API模型的功能。
我们将用多种编程语言演示保存和加载模型,使用流行的MNIST数据集进行手写数字识别(LeCun et al., 1998; 可从LibSVM数据集页面获得)。该数据集包含手写数字0-9,以及地面实况标签。几个例子:
我们的目标是通过拍摄手写的数字然后识别图像中的数字。点击笔记获取完整的加载数据、填充模型、保存和加载它们的完整示例代码。
保存和加载单个模型
我们首先给出如何保存和加载单个模型以在语言之间共享。我们使用Python语言填充Random Forest Classifier并保存,然后使用Scala语言加载这个模型。
training = sqlContext.read… # data: features, label
rf = RandomForestClassifier(numTrees=20)
model = rf.fit(training)
我们可以调用save
方法来轻松地保存这个模型,调用load
方法来加载模型:
model.save("myModelPath")
sameModel = RandomForestClassificationModel.load("myModelPath")
我们还可以加载模型(之前使用Python语言保存的)到一个Scala应用或者一个Java应用中:
// Load the model in Scala
val sameModel = RandomForestClassificationModel.load("myModelPath")
这种用法适用于小型的局部模型,例如K-Means模型(用于聚类),也适用于大型分布式模型,如ALS模型(推荐使用的场景)。因为加载到的模型具有相同的参数和数据,所以即使模型部署在完全不同的Spark上也会返回相同的预测结果。
保存和加载完整的Pipelines
我们目前只讨论了保存和加载单个ML模型。在实际应用中,ML工作流程包括许多阶段,从特征提取及转换到模型的拟合和调整。MLlib提供Pipelines来帮助用户构建这些工作流程。(点击笔记获取使用ML Pipelines分析共享自行车数据集的教程。)
MLlib允许用户保存和加载整个Pipelines。我们来看一个在Pipeline上完成这些步骤的例子:
- 特征提取:二进制转换器将图像转换为黑白图像
- 模型拟合:Random Forest Classifier拍摄图像并预测数字0-9
- 调整:交叉验证以调整森林中树木的深度
这是我们的笔记中生成这个管道的一个部分代码:
// Construct the Pipeline: Binarizer + Random Forest
val pipeline = new Pipeline().setStages(Array(binarizer, rf))
// Wrap the Pipeline in CrossValidator to do model tuning.
val cv = new CrossValidator().setEstimator(pipeline) …
在我们填充这个Pipeline之前,我们将展示我们可以保存整个工作流程(在填充之前)。这个工作流程稍后可以加载到另一个在Spark集群上运行的数据集。
cv.save("myCVPath")
val sameCV = CrossValidator.load("myCVPath")
最后,我们填充Pipeline并保存,然后把它加载回来。这节省了特征提取步骤、交叉验证调整后的Random Forest模型的步骤,模型调整过程中的统计步骤。
val cvModel = cv.fit(training)
cvModel.save("myCVModelPath")
val sameCVModel = CrossValidatorModel.load("myCVModelPath")
了解详细信息
Python调整
Spark 2.0中缺少Python的调整部分。Python目前还不支持保存和加载用于调整模型超参数的CrossValidator和TrainValidationSplit, 这个问题将在Spark 2.1(SPARK-13786)中进行考虑。尽管如此,我们仍可以保存Python中的CrossValidator和TrainValidationSplit的结果。例如我们使用交叉验证来调整Random Forest,然后调整过程中找到的最佳模型并保存。
Define the workflow
rf = RandomForestClassifier()
cv = CrossValidator(estimator=rf, …)
Fit the model, running Cross-Validation
cvModel = cv.fit(trainingData)
Extract the results, i.e., the best Random Forest model
bestModel = cvModel.bestModel
Save the RandomForest model
bestModel.save("rfModelPath")
点击笔记查看完整代码。
可交换的存储格式
在内部,我们将模型元数据和参数保存为JSON和Parquet格式。这些存储格式是可交换的并且可以使用其他库进行读取。我们能够使用Parquet 存储小模型(如朴素贝叶斯分类)和大型分布式模型(如推荐的ALS)。存储路径可以是任何URI支持的可以进行保存和加载的Dataset / DataFrame,还包括S3、本地存储等路径。
语言交叉兼容性
模型可以在Scala、Java和Python中轻松地进行保存和加载。R语言有两个限制,首先,R并非支持全部的MLlib模型,所以并不是所有使用其他语言训练过的模型都可以使用R语言加载。第二,R语言模型的格式还存储了额外数据,所以用其他语言加载使用R语言训练和保存后的模型有些困难(供参考的笔记本)。在不久的将来R语言将会有更好的跨语言支持。
总结
随着即将到来的2.0版本的发布,DataFrame-based的MLlib API将为持久化模型和Pipelines提供近乎全面的覆盖。持久性对于在团队之间共享模型、创建多语言ML工作流以及将模型转移到生产环境至关重要。准备将DataFrame-based的MLlib API变成Apache Spark中的机器学习的主要API是这项功能的最后一部分。
接下来?
高优先级的项目包括完整的持久性覆盖,包括Python模型调整算法以及R和其他语言API之间的兼容性改进。
从使用Scala和Python的教程笔记开始。您也可以只更新您当前的MLlib工作流程以使用保存和加载功能。
实验性功能:使用在Apache Spark2.0的分支(Databricks Community Edition中的测试代码)预览版中的API。加入beta版的等待名单。
阅读更多
- 阅读本博客中所有引用的代码笔记。
- 了解DataFrame-based API for MLlib & ML Pipelines:
- 介绍ML Pipelines的笔记:分析自行车共享数据集的教程
- ML Pipelines上的原始博客文章
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- Flutter 项目.gitignore配置
- js和object的常见操作,持续更新中...
- 常见编程模式之快慢指针
- python pywifi模块——暴力破解wifi
- 面试题系列第3篇:Integer等号判断的内幕,你可能不知道?
- Go by Example 中文:工作池
- 推荐一款万能抓包神器:Fiddler Everywhere
- 猿实战04——el-upload结合nginx之通用图片处理
- 30 多个有内味道且笑死的人代码注释
- Logstash-input-jdbc 同步 mysql 准实时数据至 ElasticSearch 搜索引擎
- 总结一些,我在书写 CSS 的时候,经常犯的错误!
- 通俗理解 set,dict 背后的哈希表
- K8S 生态周报| Google 选择 Cilium 作为 GKE 下一代数据面
- [Introduction]万字手撕Go http源码server.go
- Python 3.9 值得关注的更新点