tf.get_collection()
时间:2022-07-26
本文章向大家介绍tf.get_collection(),主要内容包括其使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。
交流、咨询,有疑问欢迎添加QQ 2125364717,一起交流、一起发现问题、一起进步啊,哈哈哈哈哈
此函数有两个参数,key和scope。
Args:
- 1.key: The key for the collection. For example, the GraphKeys class contains many standard names for collections.
- 2.scope: (Optional.) If supplied, the resulting list is filtered to include only items whose name attribute matches using re.match. Items without a name attribute are never returned if a scope is supplied and the choice or re.match means that a scope without special tokens filters by prefix.
举个例子:
# 在'My-TensorFlow-tutorials-master/02 CIFAR10/cifar10.py'代码中
variables = tf.get_collection(tf.GraphKeys.VARIABLES)
for i in variables:
print(i)
>>> <tf.Variable 'conv1/weights:0' shape=(3, 3, 3, 96) dtype=float32_ref>
<tf.Variable 'conv1/biases:0' shape=(96,) dtype=float32_ref>
<tf.Variable 'conv2/weights:0' shape=(3, 3, 96, 64) dtype=float32_ref>
<tf.Variable 'conv2/biases:0' shape=(64,) dtype=float32_ref>
<tf.Variable 'local3/weights:0' shape=(16384, 384) dtype=float32_ref>
<tf.Variable 'local3/biases:0' shape=(384,) dtype=float32_ref>
<tf.Variable 'local4/weights:0' shape=(384, 192) dtype=float32_ref>
<tf.Variable 'local4/biases:0' shape=(192,) dtype=float32_ref>
<tf.Variable 'softmax_linear/softmax_linear:0' shape=(192, 10) dtype=float32_ref>
<tf.Variable 'softmax_linear/biases:0' shape=(10,) dtype=float32_ref>
tf.get_collection会列出key里所有的值。
进一步地:
tf.GraphKeys的点后可以跟很多类, 比如VARIABLES类(包含所有variables), 比如REGULARIZATION_LOSSES。
具体tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)的使用:
def easier_network(x, reg):
""" A network based on tf.contrib.learn, with input `x`. """
with tf.variable_scope('EasyNet'):
out = layers.flatten(x)
out = layers.fully_connected(out,
num_outputs=200,
weights_initializer = layers.xavier_initializer(uniform=True),
weights_regularizer = layers.l2_regularizer(scale=reg),
activation_fn = tf.nn.tanh)
out = layers.fully_connected(out,
num_outputs=200,
weights_initializer = layers.xavier_initializer(uniform=True),
weights_regularizer = layers.l2_regularizer(scale=reg),
activation_fn = tf.nn.tanh)
out = layers.fully_connected(out,
num_outputs=10, # Because there are ten digits!
weights_initializer = layers.xavier_initializer(uniform=True),
weights_regularizer = layers.l2_regularizer(scale=reg),
activation_fn = None)
return out
def main(_):
mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
x = tf.placeholder(tf.float32, [None, 784])
y_ = tf.placeholder(tf.float32, [None, 10])
# Make a network with regularization
y_conv = easier_network(x, FLAGS.regu)
weights = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'EasyNet')
print("")
for w in weights:
shp = w.get_shape().as_list()
print("- {} shape:{} size:{}".format(w.name, shp, np.prod(shp)))
print("")
reg_ws = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES, 'EasyNet')
for w in reg_ws:
shp = w.get_shape().as_list()
print("- {} shape:{} size:{}".format(w.name, shp, np.prod(shp)))
print("")
# Make the loss function `loss_fn` with regularization.
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y_conv))
loss_fn = cross_entropy + tf.reduce_sum(reg_ws)
train_step = tf.train.AdamOptimizer(1e-4).minimize(loss_fn)
main()
>>> - EasyNet/fully_connected/weights:0 shape:[784, 200] size:156800
- EasyNet/fully_connected/biases:0 shape:[200] size:200
- EasyNet/fully_connected_1/weights:0 shape:[200, 200] size:40000
- EasyNet/fully_connected_1/biases:0 shape:[200] size:200
- EasyNet/fully_connected_2/weights:0 shape:[200, 10] size:2000
- EasyNet/fully_connected_2/biases:0 shape:[10] size:10
- EasyNet/fully_connected/kernel/Regularizer/l2_regularizer:0 shape:[] size:1.0
- EasyNet/fully_connected_1/kernel/Regularizer/l2_regularizer:0 shape:[] size:1.0
- EasyNet/fully_connected_2/kernel/Regularizer/l2_regularizer:0 shape:[] size:1.0
据:
for w in reg_ws:
shp = ....
这段代码的输出可知, 在图上的所有regularization都会集中保存到tf.GraphKeys.REGULARIZATION_LOSSES去。
关于collection的详情请参见: http://blog.csdn.net/shenxiaolu1984/article/details/52815641
关于tf.GraphKeys.REGULARIZATION_LOSSES的详情参见: https://gxnotes.com/article/178205.html
各位看官老爷,如果觉得对您有用麻烦赏个子,创作不易,0.1元就行了。下面是微信乞讨码:
添加描述
添加描述
- 现在 tensorflow 和 mxnet 很火,是否还有必要学习 scikit-learn 等框架?
- 数据的标准化与中心化以及R语言中的scale详解
- Java基础19(02)总结IO流,异常try…catch,throws,File类
- HTML5 — header
- 两条报警信息的分析(第二篇)(r6笔记第71天)
- 两条报警信息的分析(第一篇) (r6笔记第70天)
- R-求y=sin(X) 0-PI 面积代码
- Facebook 发布 wav2letter 工具包,用于端到端自动语音识别
- Java企业面试——Java基础
- 从Java的类型转换看MySQL和Oracle中的隐式转换(二)(r6笔记第68天)
- R包—iGraph
- 深度学习中 GPU 和显存分析
- 数据库SQL优化大总结1之- 百万级数据库优化方案
- Golang语言社区--LollipopGO开源项目搭建商城路由分发
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- 浅析Linux中使用nohup及screen运行后台任务的示例和区别
- 快速搭建简易、高效、多线程http服务器
- Linux解压文件到指定目录的方法
- Linux系统中CPU占用率较高问题排查思路与解决方法
- linux中ftp服务搭建需要注意的地方
- CentOS下使用LibreOffice实现文档格式的转换方式
- 详解在Linux中怎么使用cron计划任务
- Linux系统删除文件夹和文件的命令
- 详解Linux防火墙iptables禁IP与解封IP常用命令
- 在Ubuntu 16.04 Server上安装Zabbix的方法
- Centos7.3安装部署最新版Zabbix3.4的方法(图文)
- Linux系统下移植busybox中mkfs.vfat命令
- Linux服务器配置ip白名单防止远程登录以及端口暴露的问题
- Ubuntu上释放空间的5种简单方法
- Linux下Redis允许远程连接的实现方法