Tensorflow:模型变量保存
时间:2022-07-23
本文章向大家介绍Tensorflow:模型变量保存,主要内容包括其使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。
参考文献Tensorflow 实战 Google 深度学习框架[1]实验平台: Tensorflow1.4.0 python3.5.0
Tensorflow 常用保存模型方法
import tensorflow as tf
saver = tf.train.Saver() # 创建保存器
with tf.Session() as sess:
saver.save(sess,"/path/model.ckpt") #保存模型到相应ckpt文件
saver.restore(sess,"/path/model.ckpt") #从相应ckpt文件中恢复模型变量
- 使用 tf.train.Saver 会保存运行 Tensorflow 程序所需要的全部信息,然而有时并不需要某些信息。比如在测试或离线预测时,只需要知道如何从神经网络的输入层经过前向传播计算得到输出层即可,而不需要类似的变量初始化,模型保存等辅助节点的信息。Tensorflow 提供了 convert_varibales_to_constants 函数,通过这个函数可以将计算图中的变量及其取值通过常量的方式保存,这样整个 Tensorflow 计算图可以统一存放在一个文件中。
将变量取值保存为 pb 文件
# pb文件保存方法
import tensorflow as tf
from tensorflow.python.framework import graph_util
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2")
result = v1 + v2
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init_op) # 初始化所有变量
# 导出当前计算图的GraphDef部分,只需要这一部分就可以完成从输入层到输出层的计算过程
graph_def = tf.get_default_graph().as_graph_def()
# 将需要保存的add节点名称传入参数中,表示将所需的变量转化为常量保存下来。
output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, ['add'])
# 将导出的模型存入文件中
with tf.gfile.GFile("Saved_model/combined_model.pb", "wb") as f:
f.write(output_graph_def.SerializeToString())
# 2. 加载pb文件。
from tensorflow.python.platform import gfile
with tf.Session() as sess:
model_filename = "Saved_model/combined_model.pb"
# 读取保存的模型文件,并将其解析成对应的GraphDef Protocol Buffer
with gfile.FastGFile(model_filename, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
# 将graph_def中保存的图加载到当前图中,其中保存的时候保存的是计算节点的名称,为add
# 但是读取时使用的是张量的名称所以是add:0
result = tf.import_graph_def(graph_def, return_elements=["add:0"])
print(sess.run(result))
# Converted 2 variables to const ops.
# [array([3.], dtype=float32)]
参考资料
[1]Tensorflow实战Google深度学习框架: https://github.com/caicloud/tensorflow-tutorial/tree/master/Deep_Learning_with_TensorFlow/1.4.0
- two Pass方法连通域检测
- 【Java入门提高篇】Day14 Java中的泛型初探
- 使用shell脚本快速得到主备关系(r9笔记第93天)
- 【Java入门提高篇】Day13 Java中的反射机制
- 仿腾讯课堂固定滚动列表ReactNative组件
- Golang通过socket与java通讯
- Java基础-day09-基础题-对象;类;封装
- 通过shell脚本得到数据库的基本信息(一)(r9笔记第89天)
- iOS设备唯一标识的前世今生
- python 生成内嵌式字典(dict)-案例从python提取内嵌json写入mongodb
- Golang语言打印九九乘法表
- AVFoundation 框架初探究(四)
- Data Guard跳归档恢复的实践(r9笔记第92天)
- AVFoundation 框架初探究(三)
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- PHP实现时间日期友好显示实现代码
- AndroidStudio插件GsonFormat之Json快速转换JavaBean教程
- android studio错误: 常量字符串过长的解决方式
- Android Shader着色器/渲染器的用法解析
- PHP实现的文件浏览器功能简单示例
- Laravel中正确地返回HTTP状态码方法示例
- Android 实现抖音头像底部弹框效果的实例代码
- Android Studio修改Log信息颜色的实现
- Android 之BottomsheetDialogFragment仿抖音评论底部弹出对话框效果(实例代码)
- Yii框架的路由配置方法分析
- Android 购物车加减功能的实现代码
- Yii框架函数简单用法分析
- PHP读取XML文件的方法实例总结【DOMDocument及simplexml方法】
- 浅析PHP7 的垃圾回收机制
- android BottomSheetDialog新控件解析实现知乎评论列表效果(实例代码)