PyTorch5:torch.nn总览&torch.nn.Module
1. torch.nn 总览
PyTorch 把与深度学习模型搭建相关的全部类全部在 torch.nn
这个子模块中。
根据类的功能分类,常用的有如下十几个部分:
-
Containers:容器类,如
torch.nn.Module
; -
Convolution Layers:卷积层,如
torch.nn.Conv2d
; -
Pooling Layers:池化层,如
torch.nn.MaxPool2d
; -
Non-linear activations:非线性激活层,如
torch.nn.ReLU
; -
Normalization layers:归一化层,如
torch.nn.BatchNorm2d
; -
Recurrent layers:循环神经层,如
torch.nn.LSTM
; -
Transformer layers:transformer 层,如
torch.nn.TransformerEncoder
; -
Linear layers:线性连接层,如
torch.nn.Linear
; -
Dropout layers:dropout 层,如
torch.nn.Dropout
; -
Sparse layers:稀疏层,如
torch.nn.Embedding
; -
Vision layers:vision 层,如
torch.nn.Upsample
; -
DataParallel layers:平行计算层,如
torch.nn.DataParallel
; -
Utilities:其它功能,如
torch.nn.utils.clip_grad_value_
。
而在 torch.nn
下面还有一个子模块 torch.nn.functional
,基本上是 torch.nn
里对应类的函数,比如 torch.nn.ReLU
的对应函数是 torch.nn.functional.relu
。
为什么要这么做呢?
你可能会疑惑为什么需要这两个功能如此相近的模块,其实这么设计是有其原因的。如果我们只保留 nn.functional 下的函数的话,在训练或者使用时,我们就要手动去维护 weight,bias,stride 这些中间量的值,这显然是给用户带来了不便。而如果我们只保留 nn 下的类的话,其实就牺牲了一部分灵活性,因为做一些简单的计算都需要创造一个类,这也与 PyTorch 的风格不符。(知乎回答)
torch.nn
可以被 nn.Module
识别,并成为网络组成的一部分;torch.nn.functional
则不行。
比较以下两个模型:
>>> class Simple(nn.Module):
... def __init__(self):
... super(Simple, self).__init__()
... self.fc = nn.Linear(10, 1)
... self.dropout = nn.Dropout(0.5) # 使用 nn.Dropout 类
... def forward(self, x):
... x = self.fc(x)
... x = self.dropout(x)
... return x
>>> simple = Simple()
>>> print(simple)
Simple(
(fc): Linear(in_features=10, out_features=1, bias=True)
(dropout): Dropout(p=0.5, inplace=False) #可以被识别成一层
)
>>> class Simple2(nn.Module):
... def __init__(self):
... super(Simple2, self).__init__()
... self.fc = nn.Linear(10, 1)
... def forward(self, x):
... x = F.dropout(self.fc(x)) # 使用 nn.functional.dropout,不能被识别
... return x
>>> simple2 = Simple2()
>>> print(simple2)
Simple2(
(fc): Linear(in_features=10, out_features=1, bias=True)
)
什么时候调用 torch.nn
,什么时候调用 torch.nn.functional
呢?很多人的经验是:不需要存储权重的时候使用 torch.nn.functional
,需要存储权重的时候使用 torch.nn
:
- 层使用
torch.nn
; - dropout 使用
torch.nn
; - 激活函数使用
torch.nn.functional
;
这里要额外说一下 dropout 层。理论上 dropout 没有权重,可以使用 torch.nn.functional.dropout
,然而 dropout 有train
和 eval
模式,使用 torch.nn.Dropout
可以方便地使用 model.train()
或 model.eval()
对模式进行控制,而 torch.nn.functional.dropout
函数就不行。所以为了方便,推荐使用 torch.nn.Dropout
。
以后若没有特殊说明,均在引入模块时省略 torch
模块名称。
创造一个模型分两步:构建模型和权值初始化。而构建模型又有“定义单独的网络层”和“把它们拼在一起”两步。
2. torch.nn.Module
torch.nn.Module
是所有 torch.nn
中的类的父类。我们来看一个非常简单的神经网络:
class SimpleNet(nn.Module):
def __init__(self, x):
super(SimpleNet,self).__init__()
self.fc = nn.Linear(x.shape[0], 1)
def forward(self, x):
x = self.fc(x)
return x
我们随便喂给它一个张量,打印它的网络:
>>> simpleNet = SimpleNet(torch.tensor((10, 2)))
>>> print(simpleNet)
SimpleNet(
(fc): Linear(in_features=2, out_features=1, bias=True)
)
所有自定义的神经网络都要继承 torch.nn.Module
。定义单独的网络层在 __init__
函数中实现,把定义好的网络层拼接在一起在 forward
函数中实现。网络类有两个重要的函数:parameters
存储了模型的权重;modules
存储了模型的结构。
>>> list(simpleNet.modules())
[SimpleNet(
(fc): Linear(in_features=2, out_features=1, bias=True)
),
Linear(in_features=2, out_features=1, bias=True)]
>>> list(simpleNet.parameters())
[Parameter containing:
tensor([[ 0.1533, -0.2574]], requires_grad=True),
Parameter containing:
tensor([-0.1589], requires_grad=True)]
3. torch.nn.Sequential
这是一个序列容器,既可以放在模型外面单独构建一个模型,也可以放在模型里面成为模型的一部分。
# 单独成为一个模型
model1 = nn.Sequential(
nn.Conv2d(1,20,5),
nn.ReLU(),
nn.Conv2d(20,64,5),
nn.ReLU()
)
# 成为模型的一部分
class LeNetSequential(nn.Module):
def __init__(self, classes):
super(LeNetSequential, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 6, 5),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(6, 16, 5),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),)
self.classifier = nn.Sequential(
nn.Linear(16*5*5, 120),
nn.ReLU(),
nn.Linear(120, 84),
nn.ReLU(),
nn.Linear(84, classes),)
def forward(self, x):
x = self.features(x)
x = x.view(x.size()[0], -1)
x = self.classifier(x)
return x
放在模型里面的话,模型还是需要 __init__
和 forward
函数。
这样构建出来的模型的层没有名字:
>>> model2 = nn.Sequential(
... nn.Conv2d(1,20,5),
... nn.ReLU(),
... nn.Conv2d(20,64,5),
... nn.ReLU()
... )
>>> model2
Sequential(
(0): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
(1): ReLU()
(2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
(3): ReLU()
)
为了方便区分不同的层,我们可以使用 collections
里的 OrderedDict
函数:
>>> from collections import OrderedDict
>>> model3 = nn.Sequential(OrderedDict([
... ('conv1', nn.Conv2d(1,20,5)),
... ('relu1', nn.ReLU()),
... ('conv2', nn.Conv2d(20,64,5)),
... ('relu2', nn.ReLU())
... ]))
>>> model3
Sequential(
(conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
(relu1): ReLU()
(conv2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
(relu2): ReLU()
)
4. torch.nn.ModuleList
将网络层存储进一个列表,可以使用列表生成式快速生成网络,生成的网络层可以被索引,也拥有列表的方法 append
,extend
或 insert
。
>>> class MyModule(nn.Module):
... def __init__(self):
... super(MyModule, self).__init__()
... self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])
... self.linears.append(nn.Linear(10, 1)) # append
... def forward(self, x):
... for i, l in enumerate(self.linears):
... x = self.linears[i // 2](x) + l(x)
... return x
>>> myModeul = MyModule()
>>> myModeul
MyModule(
(linears): ModuleList(
(0): Linear(in_features=10, out_features=10, bias=True)
(1): Linear(in_features=10, out_features=10, bias=True)
(2): Linear(in_features=10, out_features=10, bias=True)
(3): Linear(in_features=10, out_features=10, bias=True)
(4): Linear(in_features=10, out_features=10, bias=True)
(5): Linear(in_features=10, out_features=10, bias=True)
(6): Linear(in_features=10, out_features=10, bias=True)
(7): Linear(in_features=10, out_features=10, bias=True)
(8): Linear(in_features=10, out_features=10, bias=True)
(9): Linear(in_features=10, out_features=10, bias=True)
(10): Linear(in_features=10, out_features=1, bias=True) # append 进的层
)
)
5. torch.nn.ModuleDict
这个函数与上面的 torch.nn.Sequential(OrderedDict(...))
的行为非常类似,并且拥有 keys
,values
,items
,pop
,update
等词典的方法:
>>> class MyDictDense(nn.Module):
... def __init__(self):
... super(MyDictDense, self).__init__()
... self.params = nn.ModuleDict({
... 'linear1': nn.Linear(512, 128),
... 'linear2': nn.Linear(128, 32)
... })
... self.params.update({'linear3': nn.Linear(32, 10)}) # 添加层
... def forward(self, x, choice='linear1'):
... return torch.mm(x, self.params[choice])
>>> net = MyDictDense()
>>> print(net)
MyDictDense(
(params): ModuleDict(
(linear1): Linear(in_features=512, out_features=128, bias=True)
(linear2): Linear(in_features=128, out_features=32, bias=True)
(linear3): Linear(in_features=32, out_features=10, bias=True)
)
)
>>> print(net.params.keys())
odict_keys(['linear1', 'linear2', 'linear3'])
>>> print(net.params.items())
odict_items([('linear1', Linear(in_features=512, out_features=128, bias=True)), ('linear2', Linear(in_features=128, out_features=32, bias=True)), ('linear3', Linear(in_features=32, out_features=10, bias=True))])
- 复合事件处理(Complex Event Processing)介绍
- Quartz.net官方开发指南 第三课:更多关于Jobs和JobDetails
- 为你的WordPress 主题添加结构化数据/丰富文本摘要,高亮搜索结果(下)
- Quartz.net官方开发指南 第四课:关于Triggers更多内容
- 数据分析:寻找Python最优计算性能
- 事件流处理框架NEsper for .NET
- Quartz.net官方开发指南 第五课: SimpleTrigger
- SQL Server Performance Dashboard Reports
- 添加WordPress评论输入邮箱实时显示Gravatar头像功能
- Quartz.net官方开发指南 第六课 : CronTrigger
- WordPress 中禁止某个用户在线编辑主题
- Quartz.net官方开发指南 第七课 : TriggerListeners和JobListeners
- Quartz.net官方开发指南 第八课:SchedulerListeners
- 为WordPress 后台编辑器文本模式(HTML模式)添加按钮
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- Kubernetes节点的驱逐与预留
- 使用reveal.js制作精美的网页版PPT
- Ceph快照爱你不容易系列 03:快照数据一致性浅析
- 没想到竟是因为它!让我的服务器变成了别人的挖矿工具
- 从零到一,Serverless 平台在滴滴内部落地
- React 使用 Proxy 代理(create-react-app)
- .Net Core + EF + mysql 从数据库生成实体
- Git 常用命令
- Nodejs 一些细节 (持续更新)
- Jenkins 凭据使用
- React源码解读【一】API复习与基础
- choco 安装 和 mkcert 本地https
- js 函数柯里化(Currying)
- GPS数据Python解析及地图可视化
- 文稿:Ant Design从无到有,带你体悟大厂前端开发范式