浅谈机器学习模型推理性能优化
在机器学习领域,清晰明了的数据预处理和表现优异的模型往往是数据科学家关注的重点,而实际生产中如何让模型落地、工程化也同样值得关注,工程化机器学习模型避不开的一个难点就是模型的推理(Inference / Serving)性能优化。
可能许多数据科学家都对模型的推理性能比较陌生,我举几个对推理性能有强要求的场景例子:
- 在公共安全领域中,视频监控中实时的人脸识别需要有实时的展示能力方便执法人员快速定位跟踪人员。
- 在互联网应用领域中,电商网站、内容应用实时的个性化推荐要求能够快速响应,推荐的卡顿感将直接影响购物或者内容获取的体验。
- 在银行领域中,电子支付中异常交易的实时识别也至关重要,任何异常的交易需要被快速识别并拦截,而正常的交易则不能被影响。
- 在金融领域中,量化模型毫秒级的交易判断输出能帮助华尔街的交易员们套取巨额利润。
从上面的例子不难发现,其实在不同的领域的场景下,推理的性能都是模型表现之外最关注的点,在某些极端的场景,数据科学家和机器学习工程师甚至愿意牺牲一部分的模型表现来换取更高的推理性能。
在统计学和传统机器学习算法时代,推理的性能往往都能达到人们的预期,毕竟一个模型算法的计算量很有限;随着多媒体的发展和计算机性能的提升,集成学习和深度学习的模型运用越来越广泛,而因为这些模型往往由成百上千个基模型构成,所以推理的性能大幅下降,在摩尔定律已经失效的今天,这慢慢变成了许多数据科学家和机器学习工程师的眼中钉。最近在项目中刚好完成了相关的需求, 所以抛砖引玉给大家分享一下做推理优化的roadmap。
计算图优化
“提高硬件性能是优化的最后一步,而不应该是第一步。”
上面这句话是我们项目PO说的话,其在Spark性能优化上有非常丰富的经验,我非常赞同这种论点。据我观察,在遇到算法模型的训练和推理性能瓶颈的时候,大部分机器学习工程师都希望能获得更高的硬件性能来突破瓶颈,却忽略了计算逻辑本身的优化。更高性能的硬件为模型推理带来的性能提升并不是线性的,而花费的硬件成本却是指数级上升的,所以一定要记得,不到万不得已,千万不要指望硬件带来的性能提升。
基本上数据处理和算法模型都可以被抽象为计算图,而计算逻辑的优化往往在领域内被称为图优化(这里的图优化并不是指图模型的表现优化哦 :D)。
每个计算图中都包含许多计算节,图优化的目标很简单,就是简化计算图中计算节点的计算量。常用的方式分为以下几种:
- 减少节点的数量
- 用高效替换低效的节点
- 用高效子图替换低效子图
- 用并行化分支代替单分支
减少节点的数量
在构造机器学习模型的时候,我们往往会无意中对数据做了多余或者反复的操作,这类操作就像写工程代码中的code smell一样,在模型构造完成之后一定要对这种操作多加注意。拿矩阵的转置(transpose)做例子(实际上多余、反复转置是非常常见的):
def func(a, b): a_T = a.transpose(1, 0) c = a_T + b return c.transpose(1, 0) def func_better(a, b): return a + b.transpose(1, 0)
可以清楚的看到,第一种实现多了一个转置的操作,而一个多余的转置多多少少会对性能产生影响。
In [10]: a = torch.randn(10000, 50000)
In [11]: b = torch.randn(50000, 10000)
In [12]: %timeit func(a, b)
9.67 s ± 827 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [13]: %timeit func_better(a, b)
8.44 s ± 233 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
用高效的节点替换低效的节点
相同的计算节点往往有多种实现方式,而这些实现方式中往往都有各自的优劣势,有一些是牺牲了空间换取时间,而有些是牺牲了时间换取了空间,如果是考虑推理的响应性能,那么我们往往会用时间有优势的实现方式来替代时间没有优势的节点。
例如,机器学习模型往往都可以看成是向量化数据运算,所以工程化的时候时刻记得使用向量化的运算,而不是使用朴素的loop。下面我就用计算矩阵对角线元素之和作为一个例子:
def func(a): result = 0 for i in range(a.shape[0]): result += a[i][i] return result
def func_better(a): eye = torch.eye(a.shape[0], dtype=torch.bool) return a[eye].sum()
In [17]: a = torch.randn(1000, 1000)
In [23]: %timeit func(a)6.75 ms ± 72 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
In [24]: %timeit func_better(a)2.48 ms ± 40.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each
用高效子图替换低效子图
这个优化点和上一个优化点很相似,只不过我们关注的点不在于某一个小的计算节点,而是从宏观上去关注节点组的优化,即子图的优化。下面这个例子是对变长向量组的求和计算,其中a表示把一堆向量拍平后的一维数组,而row_idxs是表示a数组中对应的值是属于第几个向量(即第几行),通过把a_original的表示方式换成a加一个描述每行向量的长度的数组row_lens可以减少a[row_idxs==i]这一步的判读,达到一样的效果。
# Example:# a_original -> [[1,2,3], [4,5,6], [7,8]]# a -> [1, 2, 3, 4, 5, 6, 7, 8]# row_idx -> [0, 0, 0, 1, 1, 1, 2, 2]# row_lens -> [3, 3, 2]
def func(a, row_idxs): result = [] for i in range(row_idxs.max() + 1): result.append(a[row_idxs==i].sum()) return torch.tensor(result, dtype=torch.float32)
def func_better(a, row_lens): result = [] for each in a.split(row_lens.tolist()): result.append(each.sum()) return torch.tensor(result, dtype=torch.float32)
In [66]: a = torch.randn(100000)
In [67]: row_idxs = torch.tensor([0] * 30000 + [1] * 50000 + [2] * 20000, dtype=torch.int32)
In [68]: row_lens = torch.tensor([30000, 50000, 20000], dtype=torch.int32)
In [69]: %timeit func(a, row_idxs)1.75 ms ± 9.35 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
In [70]: %timeit func_better(a, row_lens)43.4 µs ± 412 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each
用并行化分支代替单分支
在许多图计算框架里面,并行加速分为两种,一种是算子内部的并行化(intra)、一种是图分支的并行化(inter);例如在ONNX中,一个for循环算子是无法得到并行优化的,因为其维护了一个状态变量i,而往往我们并不会使用到这个i,我们只是想让某个计算逻辑执行n遍。这个时候就可以将这个for循环算子,拆分成n个计算分支,这样在使用图分支并行化计算的时候,就可以充分利用硬件资源提高计算效率了。
Backend优化
计算图优化是第一步也是最重要的一步,那么在计算图优化和硬件优化之间,难道就没有其他优化方式了吗?答案是有的,这个方式就是计算图引擎Backend的优化。计算图只是一个计算逻辑的抽象表示,而真正执行计算图的引擎也会有不同的实现,而每种实现往往带来的都是不同的计算效率。比如在PyTorch模型的Inference的可选项里面,有以下几种计算图引擎后端可供选择:
- 原生PyTorch API
- TorchScript Python API
- LibTorch
- ONNX Runtime
原生PyTorch API
原生的PyTorch API其实不用过多的描述,就是执行PyTorch模型中的forward函数,直接得到推理结果。可以说这个API是最简单、最原生的方式,可以作为推理性能表现的的一个BaseLine。这个API推理过程和训练保持一致,可以保障结果的正确性,可以作为其他backend正确性检验的一个对照。
TorchScript Python API
TorchScript是一种PyTorch模型的表示格式,相当于PyTorch Python API的子集构建的子语言,其能够被TorchScript编译器实时编译成C++的模型代码并执行。这种格式有三个主要的设计初衷:
- 构建一种跨环境序列化模型的方式
- 基于Torch基本算子,并可扩展的算子集
- 可以在C++程序中实时执行
通过torch.jit.script的API,可以将一个Python模型转换为TorchScript模型,并通过torch.jit.save保存为.pt格式的TorchScript模型。注意:转换成TorchScript的PyTorch模型要求用一定的规范编写,可以参考官方的文档:https://pytorch.org/tutorials/beginner/IntrotoTorchScript_tutorial.html?highlight=torchscript
对于TorchScript,PyTorch是有Python的API支持的,通过torch.jit.load可以读取一个.pt的模型文件,并执行forward函数即可进行推理服务。然而,实际测试发现,这种推理性能与PyTorch原生的API的性能是较为接近的,仅仅是在稳定性有小幅度的领先。
LibTorch
既然TorchScript Python的API那么弱,那我们就来试试C++的API吧!TorchScript的C++API是LibTorch, 使用LibTorch编译后的推理性能无论是速度和稳定性能有显著的提升。官方提供了一个简单的例子进行参考https://pytorch.org/tutorials/advanced/cpp_export.html。
有个小疑问:实际在客户现场的Linux服务器上,LibTorch的表现稳定性相当差,而在我自己的MacBook上是很稳定的,不清楚是什么原因。我怀疑和编译器或者基础库有关
ONNX Runtime
有人会问了,既然那么好,为啥不能成为唯一的选项?答案是LibTorch是C++的库,对编译和环境的依赖比较严重,并且对C++编程水平的要求会比较高。所以接下来我要讲讲我们最后选用的方案———ONNX Runtime。
ONNX(Open Neural Network Exchange)是一种针对机器学习所设计的开放式的文件格式,用于存储训练好的模型。它使得不同的人工智能框架(如Pytorch, MXNet)可以采用相同格式存储模型数据并交互。ONNX的规范及代码主要由微软,亚马逊 ,Facebook 和 IBM 等公司共同开发,以开放源代码的方式托管在Github上。目前官方支持加载ONNX模型并进行推理的深度学习框架有:Caffe2, PyTorch, MXNet,ML.NET,TensorRT 和 Microsoft CNTK,并且 TensorFlow 也非官方的支持ONNX。——Wikipedia
针对这种通用的交换格式的模型,微软牵头发起了ONNX Runtime的项目 这个项目旨在直接运行ONNX Runtime,相当于纯的模型推理Backend,其设计的理念就是为了解决训练和推理的性能问题,并且支持各种硬件加速库加速(如:MKL、CUDA、TensorRT等等)。除此之外,ONNX Runtime还有Python、C++、JAVA等多种接口;甚至提供了直接用于Serving的程序,暴露了HTTP2.0和GRPC的接口,用起来非常方便。
PyTorch提供了模型转换为ONNX模型的接口torch.onnx.export,通过这个接口我们就可以将模型转换为ONNX模型在Runtime中进行推理了。通过测试,我们发现ONNX Runtime在推理的速度和稳定性上都是相当优秀的。略有些遗憾的是,PyTorch中有些比较酷炫的算子ONNX并不支持,不过ONNX才刚刚兴起,相信之后一定会加入更多好用的算子的。
性能对比
- 机器:MacBook Pro (Retina, 15-inch, Mid 2015)
- CPU:2.2 GHz Intel Core i7
- 内存:16GB 1600 MHz DDR3
- device:CPU
- 模型:SqueezeNet
- 数据:(5, 3, 64, 64)
- Loop: 1024
最后
除了上述优化方式以外,其实还有很多我没有了解到的优化方法等待大家去探索和探讨。实际我们在项目中还有一部分计算图是脱离了计算引擎手撸出来的,因为涉及到信息安全问题就不在这里展开讲了。这篇博客只是抛砖引玉,希望能帮助到以后有内容的项目。我对这方面也了解不够,如果哪里说的不对欢迎大家指正~(ps. 这是我时隔一年发的第一篇博客,文笔写的不是很好大家不要介意哈哈!)
- Linux下部署SSH登录时的二次身份验证环境记录(利用Google Authenticator)
- Linux下DNS简单部署(主从域名服务器)
- 本地yum源部署记录
- silverlight4:摄像头占用状态检测以及二种截屏方法
- Flash/Flex学习笔记(16):如何做自定义Loading加载其它swf
- 台胞也能发红包喽!小编手把手教你搞定微信支付!
- 获取可视区域高度赋值给div(解决document.body.clientHeight的返回值为0的问题)
- Docker管理工具-Swarm部署记录
- 聊一聊大数据的问题和缺陷
- Flash/Flex学习笔记(13):对象拖动(startDrag/stopDrag)
- 原来Silverlight 4中是可以玩UDP的!
- Flash/Flex学习笔记(12):FMS 3.5之如何做视频实时直播
- Flash/Flex学习笔记(11):如何检测摄像头是否被占用
- Flash/Flex学习笔记(10):FMS 3.5之Hello World!
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- 搭建K8S集群之node节点部署
- ent orm笔记2---schema使用(上)
- ent orm笔记4---Code Generation
- 什么?明明是2020年12月30日显示2021年12月30日?
- JDK1.8HashMap源码学习-数据结构
- JDK1.8HashMap源码学习-初始化
- JDK1.8HashMap源码学习-put操作以及扩容(一)
- 数据科学家极力推荐核心计算工具-Numpy的前世今生(上)
- 什么是运维眼中可部署的软件架构
- 2020-09-03:裸写算法:回形矩阵遍历。
- Java并发编程系列34 | 深入理解线程池(下)
- MySQL 8.0新特性 — 密码管理
- 聊聊claudb的NotificationManager
- windows下安装Postman
- 【Pytorch 】笔记七:优化器源码解析和学习率调整策略