Pytorch笔记 (2) 初识Pytorch
一、人工神经网络库
Pytorch ———— 让计算机 确定神经网络的结构 + 实现人工神经元 + 搭建人工神经网络 + 选择合适的权重
(1)确定人工神经网络的 结构:
只需要告诉Pytorch 神经网络 中的神经元个数 每个神经元是怎么样的【比如 输入 输出 非线性函数】 各神经元的连接方式
(2)确定人工神经元的权重值:
只需要告诉 pytorch 什么样的权重值比较好
(3)处理 输入和输出:
pytorch 可以和其他库合作,协助处理神经网络的 输入和输出
二、利用Pytorch 实现 迷你AlphaGo
可以把X[0] X[1] X[2] 三个输入看作 当前局势,把y看作下一步要下的棋,把g看作胜率函数,以找到 最优的 下棋策略
我们不需要知道 从X到 y的 关系的形式,只需要搭建神经网络
不需要告诉神经元的权重都是多少,pytorch 可以帮助找到 神经元的权重
步骤:
只需要把下方 四段代码,前后连接,即可
(1)定义神经网络
from torch.nn import Linear,ReLU,Sequential net = Sequential( Linear(3,8), #第一层 8 个神经元 ReLU(),# 第一层神经元的 非线性函数是max(·,0) Linear(8,8), #第二层 8个神经元 ReLU(),#非线性函数是max(·,0) Linear(8,1), #第三层 1 个神经元 )
这个序列中 有三个Linear 类实例 ————> 说明这个 神经网络 有3层
第一个Linear 类实例 用参数 3 8 来构造,这两个参数 说明每个神经元都有 3个输入,一共有8 个神经元
这个序列中有两个ReLU 类实例,也就是说,其中两个层的神经元的非线性函数都是 max(·,0)
这个神经网络最后一层没有使用非线性函数 max(·,0) ————原因: 我们希望将要制作的 应用既能输出≥0 的结果,也能输出<0 的结果
(2)测试函数g()
def g(x,y): x0,x1,x2 = x[:,0] ** 0,x[:,1] ** 1,x[:,2] ** 2 y0 = y[:,0] return (x0 + x1 + x2) * y0 - y0 * y0 - x0 * x1 * x2
(3)寻找合适的神经元的权重
import torch from torch.optim import Adam optimizer = Adam(net.parameters()) for step in range(1000): optimizer.zero_grad() x = torch.randn(1000,3) y = net(x) outputs = g(x,y) loss = -torch.sum(outputs) loss.backward() optimizer.step() if step % 100 == 0: print('第{}次迭代损失 = {}'.format(step,loss))
第0次迭代损失 = -533.194091796875 第100次迭代损失 = -1128.9976806640625 第200次迭代损失 = -1480.289794921875 第300次迭代损失 = -1731.8543701171875 第400次迭代损失 = -1867.0120849609375 第500次迭代损失 = -1623.46728515625 第600次迭代损失 = -1827.7152099609375 第700次迭代损失 = -1860.97216796875 第800次迭代损失 = -1743.3468017578125 第900次迭代损失 = -1622.2218017578125
代码在第三行构造了优化器 optimizer,这个优化器每次可以改良所有权重值,但是这个改良不是一步到位的
需要让优化器反复循环很多次【后面缩进的语句都是要循环的内容】 ———— 每次需要告诉优化器 每次改良的依据是什么
通过 optimizer.step() 完成权重的改良
完成后,就训练好了神经网络
(4)测试神经网络的性能
#生成测试数据 x_test = torch.randn(2,3) print('测试输入:{}'.format(x_test)) # 查看神经网络的计算结果 y_test = net(x_test) print ('人工神经网络计算结果: {}'.format(y_test)) print('g的值:{}'.format(g(x_test,y_test))) #根据理论,计算参考答案 def argmax_g(x): x0,x1,x2 = x[:,0] ** 0,x[:,1] ** 1,x[:,2] ** 2 return 0.5 * (x0 + x1 + x2)[:, None] yref_test = argmax_g(x_test) print('理论最优值:{}'.format(yref_test)) print('g的值:{}'.format(g(x_test,yref_test)))
测试输入:tensor([[ 0.1865, 1.4210, 1.1290], [-0.2137, 0.1621, 0.9952]]) 人工神经网络计算结果: tensor([[1.9692], [1.0804]], grad_fn=<AddmmBackward>) g的值:tensor([1.5885, 0.9977], grad_fn=<SubBackward0>) 理论最优值:tensor([[1.8479], [1.0762]]) g的值:tensor([1.6032, 0.9977])
可以断定,我们的神经网络 已经正确地 输出了最优结果
由于 验证代码的输入是随机确定的。所以每次运行的输入和输出都不一样
原文地址:https://www.cnblogs.com/expedition/p/11369239.html
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- 定义一个函数,在该函数中可以实现任意两个整数的加法。java实现
- JDBC连接ORACLE的三种URL格式
- Centreon+Nagios实战第八篇——Nagios+Centreon添加监控服务
- Centreon+Nagios实战第九篇——利用nrpe插件监控本机
- 第四篇 CentOs7下安装Zabbix
- 第十三篇 zabbix创建Item
- 第十四篇 zabbix创建自定义Item
- 【性能】688- 前端性能优化——从 10 多秒到 1.05 秒
- 第十六篇 zabbix创建Trigger
- 如何通过程序(java代码)提高你的博客访问量
- zabbix_get [12429]: Check access restrictions in Zabbix agent configuration
- 【设计模式】689- TypeScript 设计模式之观察者模式
- Found a swap file by the name ".jsidInspector.py.swp"
- CentOs7下部署tomcat文件服务器
- 【拓展】未来的JavaScript记录与元组