【干货】动手实践:理解和优化GAN(附代码)
【导读】本文是机器学习研究员Mirantha Jayathilaka撰写的一篇技术博文,主要讲解了生成对抗网络(GAN)。本文分别从理论和代码实践两方面来介绍GAN,首先介绍了生成器和判别器的概念及其工作原理,然后分别从构建生成器模型,构建判别器模型,选择损失并训练,选择不同的优化算法进行训练等方面讲解代码,文末附有作者的完整代码和数据集代码,感兴趣的读者可以学习一下。
Understanding and optimizing GANs
(Going back to first principles)
理解和优化GAN
自从Ian Goodfellow首次引入架构以来,生成对抗网络(GAN)的研究一直在不断增加,许多相关的进步和应用变得越来越引人入胜。但对于任何想要开始使用GAN的人来说,如何开始是非常棘手的。这篇文章将引导你如何开始使用。
与许多事情一样,充分理解它的概念的最好方法就是溯源。对于GAN,这里是原论文 - >(https://arxiv.org/abs/1406.2661)。现在理解这类论文可以有两种方法,理论和实践。我通常喜欢后者,但是如果你想深入挖掘数学理解,这是一篇很棒的文章。同时,这篇文章将以Keras的最纯粹的形式介绍GAN的一种简单的算法实现。让我们开始吧。
数学理解GAN:
https://medium.com/@samramasinghe/generative-adversarial-networks-a-theoretical-walk-through-5889d5a8f2bb
在GAN的基础设置中,有两个模型,即生成器和判别器,其中生成器不断与判别器竞争,判别器区分模型分布(例如生成的假图像)和数据的分布(例如真实图像)的差别。这个概念可以通过著名的伪造者与警察场景来形象化,其中生成模型被认为是伪造者生产假币,判别器模型作为试图找出假币的警察。这个想法是,由于彼此之间不断的竞争,造假者和警察都提升了自己的业务水平,但最终造假者实现了生产假币和真币一样的水平。原理很简短,现在让我们把它放到代码中。
本文提供的示例脚本用于生成伪造的脸部图像。图1显示了我们试图用算法实现的最终结果。
▌构建生成模型
生成模型应该会吸收一些噪音并输出令人满意的外观图像。在这里,我们使用Keras Sequential模型以及Dense(全连接)和Batch Normalization层。使用的Activation(激活函数)是Leaky Relu。请参阅下面的代码片段。生成模型可以分成几个区块。一个块由Dense层 - >激活 - >Batch Normalization组成。添加了三个这样的块,最后一个块将像素转换为我们期望的图像的期望形状作为输出。模型的输入将是形状(100,)的噪声矢量,并在最后输出模型。注意每个Dense层中的节点随着模型的进展而增加。
def build_generator(self):
noise_shape = (100,)
model = Sequential()
model.add(Dense(256, input_shape=noise_shape))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(1024))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(np.prod(self.img_shape), activation='tanh'))
model.add(Reshape(self.img_shape))
model.summary()
noise = Input(shape=noise_shape)
img = model(noise)
return Model(noise, img)
▌建立判别器模型
判别器接收图像的输入,将其平滑并通过两个Dense- >Activation块,最终输出1和0之间的标量。输出1应表示输入图像是真实的,否则为0。 请参阅下面的代码。
def build_discriminator(self):
img_shape = (self.img_rows, self.img_cols, self.channels)
model = Sequential()
model.add(Flatten(input_shape=img_shape))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(256))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(1, activation='sigmoid'))
model.summary()
img = Input(shape=img_shape)
validity = model(img)
return Model(img, validity)
注意 您可以修改这些模型,以获得更多的块,更多的Batch Normalization层,不同的激活函数等。按照这个例子,这些模型足以理解GAN背后的概念。
▌找出损失并训练
我们计算三中损失,在这个例子中全部使用二分类交叉熵来训练这两个模型。
首先是判别器。 如下面的代码所示,它训练了两种方式。 首先为真实图像输出1(数组'img'),然后为生成的图像输出0(数组'gen_img')。 随着训练的进展,辨别器在此任务中得到改进。 但是我们的最终目标是在鉴别器对两种输入类型输出0.5的理论点(无法判断真假)。
d_loss_real = self.discriminator.train_on_batch(imgs, np.ones((half_batch, 1)))
d_loss_fake = self.discriminator.train_on_batch(gen_imgs, np.zeros((half_batch, 1)))
接下来是训练生成模型,这是棘手的一点。 要做到这一点,我们首先将生成器模型和判别器模型组合起来,用判别器的输出处理生成器模型的输出。 记得! 理想情况下,我们希望这是1,这意味着鉴别器将假造图像识别为真实图像。请参阅下面的代码。
z = Input(shape=(100,))
img = self.generator(z)
valid = self.discriminator(img)
self.combined = Model(z, valid)
self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)
g_loss = self.combined.train_on_batch(noise, valid_y)
▌现在让我们来调教代码吧
这是代码的要点,以简单理解GAN的运作。
完整的代码可以在GitHub上找到。
https://github.com/miranthajayatilake/GANwKeras
您可以参考所有用于导入RGB图像的附加代码,初始化模型并将结果记录在代码中。请注意,在训练期间,为了能够在CPU上运行,将Mini batches设置为Hi32映像。 此外,本例中使用的真实图像是来自CelebA数据集的5000张图像。这是一个开源数据集,我已经将它上传到Floydhub,以便下载,您可以在这里找到。
https://www.floydhub.com/mirantha/datasets/celeba
有很多方法可以优化代码以获得更好的结果,并且这样可以帮助你了解算法的不同组件如何影响模型。在调整优化器,激活函数,归一化,损失损失函数,超参数等不同组件的同时观察结果是增强对算法理解的最佳方法。这里 我选择改变优化器。
因此,用32 batches训练5000 epochs,我使用三种优化算法进行了测试。使用Keras这个过程就像导入和替换优化器函数的名称一样简单。 Keras内置的所有优化器都可以在这里找到。
此外,在每个实例中绘制的损失用于理解模型的行为。
1. 使用SGD(随机梯度下降优化器)。输出和损失变化分别如图2和3所示。
注意:虽然收敛是不平稳的,但我们可以在这里看到,生成器损失在epochs时期减少,这意味着鉴别器倾向于将假图像检测为真实。
2.使用RMSProp优化器。 输出和损失变化分别如图4和5所示。
损失:
注意:在这里,我们也看到生成模型损失在减少,这是一件好事。 令人惊讶的是,真实图像上的判别器损失增加,这非常有趣。
3. 使用Adam优化器。 输出和损耗变化分别如图6和图7所示。
注意: adam优化器产生迄今为止最好的结果。 请注意,假图像上的鉴别器损失保留较大的值,这意味着鉴别器倾向于将假图像检测为真实。
完整代码:
https://github.com/miranthajayatilake/GANwKeras
图像数据:
https://www.floydhub.com/mirantha/datasets/celeba
原文链接:
https://towardsdatascience.com/understanding-and-optimizing-gans-going-back-to-first-principles-e5df8835ae18
- 数据挖掘工程师:如何通过百度地图API抓取建筑物周边位置、房价信息
- crontab导致CPU异常的问题分析及处理(r3笔记第100天)
- 短信接口被恶意调用(二)肉搏战-阻止恶意请求
- 关于首屏时间采集自动化的解决方案
- javax.net.ssl.SSLHandshakeException: No appropriate protocol (protocol is disabled or cipher suites
- 一次数据库无法登陆的问题及排查 (r3笔记第99天)
- 用深度学习keras的cnn做图像识别分类,准确率达97%
- 短信发送接口被恶意访问的网络攻击事件(三)定位恶意IP的日志分析脚本
- job处理缓慢的性能问题排查与分析(r4笔记第18天)
- 京东商品评论情感分析:数据采集与词向量构造方法
- springboot开启access_log日志输出
- 完美的执行计划导致的性能问题(r4笔记第17天)
- 解决Docker容器时区及时间不同步的问题
- 移动端测试方案--sptt
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- 一天一大 lee(复原IP地址)难度:中等-Day20200809
- 一天一大 lee(计数二进制子串)难度:简单-Day20200810
- 一天一大 lee(打家劫舍 III)难度:中等-Day20200805
- 一天一大 lee(课程表)难度:中等-Day20200804
- 一天一大 leet(二叉树展开为链表)难度:中等-Day20200802
- 一天一大 leet(字符串相加)难度:简单-Day20200803
- 一天一大 lee(恢复二叉搜索树)难度:困难-Day20200808
- 一天一大 leet(最小区间)难度:困难-Day20200801
- 一天一大 lee(全排列 II)难度:中等-Day20200918
- 一天一大 lee(冗余连接 II)难度:困难-Day20200917
- 一天一大 lee(左叶子之和)难度:简单-Day20200919
- 【一天一大 lee】 把二叉搜索树转换为累加树 (难度:简单)-Day20200921
- 【一天一大 lee】子集 (难度:中等)-Day20200920
- 用了这个jupyter插件,我已经半个月没打开过excel了
- Webpack学习笔记