tensorflow 模型浮点数计算量和参数量估计
TensorFlow 模型浮点数计算量和参数量统计
2018-08-28
本博文整理了如何对一个 TensorFlow 模型的浮点数计算量(FLOPs)和参数量进行统计。
stats_graph.py
import tensorflow as tf
def stats_graph(graph):
flops = tf.profiler.profile(graph, options=tf.profiler.ProfileOptionBuilder.float_operation())
params = tf.profiler.profile(graph, options=tf.profiler.ProfileOptionBuilder.trainable_variables_parameter())
print('FLOPs: {}; Trainable params: {}'.format(flops.total_float_ops, params.total_parameters))
利用高斯分布对变量进行初始化会耗费一定的 FLOP
C[25,9]=A[25,16]B[16,9] FLOPs=(16+15)×(25×9)=6975FLOPs(inTFstyle)=(16+16)×(25×9)=7200total_parameters=25×16+16×9=544
with tf.Graph().as_default() as graph:
A = tf.get_variable(initializer=tf.random_normal_initializer(dtype=tf.float32), shape=(25, 16), name='A')
B = tf.get_variable(initializer=tf.random_normal_initializer(dtype=tf.float32), shape=(16, 9), name='B')
C = tf.matmul(A, B, name='ouput')
stats_graph(graph)
输出为:
FLOPs: 8288; Trainable params: 544
利用常量初始化器对变量进行初始化不会耗费 FLOP
with tf.Graph().as_default() as graph:
A = tf.get_variable(initializer=tf.constant_initializer(value=1, dtype=tf.float32), shape=(25, 16), name='A')
B = tf.get_variable(initializer=tf.zeros_initializer(dtype=tf.float32), shape=(16, 9), name='B')
C = tf.matmul(A, B, name='ouput')
stats_graph(graph)
输出为:
FLOPs: 7200; Trainable params: 544
Frozen graph
通常我们对耗费在初始化上的 FLOPs 并不感兴趣,因为它是发生在训练过程之前且是一次性的,我们感兴趣的是模型部署之后在生产环境下的 FLOPs。我们可以通过 Freeze 计算图的方式得到除去初始化 FLOPs 的、模型部署后推断过程中耗费的 FLOPs。
from tensorflow.python.framework import graph_util
def load_pb(pb):
with tf.gfile.GFile(pb, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def, name='')
return graph
with tf.Graph().as_default() as graph:
# ***** (1) Create Graph *****
A = tf.Variable(initial_value=tf.random_normal([25, 16]))
B = tf.Variable(initial_value=tf.random_normal([16, 9]))
C = tf.matmul(A, B, name='output')
print('stats before freezing')
stats_graph(graph)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
# ***** (2) freeze graph *****
output_graph = graph_util.convert_variables_to_constants(sess, graph.as_graph_def(), ['output'])
with tf.gfile.GFile('graph.pb', "wb") as f:
f.write(output_graph.SerializeToString())
# ***** (3) Load frozen graph *****
graph = load_pb('./graph.pb')
print('stats after freezing')
stats_graph(graph)
输出为:
stats before freezing
FLOPs: 8288; Trainable params: 544
INFO:tensorflow:Froze 2 variables.
INFO:tensorflow:Converted 2 variables to const ops.
stats after freezing
FLOPs: 7200; Trainable params: 0
与 Keras 的结合
from keras import backend as K
from keras.layers import Dense
from keras.models import Sequential
from keras.initializers import Constant
model = Sequential()
model.add(Dense(32, input_dim=4, bias_initializer=Constant(value=0), kernel_initializer=Constant(value=1)))
sess = K.get_session()
graph = sess.graph
stats_graph(graph)
输出为:
FLOPs: 0; Trainable params: 160
Using TensorFlow backend.
2 ops no flops stats due to incomplete shapes.
2 ops no flops stats due to incomplete shapes.
model.summary()
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense_1 (Dense) (None, 32) 160
=================================================================
Total params: 160
Trainable params: 160
Non-trainable params: 0
_________________________________________________________________
DL
About
This is Robert Lexis (FengCun Li). To see the world, things dangerous to come to, to see behind walls, to draw closer, to find each other and to feel. That is the purpose of LIFE.
Recent Posts
Static variable in inline
Iterator invalidation rul
Emplace back
Perfect forward
原文地址:https://www.cnblogs.com/o-v-o/p/11042066.html
- 鼠标滚轮事件介绍
- Understanding delete
- objC与js通信实现--WebViewJavascriptBridge
- 简单易学的机器学习算法——岭回归(Ridge Regression)
- QQ空间(日志、说说、个人信息)python爬虫源码(一天可抓取 400 万条数据)
- 文本分类实战: 机器学习vs深度学习算法对比(附代码)
- ReactJS分析之入口函数render
- 简单易学的机器学习算法——SVD奇异值分解
- AngularJS源码分析之依赖注入$injector
- 使用yield进行异步流程控制
- 【Java提高十七】Set接口集合详解
- 如何科学地蹭热点:用python爬虫获取热门微博评论并进行情感分析
- 使用ETag进行session的降级
- 关于oracle中的反连接(r3笔记第95天)
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- Java|屏幕截图
- VBA解压缩ZIP文件07——length和distance扩展
- Excel VBA常用功能加载宏——打开活动工作簿所在文件夹
- 常用功能加载宏——拆分工作表
- MyVBA加载宏——添加自定义菜单04——功能实现
- CS学习笔记 | 14、powerup提权的方法
- VBA解压缩ZIP文件05——Huffman树
- JavaScript|jQuery基础语法
- VBA解压缩ZIP文件03——解压准备工作
- VBA解压缩ZIP文件04——解析ZIP文件结构
- 开发|Springboot简单实现文件上传
- VBA解压缩ZIP文件01——实现的功能
- MyVBA加载宏——添加自定义菜单03——功能分析
- MyVBA加载宏——添加自定义菜单02——给按钮添加单击事件
- 科研猫小课堂:敲黑板!竞争风险模型应该如何分析?