【动手学深度学习笔记】之构造MLP模型的几种方法
时间:2022-07-23
本文章向大家介绍【动手学深度学习笔记】之构造MLP模型的几种方法,主要内容包括其使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。
构造一个MLP模型的几种方法
本篇文章以构造一个MLP模型为例,介绍几种构造模型的常见方法。
1. 直接继承Module类
Module类是nn模块里提供的一个模型构造类,通过继承Module实现MLP的程序如下
class MLP(nn.Module):
def __init__(self):
super(MLP,self).__init__()
self.hidden = nn.Linear(784,256)
self.act = nn.ReLU()
self.output = nn.Linear(256,10)
def forward(self,x):
a = self.act(self.hidden(x))
return self.output(a)
在构造这个模型的过程中,仅仅定义了正向传播。系统将通过自动求梯度而自动生成反向传播所需要的backward函数。下面将MLP类实例化。
X = torch.rand(2,784)
#随即胜场两组样本
net = MLP()
print(net)
net(X)
神经网络结构和正向传播一次的结果:
2. 使用Module的子类
Module的子类包括:Sequential类、ModuleList类和ModuleDict类
2.1 Sequential类
当模型的正向传播为简单串联各个层的计算时,Sequential类可以通过更简单的方式定义模型。Sequential类可以接收一系列子模块作为参数来逐一添加Module的实例。模型的正向传播就是将这些实例按添加顺序逐一计算。
使用Sequential类实现MLP模型,并使用随机生成的样本做一次前向计算。
#构造
net = nn.Sequential(
nn.Linear(784,256),
nn.ReLU(),
nn.Linear(256,10),
)
#输出结果和进行一次前向计算
X = torch.rand(2,784)
print(net)
net(X)
神经网络结构和正向传播一次的结果:
2.2 ModuleList类
ModuleList接收一个子模块的列表作为输入。并且可以像list那样进行索引、append和extend操作。但不同于一般List的地方是加入到ModuleList里面的所有模块的参数会被自动添加到整个网络。
下面以实例说明如何使用ModuleList类构造MLP模型。
#构造
net = nn.ModuleList([nn.Linear(784,256),nn.ReLU()])
net.append(nn.Linear(256,10))
print(net[-1])
#输出net中最后一个实例
print(net)
神经网络结构:
2.3 ModuleDict类
ModuleDict类接受一个子模块的字典作为输入,也可以类似字典那样进行添加访问操作。
net = nn.ModuleDict({
'linear':nn.Linear(784,256),
'act':nn.ReLU(),
})
#类似字典的访问方法
net['output']=nn.Linear(256,10)
print(net['linear'])
print(net.output)
#输出结构
print(net)
神经网络结构:
3. 构造复杂模型
上述两种方法各有利弊。下面我们综合使用这两种方法,构造一个复杂的神经网络FancyMLP。在这个神经网络中,我们需要创建常数参数(训练中不被迭代的参数),在前向计算中,还需要使用Tensor的函数和Python控制流并多次调用相同的层。
class FancyMLP(nn.Module):
def __init__(self):
super(FancyMLP, self).__init__()
self.rand_weight = torch.rand((20, 20), requires_grad=False) #常数参数
self.linear = nn.Linear(20, 20)
def forward(self, x):
x = self.linear(x)
# 使用创建的常数参数
x = nn.functional.relu(torch.mm(x, self.rand_weight.data) + 1)
# 复用全连接层。等价于两个全连接层共享参数
x = self.linear(x)
# 控制流,这里我们需要调用item函数来返回标量进行比较
while x.norm().item() > 1:
x /= 2
if x.norm().item() < 0.8:
x *= 10
return x.sum()
class NestMLP(nn.Module):
def __init__(self):
super(NestMLP, self).__init__()
self.net = nn.Sequential(nn.Linear(40, 30), nn.ReLU())
def forward(self, x):
return self.net(x)
net = nn.Sequential(NestMLP(), nn.Linear(30, 20), FancyMLP())
#嵌套调用FancyMLP和Sequential
X = torch.rand(2, 40)
print(net)
net(X)
神经网络结构:
- WCF技术剖析之十四:泛型数据契约和集合数据契约(上篇)
- WCF技术剖析之十四:泛型数据契约和集合数据契约(下篇)
- WCF技术剖析(卷1)之前言
- WCF技术剖析(卷1)之目录
- WCF技术剖析(卷1)之推荐序
- 谈谈基于SQL Server 的Exception Handling[上篇]
- 谈谈WCF中的Data Contract(4):WCF Data Contract Versioning
- 如何在silverlihgt中使用右键
- WCF技术剖析之十二:数据契约(Data Contract)和数据契约序列化器(DataContractSerializer)
- silverlight向服务器post数据类
- WCF技术剖析之十三:序列化过程中的已知类型(Known Type)
- 44 Amazing Silverlight 2.0 Screencasts
- CaseStudy(showcase)类库篇-用agTweener来实现动画效果
- CaseStudy(showcase)数据篇-Loading的制作
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- android 上传aar到私有maven服务器的示例
- Android Studio开发环境搭建教程详解
- android事件总线EventBus3.0使用方法详解
- Android仿淘口令复制弹出框功能(简答版)
- Android实现简单断点续传和下载到本地功能
- Android用MVP实现一个简单的类淘宝订单页面的示例
- Android Bitmap的截取及状态栏的隐藏和显示功能
- 详解Android沉浸式实现兼容解决办法
- AndroidStudio项目打包成jar的简单方法
- 浅谈React Native打包apk的坑
- Android 设置主题实现点击波纹效果的示例
- 更新Android Studio 3.0碰到的问题小结
- android实现一个图片验证码倒计时功能
- Android添加glide库报错Error:Failed to resolve:com.android.support:26.0.2的解决
- Android多线程下载示例详解