【收藏】简单易用 TensorFlow 代码集,GAN通用框架、函数
时间:2022-06-18
本文章向大家介绍【收藏】简单易用 TensorFlow 代码集,GAN通用框架、函数,主要内容包括其使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。
新智元报道
来源:GitHub 作者:Junho Kim
【新智元导读】今天为大家推荐一个实用的GitHub项目:TensorFlow-Cookbook。 这是一个易用的TensorFlow代码集,包含了对GAN有用的一些通用架构和函数。
今天为大家推荐一个实用的GitHub项目:TensorFlow-Cookbook。
这是一个易用的TensorFlow代码集,作者是来自韩国的AI研究科学家Junho Kim,内容涵盖了谱归一化卷积、部分卷积、pixel shuffle、几种归一化函数、 tf-datasetAPI,等等。
作者表示,这个repo包含了对GAN有用的一些通用架构和函数。
项目正在进行中,作者将持续为其他领域添加有用的代码,目前正在添加的是 tf-Eager mode的代码。欢迎提交pull requests和issues。
Github地址 :
https://github.com/taki0112/Tensorflow-Cookbook
如何使用
Import
-
ops.py
- operations
- from ops import *
-
utils.py
- image processing
- from utils import *
Network template
def network(x, is_training=True, reuse=False, scope="network"): with tf.variable_scope(scope, reuse=reuse):
x = conv(...)
...
return logit
使用DatasetAPI向网络插入数据
Image_Data_Class = ImageData(img_size, img_ch, augment_flag)
trainA = trainA.map(Image_Data_Class.image_processing, num_parallel_calls=16)
trainA = trainA.shuffle(buffer_size=10000).prefetch(buffer_size=batch_size).batch(batch_size).repeat()
trainA_iterator = trainA.make_one_shot_iterator()
data_A = trainA_iterator.get_next()
logit = network(data_A)
- 了解更多,请阅读: https://github.com/taki0112/Tensorflow-DatasetAPI
Option
-
padding='SAME'
- pad = ceil[ (kernel - stride) / 2 ]
-
pad_type
- 'zero' or 'reflect'
-
sn
- use spectral_normalization or not
-
Ra
- use relativistic gan or not
-
loss_func
- gan
- lsgan
- hinge
- wgan
- wgan-gp
- dragan
注意
- 如果你不想共享变量,请以不同的方式设置所有作用域名称。
权重(Weight)
weight_init = tf.truncated_normal_initializer(mean=0.0, stddev=0.02)
weight_regularizer = tf.contrib.layers.l2_regularizer(0.0001)
weight_regularizer_fully = tf.contrib.layers.l2_regularizer(0.0001)
初始化(Initialization)
-
Xavier
: tf.contrib.layers.xavier_initializer() -
He
: tf.contrib.layers.variance_scaling_initializer() -
Normal
: tf.random_normal_initializer(mean=0.0, stddev=0.02) -
Truncated_normal
: tf.truncated_normal_initializer(mean=0.0, stddev=0.02) -
Orthogonal
: tf.orthogonal_initializer(1.0) / # if relu = sqrt(2), the others = 1.0
正则化(Regularization)
-
l2_decay
: tf.contrib.layers.l2_regularizer(0.0001) -
orthogonal_regularizer
: orthogonal_regularizer(0.0001) & orthogonal_regularizer_fully(0.0001)
卷积(Convolution)
basic conv
x = conv(x, channels=64, kernel=3, stride=2, pad=1, pad_type='reflect', use_bias=True, sn=True, scope='conv')
partial conv (NVIDIA Partial Convolution)
x = partial_conv(x, channels=64, kernel=3, stride=2, use_bias=True, padding='SAME', sn=True, scope='partial_conv')
dilated conv
x = dilate_conv(x, channels=64, kernel=3, rate=2, use_bias=True, padding='SAME', sn=True, scope='dilate_conv')
Deconvolution
basic deconv
x = deconv(x, channels=64, kernel=3, stride=2, padding='SAME', use_bias=True, sn=True, scope='deconv')
Fully-connected
x = fully_conneted(x, units=64, use_bias=True, sn=True, scope='fully_connected')
Pixel shuffle
x = conv_pixel_shuffle_down(x, scale_factor=2, use_bias=True, sn=True, scope='pixel_shuffle_down')
x = conv_pixel_shuffle_up(x, scale_factor=2, use_bias=True, sn=True, scope='pixel_shuffle_up')
-
down
===> [height, width] -> [height // scale_factor, width // scale_factor] -
up
===> [height, width] -> [height * scale_factor, width * scale_factor]
Block
residual block
x = resblock(x, channels=64, is_training=is_training, use_bias=True, sn=True, scope='residual_block')
x = resblock_down(x, channels=64, is_training=is_training, use_bias=True, sn=True, scope='residual_block_down')
x = resblock_up(x, channels=64, is_training=is_training, use_bias=True, sn=True, scope='residual_block_up')
-
down
===> [height, width] -> [height // 2, width // 2] -
up
===> [height, width] -> [height * 2, width * 2]
attention block
x = self_attention(x, channels=64, use_bias=True, sn=True, scope='self_attention')
x = self_attention_with_pooling(x, channels=64, use_bias=True, sn=True, scope='self_attention_version_2')
x = squeeze_excitation(x, channels=64, ratio=16, use_bias=True, sn=True, scope='squeeze_excitation')
x = convolution_block_attention(x, channels=64, ratio=16, use_bias=True, sn=True, scope='convolution_block_attention')
Normalization
Normalization
x = batch_norm(x, is_training=is_training, scope='batch_norm')
x = instance_norm(x, scope='instance_norm')
x = layer_norm(x, scope='layer_norm')
x = group_norm(x, groups=32, scope='group_norm')
x = pixel_norm(x)
x = batch_instance_norm(x, scope='batch_instance_norm')
x = condition_batch_norm(x, z, is_training=is_training, scope='condition_batch_norm'):
x = adaptive_instance_norm(x, gamma, beta):
- 如何使用 condition_batch_norm,请参考:
https://github.com/taki0112/BigGAN-Tensorflow
- 如何使用
adaptive_instance_norm
,请参考:
https://github.com/taki0112/MUNIT-Tensorflow
Activation
x = relu(x)
x = lrelu(x, alpha=0.2)
x = tanh(x)
x = sigmoid(x)
x = swish(x)
Pooling & Resize
x = up_sample(x, scale_factor=2)
x = max_pooling(x, pool_size=2)
x = avg_pooling(x, pool_size=2)
x = global_max_pooling(x)
x = global_avg_pooling(x)
x = flatten(x)
x = hw_flatten(x)
Loss
classification loss
loss, accuracy = classification_loss(logit, label)
pixel loss
loss = L1_loss(x, y)
loss = L2_loss(x, y)
loss = huber_loss(x, y)
loss = histogram_loss(x, y)
-
histogram_loss
表示图像像素值在颜色分布上的差异。
gan loss
d_loss = discriminator_loss(Ra=True, loss_func='wgan-gp', real=real_logit, fake=fake_logit)
g_loss = generator_loss(Ra=True, loss_func='wgan_gp', real=real_logit, fake=fake_logit)
- 如何使用
gradient_penalty,
请参考:
https://github.com/taki0112/BigGAN-Tensorflow/blob/master/BigGAN_512.py#L180
kl-divergence (z ~ N(0, 1))
loss = kl_loss(mean, logvar)
Author
Junho Kim
Github地址 :
https://github.com/taki0112/Tensorflow-Cookbook
- 【翻译】在Visual Studio中使用Asp.Net Core MVC创建你的第一个Web API应用(一)
- 基于JQuery EasyUI的WebMVC控件封装(含源码)
- Android系统源码分析-JNI
- EntityFrameWork实现部分字段获取和修改(含源码)
- 基于Ado.Net的日志组件
- Do you kown Asp.Net Core -- 配置Kestrel端口
- 【翻译】在Visual Studio中使用Asp.Net Core MVC创建第一个Web Api应用(二)
- 微信快速开发框架(一)-- 对微信公众平台开发的消息处理
- 微信快速开发框架(二) -- 快速开发微信公众平台框架---简介
- LayoutInflater 布局渲染工具原理分析
- 使用Keras在训练深度学习模型时监控性能指标
- 微信快速开发框架(四)-- 体验微信公众平台快速开发框架
- AsyncTask源码解析
- 微信快速开发框架(五)-- 利用快速开发框架,快速搭建微信浏览博客园首页文章
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- spring框架应用系列二:component-scan自动扫描注册装配
- 手把手教你,嘴对嘴传达------Apache日志管理日志(rotatelogs分割工具、AWStats日志分析)
- 配合JAVA的AJAX使用
- 手把手教你,嘴对嘴传达------Apache网页优化
- jQuery通过Ajax实现请求后台接口数据
- Git常规操作
- 手把手教你,嘴对嘴传达 ----源码编译安装部署LAMP平台(LAMP平台与编译安装详解,Apache,MySQL与PHP源码编译安装,LAMP平台搭建论坛)
- Vue点击切换样式
- ElementUI引入到vue项目开发
- 手把手教你,嘴对嘴传达------Apache(安全优化防盗链、隐藏版本信息)
- spring框架应用系列三:切面编程(带参数)
- 排障集锦:九九八十一难之第六难!(98)Address already in use: AH00072: make_sock: could not bind to address ::80
- Vue页面中引用自定义组件
- Vue如何引用Vant组件
- js表单验证工具包