GAN 的 keras 实现
本文结构:
- 什么是 GAN?
- 优点?
- keras 例子?
什么是 GAN?
GAN,全称为 Generative Adversarial Nets,直译为生成式对抗网络,是一种非监督式模型。
一种应用是生成在原始数据集中不存在的但是却比较合理的数据,还可以拓展一张图片,生成下一帧影像,由简单几笔生成一幅画:
模型:
主要有两部分:
The Generative Model:通过输入任意随机数据,尝试生成一些真实的东西(曲线,图像,声音,文本,...)
The Discriminative Model:试图判定哪些是虚假的数据,来减小对真实数据的误报。
优点:
Markov chains are never needed 避免了计算复杂度特别高的过程,直接进行采样和推断,应用效率相应提高。
a wide variety of functions can be incorporated into the model 针对不同的任务就可以设计不同类型的损失函数。
can represent very sharp, even degenerate distributions
Keras 例子:
任务:生成 sin 曲线。
%matplotlib inline
import os
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm_notebook as tqdm
from keras.models import Model
from keras.layers import Input, Reshape
from keras.layers.core import Dense, Activation, Dropout, Flatten
from keras.layers.normalization import BatchNormalization
from keras.layers.convolutional import UpSampling1D, Conv1D
from keras.layers.advanced_activations import LeakyReLU
from keras.optimizers import Adam, SGD
from keras.callbacks import TensorBoard
1. Generative model:
输入:noise data 输出:尝试生成真实的 sin 数据
def get_generative(G_in, dense_dim=200, out_dim=50, lr=1e-3):
x = Dense(dense_dim)(G_in)
x = Activation('tanh')(x)
G_out = Dense(out_dim, activation='tanh')(x)
G = Model(G_in, G_out)
opt = SGD(lr=lr)
G.compile(loss='binary_crossentropy', optimizer=opt)
return G, G_out
2. Discriminative model:
输出:识别此数据是真实的,还是由 Generative model 生成的
def get_discriminative(D_in, lr=1e-3, drate=.25, n_channels=50, conv_sz=5, leak=.2):
x = Reshape((-1, 1))(D_in)
x = Conv1D(n_channels, conv_sz, activation='relu')(x)
x = Dropout(drate)(x)
x = Flatten()(x)
x = Dense(n_channels)(x)
D_out = Dense(2, activation='sigmoid')(x)
D = Model(D_in, D_out)
dopt = Adam(lr=lr)
D.compile(loss='binary_crossentropy', optimizer=dopt)
return D, D_out
3. chain the two models into a GAN:
set_trainability 的作用是每次训练 generator 时要冻住 discriminator。
def set_trainability(model, trainable=False):
model.trainable = trainable
for layer in model.layers:
layer.trainable = trainable
def make_gan(GAN_in, G, D):
set_trainability(D, False)
x = G(GAN_in)
GAN_out = D(x)
GAN = Model(GAN_in, GAN_out)
GAN.compile(loss='binary_crossentropy', optimizer=G.optimizer)
return GAN, GAN_out
4. Training:
交替训练 discriminator 和 chained GAN,在训练 chained GAN 时要冻住 discriminator 的参数:
def sample_noise(G, noise_dim=10, n_samples=10000):
X = np.random.uniform(0, 1, size=[n_samples, noise_dim])
y = np.zeros((n_samples, 2))
y[:, 1] = 1
return X, y
def train(GAN, G, D, epochs=500, n_samples=10000, noise_dim=10, batch_size=32, verbose=False, v_freq=50):
d_loss = []
g_loss = []
e_range = range(epochs)
if verbose:
e_range = tqdm(e_range)
for epoch in e_range:
X, y = sample_data_and_gen(G, n_samples=n_samples, noise_dim=noise_dim)
set_trainability(D, True)
d_loss.append(D.train_on_batch(X, y))
X, y = sample_noise(G, n_samples=n_samples, noise_dim=noise_dim)
set_trainability(D, False)
g_loss.append(GAN.train_on_batch(X, y))
if verbose and (epoch + 1) % v_freq == 0:
print("Epoch #{}: Generative Loss: {}, Discriminative Loss: {}".format(epoch + 1, g_loss[-1], d_loss[-1]))
return d_loss, g_loss
d_loss, g_loss = train(GAN, G, D, verbose=True)
5. Results:
N_VIEWED_SAMPLES = 2
data_and_gen, _ = sample_data_and_gen(G, n_samples=N_VIEWED_SAMPLES)
pd.DataFrame(np.transpose(data_and_gen[N_VIEWED_SAMPLES:])).rolling(5).mean()[5:].plot()
学习资料: https://arxiv.org/pdf/1406.2661.pdf http://www.rricard.me/machine/learning/generative/adversarial/networks/2017/04/05/gans-part1.html http://www.rricard.me/machine/learning/generative/adversarial/networks/keras/tensorflow/2017/04/05/gans-part2.html
- 用awk写递归
- bc计算A股上市新股依次涨停股价
- ASP.NET Core应用的错误处理[1]:三种呈现错误页面的方式
- python访问http的GET/POST
- 用openssl库RSA加密解密
- Kobject浅析
- ASP.NET Core应用的错误处理[2]:DeveloperExceptionPageMiddleware中间件如何呈现“开发者异常页面”
- RSA简介(二)——模幂算法
- 为虚拟机vCPU绑定物理CPU
- RSA简介(三)——寻找质数
- RSA简介(四)——求逆算法
- 平方根的C语言实现(三) ——最终程序实现
- ASP.NET Core应用的错误处理[3]:ExceptionHandlerMiddleware中间件如何呈现“定制化错误页面”
- 【视频】Es6新特性-Symbol
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- 再不迁移到Material Design Components 就out啦
- hbase 学习
- 再谈Fragment
- java线程池(四):ForkJoinPool的使用及基本原理
- 算法书中算法
- Robo3T 与 NaviCat 的安装
- 牛客网2017年校招真题-1
- 实例分割新思路之SOLO v1&v2深度解析
- 牛客网剑指offer-3
- java8新特性总结备忘
- 商业数据分析从入门到入职(6)Python程序结构和函数
- 数据科学家极力推荐核心计算工具-Numpy的前世今生(下)
- Android 重构 | 持续优化统一管理 Gradle
- 快速学习-XXL-JOB调度中心/执行器 RESTful API
- 快速学习-XXL-JOB快速入门