pytorch查看模型weight与grad方式
时间:2022-07-27
本文章向大家介绍pytorch查看模型weight与grad方式,主要内容包括其使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。
在用pdb debug的时候,有时候需要看一下特定layer的权重以及相应的梯度信息,如何查看呢?
1. 首先把你的模型打印出来,像这样
2. 然后观察到model下面有module的key,module下面有features的key, features下面有(0)的key,这样就可以直接打印出weight了,在pdb debug界面输入p model.module.features[0].weight,就可以看到weight,输入 p model.module.features[0].weight.grad就可以查看梯度信息
补充知识:查看Pytorch网络的各层输出(feature map)、权重(weight)、偏置(bias)
BatchNorm2d参数量
torch.nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# 卷积层中卷积核的数量C
num_features – C from an expected input of size (N, C, H, W)
import torch
m = torch.nn.BatchNorm2d(100)
m.weight.shape
torch.Size([100])
m.numel()
AttributeError: 'BatchNorm2d' object has no attribute 'numel'
m.weight.numel()
100
m.parameters().numel()
Traceback (most recent call last):
File "<stdin ", line 1, in <module
AttributeError: 'generator' object has no attribute 'numel'
[p.numel() for p in m.parameters()]
[100, 100]
linear层
import torch
m1 = torch.nn.Linear(100,10)
# 参数数量= (输入神经元+1)*输出神经元
m1.weight.shape
torch.Size([10, 100])
m1.bias.shape
torch.Size([10])
m1.bias.numel()
10
m1.weight.numel()
1000
m11 = list(m1.parameters())
m11[0].shape
# weight
torch.Size([10, 100])
m11[1].shape
# bias
torch.Size([10])
weight and bias
# Method 1 查看Parameters的方式多样化,直接访问即可
model = alexnet(pretrained=True).to(device)
conv1_weight = model.features[0].weight# Method 2
# 这种方式还适合你想自己参考一个预训练模型写一个网络,各层的参数不变,但网络结构上表述有所不同
# 这样你就可以把param迭代出来,赋给你的网络对应层,避免直接load不能匹配的问题!
for layer,param in model.state_dict().items(): # param is weight or bias(Tensor)
print layer,param
feature map
由于pytorch是动态网络,不存储计算数据,查看各层输出的特征图并不是很方便!分下面两种情况讨论:
1、你想查看的层是独立的,那么你在forward时用变量接收并返回即可!!
class Net(nn.Module):
def __init__(self):
self.conv1 = nn.Conv2d(1, 1, 3)
self.conv2 = nn.Conv2d(1, 1, 3)
self.conv3 = nn.Conv2d(1, 1, 3) def forward(self, x):
out1 = F.relu(self.conv1(x))
out2 = F.relu(self.conv2(out1))
out3 = F.relu(self.conv3(out2))
return out1, out2, out3
2、你的想看的层在nn.Sequential()顺序容器中,这个麻烦些,主要有以下几种思路:
# Method 1 巧用nn.Module.children()
# 在模型实例化之后,利用nn.Module.children()删除你查看的那层的后面层
import torch
import torch.nn as nn
from torchvision import modelsmodel = models.alexnet(pretrained=True)# remove last fully-connected layer
new_classifier = nn.Sequential(*list(model.classifier.children())[:-1])
model.classifier = new_classifier
# Third convolutional layer
new_features = nn.Sequential(*list(model.features.children())[:5])
model.features = new_features
# Method 2 巧用hook,推荐使用这种方式,不用改变原有模型
# torch.nn.Module.register_forward_hook(hook)
# hook(module, input, output) - Nonemodel = models.alexnet(pretrained=True)
# 定义
def hook (module,input,output):
print output.size()
# 注册
handle = model.features[0].register_forward_hook(hook)
# 删除句柄
handle.remove()# torch.nn.Module.register_backward_hook(hook)
# hook(module, grad_input, grad_output) - Tensor or None
model = alexnet(pretrained=True).to(device)
outputs = []
def hook (module,input,output):
outputs.append(output)
print len(outputs)handle = model.features[0].register_backward_hook(hook)
注:还可以通过定义一个提取特征的类,甚至是重构成各层独立相同模型将问题转化成第一种
计算模型参数数量
def count_parameters(model): return sum(p.numel() for p in model.parameters() if p.requires_grad)
以上这篇pytorch查看模型weight与grad方式就是小编分享给大家的全部内容了,希望能给大家一个参考。
- 细说log4j
- SEVERE: Error configuring application listener of class org.springframework.web.context.ContextLoade
- TCP/IP(一)之开启计算机网络之路
- JSON入门指南--客户端处理JSON
- mysql5.7 ERROR 1045 (28000): Access denied for user 'root'@'localhost' (using password: NO)
- TCP/IP中你不得不知的十大秘密
- Java Web开发学习之路2012版
- TortoiseSVN客户端使用的2个配置问题
- JavaWeb(二)会话管理之细说cookie与session
- 概率论09 期望
- Javascript中数组的sort()和reverse()方法
- CentOS6.5开放端口,配置防火墙
- JavaWeb(一)Servlet中乱码解决与转发和重定向的区别
- Java魔法堂:四种引用类型、ReferenceQueue和WeakHashMap
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- python爬虫----(3. scrapy框架,简单应用)
- python爬虫----(2. scrapy框架)
- python爬虫----(1. 基本模块)
- 七日Python之路--第十二天(Django Web 开发指南)
- 三日php之路 -- 第一天(php语言参考)
- 三日php之路 -- 第一天(初识php)
- NoSQL数据库 -- MongoDB
- 数据抓取练习
- python基础 -- 简单实现HTTP协议
- RabbitMQ 学习
- asp连接access,增删改查
- Spring 中的如何自定义事件处理(Custom Event)
- python基础 -- 自定义排序
- nginx(安装)
- Spring 中基于 AOP 的 XML操作方式