【动手学深度学习笔记】之自定义层
时间:2022-07-23
本文章向大家介绍【动手学深度学习笔记】之自定义层,主要内容包括其使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。
1.自定义层
神经网络中存在着全连接层、卷积层、池化层和循环层等各种各样的层。虽然PyTorch提供了大量常用的层,但有时还是需要我们自定义层。本篇文章介绍如何使用Module
类来自定义层。
1.1 不含模型参数的自定义层
下面以实例介绍一下通过继承Module
类定义不含模型参数的自定义层。
class layer(nn.Module):
def __init__(self,**keywargs):
#直接继承Module的__init__()
super(layer,self).__init__(**keywargs)
def forward(self,x):
#定义前向传导
return x-x.mean()
定义的这个层并没有模型参数。实例化例子如下
layer = layer()
layer(torch.tensor([1,2,3,4,5],dtype=torch.float))
Out[1]:tensor([-2., -1., 0., 1., 2.])
同样可以使用Sequential
类将这个层添加到网络。
net = nn.Sequential(net.Linear(8,128),layer())
1.2 含模型参数的自定义层
为自定义层添加模型参数有以下三种方式。
使用Parameter类
上一篇文章介绍过,当一个Tensor
类型为Parameters
时,它将会被自动添加到参数列表中。
class net1(nn.Module):
def __init__(self):
super(net1,self).__init__()
self.weight = nn.Parameter(torch.rand(4,4))
self.bais = nn.Parameter(torch.rand(4,1))
def forward(self,x):
for i in range(len(self.params)):
x = torch.mm(x, self.params[i])
return x
net1 = net1()
for name,param in net1.named_parameters():
print(name,param)
Out[1]:
weight Parameter containing:
tensor([[0.3217, 0.8082, 0.2425, 0.3970],
[0.6009, 0.2262, 0.7150, 0.6720],
[0.4062, 0.6335, 0.6234, 0.2680],
[0.1824, 0.0825, 0.8183, 0.2564]], requires_grad=True)
bais Parameter containing:
tensor([[0.3952],
[0.4866],
[0.9082],
[0.1949]], requires_grad=True)
使用ParameterList类
ParameterList
类接收Parameters
实例的列表作为输入然后得到一个参数列表,与List
类似,可以使用索引访问,append
添加。
class MyDense(nn.Module):
def __init__(self):
super(MyDense, self).__init__()
self.params = nn.ParameterList([nn.Parameter(torch.randn(4, 4)) for i in range(3)])
self.params.append(nn.Parameter(torch.randn(4, 1)))
def forward(self, x):
for i in range(len(self.params)):
x = torch.mm(x, self.params[i])
return x
net = MyDense()
print(net)
Out[1]:
MyDense(
(params): ParameterList(
(0): Parameter containing: [torch.FloatTensor of size 4x4]
(1): Parameter containing: [torch.FloatTensor of size 4x4]
(2): Parameter containing: [torch.FloatTensor of size 4x4]
(3): Parameter containing: [torch.FloatTensor of size 4x1]
)
)
使用ParameterDict类
ParameterDict
类接收一个Parameter
实例的字典作为输入,返回一个参数字典,同样可以使用updata()
添加参数,使用key()
返回所有键值,使用item()
返回所有键值对等字典操作。
class MyDictDense(nn.Module):
def __init__(self):
super(MyDictDense, self).__init__()
self.params = nn.ParameterDict({
'linear1': nn.Parameter(torch.randn(4, 4)),
'linear2': nn.Parameter(torch.randn(4, 1))
})
self.params.update({'linear3': nn.Parameter(torch.randn(4, 2))}) # 新增
def forward(self, x, choice='linear1'):
return torch.mm(x, self.params[choice])
net = MyDictDense()
print(net)
Out[1]:
MyDictDense(
(params): ParameterDict(
(linear1): Parameter containing: [torch.FloatTensor of size 4x4]
(linear2): Parameter containing: [torch.FloatTensor of size 4x1]
(linear3): Parameter containing: [torch.FloatTensor of size 4x2]
)
)
使用ParameterDict
类,可以通过选择不同的键,来进行不同的正向传播。
x = torch.ones(1, 4)
print(net(x, 'linear1'))
print(net(x, 'linear2'))
print(net(x, 'linear3'))
Out[1]:
tensor([[1.5082, 1.5574, 2.1651, 1.2409]], grad_fn=<MmBackward>)
tensor([[-0.8783]], grad_fn=<MmBackward>)
tensor([[ 2.2193, -1.6539]], grad_fn=<MmBackward>)
上述这些中方法创建的层,都可以像PyTorch
中其他层一样,通过Sequential
类、ModuleList
类和ModuleDict
类等方法构造模型。
- 使用shell定制addm脚本(r3笔记第88天)
- 【专业技术第十三讲】指针和内存泄露
- 【Java案例】余弦函数
- MySQL数据类型(r3笔记第87天)
- NLP真实项目:利用这个模型能够通过商品评论去预测一个商品的销量
- python + selenium + PhantomJS 获取腾讯应用宝APP评论
- 简单实用的sql小技巧(第二篇)(r3笔记第86天)
- Java代码效率优化【面试+提高】
- 利用逻辑回归模型判断用户提问意图
- 关于reset sequence(r3笔记第85天)
- 【编程基础第十二讲】web开发编程基础--回调函数
- typeof的一些兼容性问题
- 类型转换的判定方式
- 【Java案例】打印杨辉三角
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- 隧道构建:端口转发的原理和实现
- SAP Spartacus注入自定义的CurrentProductService
- Redis系列(十一)redis命令全集
- Jinkens+gitlab针对k8s集群实现CI/CD
- Vue 踩过的坑
- Java TCP/UDP/HttpClient简例
- 让你设计实现一个签到功能,到底用MySQL还是Redis?
- 如何防止MySQL重复插入数据,这篇文章会告诉你
- Spring AOP注解开发
- 快速学习-Jenkins CLI凭据
- 快速学习-Jenkins CLI任务
- 珍惜数据,远离钓鱼
- Android Pie限制非 SDK 接口的调用
- 多线程基础(十一):interrupt深度分析
- [记录点滴]授人以渔,从Tensorflow找不到dll扩展到如何排查问题