pytorch快速搭建神经网络_Sequential操作
时间:2022-07-27
本文章向大家介绍pytorch快速搭建神经网络_Sequential操作,主要内容包括其使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。
之前用Class类来搭建神经网络
class Neuro_net(torch.nn.Module):
"""神经网络"""
def __init__(self, n_feature, n_hidden_layer, n_output):
super(Neuro_net, self).__init__()
self.hidden_layer = torch.nn.Linear(n_feature, n_hidden_layer)
self.output_layer = torch.nn.Linear(n_hidden_layer, n_output)
def forward(self, input):
hidden_out = torch.relu(self.hidden_layer(input))
out = self.output_layer(hidden_out)
return out
net = Neuro_net(2, 10, 2)
print(net)
class类图结构:
使用torch.nn.Sequential() 快速搭建神经网络
net = torch.nn.Sequential(
torch.nn.Linear(2, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 2)
)
print(net)
Sequential图结构
总结:
我们可以发现,使用torch.nn.Sequential会自动加入激励函数, 但是 class类net 中, 激励函数实际上是在 forward() 功能中才被调用的
使用class类中的torch.nn.Module,我们可以根据自己的需求改变传播过程
如果你需要快速构建或者不需要过多的过程,直接使用torch.nn.Sequential吧
补充知识:【PyTorch神经网络】使用Moudle和Sequential搭建神经网络
Module:
init中定义每个神经层的神经元个数,和神经元层数;
forward是继承nn.Moudle中函数,来实现前向反馈(加上激励函数)
# -*- coding: utf-8 -*-
# @Time : 2019/11/5 10:43
# @Author : Chen
# @File : neural_network_impl.py
# @Software: PyCharm
import torch
import torch.nn.functional as F
#data
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
y = x.pow(2) + 0.2 * torch.rand(x.size())
#第一种搭建方法:Module
# 其中,init中定义每个神经层的神经元个数,和神经元层数;
# forward是继承nn.Moudle中函数,来实现前向反馈(加上激励函数)
class Net(torch.nn.Module):
def __init__(self):
#继承__init__函数
super(Net, self).__init__()
#定义每层的形式
#隐藏层线性输出feature- hidden
self.hidden = torch.nn.Linear(1, 10)
#输出层线性输出hidden- output
self.predict = torch.nn.Linear(10, 1)
#实现所有层的连接关系。正向传播输入值,神经网络分析输出值
def forward(self, x):
#x首先在隐藏层经过激励函数的计算
x = F.relu(self.hidden(x))
#到输出层给出预测值
x = self.predict(x)
return x
net = Net()
print(net)
print('nn')
#快速搭建:Sequential
#模板:net2 = torch.nn.Sequential()
net2 = torch.nn.Sequential(
torch.nn.Linear(1, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 1)
)
print(net2)
以上这篇pytorch快速搭建神经网络_Sequential操作就是小编分享给大家的全部内容了,希望能给大家一个参考。
- 基于MVC理解React+Redux
- JavaScript的IIFE(即时执行方法)
- 从机器学习学python(三) ——数组冒号取值与extend
- 从机器学习学python(四) ——numpy矩阵基础
- 从map函数引发的讨论
- AngularJs中,如何在render完成之后,执行Js脚本
- PHP取得上周一、上周日,下周一
- 代码诊所
- 《编程之美》读书笔记(一)——中国象棋将帅有效位置
- 有趣的Code Poster
- div 自适应高度 自动填充剩余高度
- PHP开发人员常犯的10个MysqL错误
- android AutoCompleteTextView 自定义BaseAdapter
- Scala中的偏函数
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- java线程池(五):ForkJoinPool源码分析之一(外部提交及worker执行过程)
- JavaScript中的匿名函数、闭包和BOM
- 【Vue.js】Vue.js中的事件处理、过滤器、过渡和动画、组件的生命周期及组件之间的通信
- 树莓派基础实验18:声音传感器实验
- 树莓派基础实验19:光敏传感器实验
- 逻辑式编程还有用吗?--“三维度”逻辑编程语言的设计(2)
- git 报错解决Validate branches Cannot Create: This merge request already exists
- 树莓派基础实验20:火焰报警传感器实验
- (译)SDL编程入门(8)几何图形渲染
- Java8 dubbo 调用 Collectors.toMap代码片发生的异常(IllegalStateException: Duplicate key)
- 树莓派基础实验21:烟雾报警传感器实验
- 树莓派基础实验22:红外遥控传感器实验
- Spring的BeanUtil的copyProperties方法 慎用!!
- (译)SDL编程入门(9)视口
- (译)SDL编程入门(7)纹理加载和渲染