轻松学Pytorch – 构建生成对抗网络
时间:2022-07-24
本文章向大家介绍轻松学Pytorch – 构建生成对抗网络,主要内容包括其使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。
引言
又好久没有继续写了,这个是我写的第21篇文章,我还在继续坚持写下去,虽然经常各种拖延症,但是我还记得,一直没有敢忘记!今天给大家分享一下Pytorch生成对抗网络代码实现。
01
什么是生成对抗网络
Ian J. Goodfellow在2014年提出生成对抗网络,从此打开了深度学习中另外一个重要分支,让生成对抗网络(GAN)成为与卷积神经网络(CNN)、循环神经网络(RNN/LSTM)可以并驾齐驱的分支领域。今天GAN仍然是计算机视觉领域研究热点之一,每年还有大量相关的论文产生,GAN已经被用在视觉任务的很多方面,主要包括:
- 图像合成与数据增广
- 图像翻译与变换
- 缺陷检测
- 图像去噪与重建
- 图像分割
但是GAN最基本的核心思想还是2014年Ian J. Goodfellow在论文中提到的两个基本的模型分别是:生成器与判别器
生成器(G):
根据输入噪声Z生成输出样本G(z)
目标:通过生成样本与目标样本分布一致,成功欺骗鉴别器
判别器(D):
根据输入样本数据来分辨真实样本概率
从数据中学习样本数据的差异性
从a到d,可以看到输入噪声的生成分布越来越接近真实分布X,最终达到一种平衡状态,这种稳定的平衡状态叫纳什均衡,还有一部电影跟这个有关系叫《美丽心灵》。
02
GAN代码实现
下面的代码实现了基于Mnist数据集实现判别器与生成器,最终通过生成器可以自动生成手写数字识别的图像,输入的z=100是随机噪声,输出的是784个数据表示28x28大小的手写数字样本,损失主要来自两个部分,生成器生成损失,判别器分别判别真实与虚构样本概率,基于反向传播训练两个网络,设置epoch=100,得到最终的生成器生成结果如下:
生成器与判别器代码实现如下
HARR特征级联分类器人脸检测来自VJ的2004论文中提出,其主要思想可以通过下面一张图像解释:
transform = tv.transforms.Compose([tv.transforms.ToTensor(),
tv.transforms.Normalize((0.5,), (0.5,))])
train_ts = tv.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_ts = tv.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_dl = DataLoader(train_ts, batch_size=128, shuffle=True, drop_last=False)
test_dl = DataLoader(test_ts, batch_size=128, shuffle=True, drop_last=False)
class Generator(t.nn.Module):
def __init__(self, g_input_dim, g_output_dim):
super(Generator, self).__init__()
self.fc1 = t.nn.Linear(g_input_dim, 256)
self.fc2 = t.nn.Linear(self.fc1.out_features, self.fc1.out_features * 2)
self.fc3 = t.nn.Linear(self.fc2.out_features, self.fc2.out_features * 2)
self.fc4 = t.nn.Linear(self.fc3.out_features, g_output_dim)
# forward method
def forward(self, x):
x = F.leaky_relu(self.fc1(x), 0.2)
x = F.leaky_relu(self.fc2(x), 0.2)
x = F.leaky_relu(self.fc3(x), 0.2)
return t.tanh(self.fc4(x))
class Discriminator(t.nn.Module):
def __init__(self, d_input_dim):
super(Discriminator, self).__init__()
self.fc1 = t.nn.Linear(d_input_dim, 1024)
self.fc2 = t.nn.Linear(self.fc1.out_features, self.fc1.out_features // 2)
self.fc3 = t.nn.Linear(self.fc2.out_features, self.fc2.out_features // 2)
self.fc4 = t.nn.Linear(self.fc3.out_features, 1)
# forward method
def forward(self, x):
x = F.leaky_relu(self.fc1(x), 0.2)
x = F.dropout(x, 0.3)
x = F.leaky_relu(self.fc2(x), 0.2)
x = F.dropout(x, 0.3)
x = F.leaky_relu(self.fc3(x), 0.2)
x = F.dropout(x, 0.3)
return t.sigmoid(self.fc4(x))
损失与训练代码如下
分别定义生成网络训练与鉴别网络的训练方法,然后开始训练即可,代码实现如下:
# 生成者与判别者
bs = 128
z_dim = 100
mnist_dim = 784
# loss
criterion = t.nn.BCELoss()
# optimizer
device = "cuda"
gnet = Generator(g_input_dim = z_dim, g_output_dim = mnist_dim).to(device)
dnet = Discriminator(mnist_dim).to(device)
lr = 0.0002
G_optimizer = t.optim.Adam(gnet.parameters(), lr=lr)
D_optimizer = t.optim.Adam(dnet.parameters(), lr=lr)
def D_train(x):
# =======================Train the discriminator=======================#
dnet.zero_grad()
# train discriminator on real
x_real, y_real = x.view(-1, mnist_dim), t.ones(bs, 1)
x_real, y_real = Variable(x_real.to(device)), Variable(y_real.to(device))
D_output = dnet(x_real)
D_real_loss = criterion(D_output, y_real)
# train discriminator on facke
z = Variable(t.randn(bs, z_dim).to(device))
x_fake, y_fake = gnet(z), Variable(t.zeros(bs, 1).to(device))
D_output = dnet(x_fake)
D_fake_loss = criterion(D_output, y_fake)
# gradient backprop & optimize ONLY D's parameters
D_loss = D_real_loss + D_fake_loss
D_loss.backward()
D_optimizer.step()
return D_loss.data.item()
def G_train(x):
# =======================Train the generator=======================#
gnet.zero_grad()
z = Variable(t.randn(bs, z_dim).to(device))
y = Variable(t.ones(bs, 1).to(device))
G_output = gnet(z)
D_output = dnet(G_output)
G_loss = criterion(D_output, y)
# gradient backprop & optimize ONLY G's parameters
G_loss.backward()
G_optimizer.step()
return G_loss.data.item()
n_epoch = 100
for epoch in range(1, n_epoch+1):
D_losses, G_losses = [], []
for batch_idx, (x, _) in enumerate(train_dl):
bs_, _,_,_ = x.size()
bs = bs_
D_losses.append(D_train(x))
G_losses.append(G_train(x))
print('[%d/%d]: loss_d: %.3f, loss_g: %.3f' % (
(epoch), n_epoch, t.mean(t.FloatTensor(D_losses)), t.mean(t.FloatTensor(G_losses))))
- Leetcode 290. Word Pattern
- 【深度学习】使用tensorflow实现VGG19网络
- Leetcode 289. Game of Life
- Leetcode 287. Find the Duplicate Number
- Leetcode 284. Peeking Iterator
- Leetcode 283. Move Zeroes
- Leetcode 282. Expression Add Operators
- Leetcode 279. Perfect Squares
- Leetcode 278. First Bad Version
- Leetcode 275. H-Index II
- Leetcode 274. H-Index
- 值得 .NET 开发者了解的15个特性
- Angular和Vue.js 深度对比
- 前端开发者常用的9个JavaScript图表库
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- 使用XtraBackup备份MySQL 8.0 Part 2 XtraBackup权限及配置
- 使用XtraBackup备份MySQL 8.0 Part 6 对数据库进行增量备份
- ArrayList源码阅读笔记
- 3分钟短文 | Laravel表单验证没规则可用?你试试自定义,真香!
- 【C#】DataGridView 数据绑定的一些细节
- 3分钟短文 | Laravel 查询结果检查是不是空,5个方法你别用错!
- 使用XtraBackup备份MySQL 8.0 Part 7 对增量备份进行恢复
- 3分钟短文 | Laravel 日志全程记录 SQL 查询语句,要改写底层?
- MySQL InnoDB表空间加密
- 微信小程序自动化测试最佳实践(附 Python 源码)
- 3分钟短文 | MySQL在分组时,把多列合并为一个字段!
- Redis Linux系统参数最佳配置
- 实现Promise其它API
- 使用sysbench进行压测 Part1 sysbench安装
- Java并发编程(07):Fork/Join框架机制详解