MNIST手写数字集的多分类问题(Linear Layer)
时间:2021-09-06
本文章向大家介绍MNIST手写数字集的多分类问题(Linear Layer),主要包括MNIST手写数字集的多分类问题(Linear Layer)使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。
import torch # 引入模块PyTorch
from torchvision import transforms # 从torch视觉中引入转换函数
from torchvision import datasets # 导入数据库
from torch.utils.data import DataLoader # 导入数据加载器
import torch.nn.functional as F # 导入激活函数
import torch.optim as optim # 导入优化器
# 使用的数字分类数据集,只有一个灰度通道0-255
batch_size = 64
transform = transforms.Compose([
transforms.ToTensor(), # 将图像的通道从[0,255]映射到[0,1]
transforms.Normalize((0.1307,), (0.3081,)) # 根据手写数字集的平均值和标准差来实现归一化
]) # 定义mnist数据集转化管道
train_dataset = datasets.MNIST(root='./dataset/mnist/', # 设置数据集存放的根目录
train=True, # 选择是否是训练集
download=True, # 是否从网上下载
transform=transform) # 使用可选参数trannsform来使用前面定义的转换
train_loader = DataLoader(dataset=train_dataset, # 使用数据加载器来读入训练集数据
shuffle=True, # 使用可选参数shuffle=True来打乱训练集
batch_size=batch_size) # 使用可选参数batch_size来确定每批的数据数量
test_dataset = datasets.MNIST(root='./dataset/mnist/',
train=False, # train=False 在这里代表选用了测试集
download=True,
transform=transform)
test_loader = DataLoader(dataset=test_dataset,
shuffle=False, # 测试机数据不用打乱
batch_size=batch_size) # 每批数据数量同前面一样
class Net(torch.nn.Module): # 创建网络类
def __init__(self): # 初始化类
super(Net, self).__init__() # 继承父类
self.l1 = torch.nn.Linear(784, 512) # 实现线性层
self.l2 = torch.nn.Linear(512, 256)
self.l3 = torch.nn.Linear(256, 128)
self.l4 = torch.nn.Linear(128, 64)
self.l5 = torch.nn.Linear(64, 10)
def forward(self, x):
x = x.view(-1, 784) # 将输入的shape=(28,28)的张量展平为(batch_size,28*28)
x = F.gelu(self.l1(x)) # 将输入通过线性层后再经由激活函数gelu进入到下一线性层
x = F.gelu(self.l2(x))
x = F.gelu(self.l3(x))
x = F.gelu(self.l4(x))
return self.l5(x)
model = Net() # 将神经网络模型实例化给model
criterion = torch.nn.CrossEntropyLoss() # 定义多分类损失函数
# 使用SGD(Stochastic Gradient for Tensor Decomposition)随机梯度下降优化器来修正模型的参数
# 并定义了学习率和动量来加速网络的收敛
optimizer = torch.optim.SGD(model.parameters(), lr=0.01,
momentum=.5)
def train(epoch): # 定义训练函数
running_loss = 0 # 初始化损失为0
for batch_idx, (x, y) in enumerate(train_loader): # 从可迭代对象train_loader中获取获取批数,还有输入和输出
inputs, target = x, y # 将x,y分别给输入和目标值
optimizer.zero_grad() # 初始化将梯度置0
outputs = model(inputs) # 输入经由model的正向传播得到输入
loss = criterion(outputs, target) # 通过criterion函数来计算损失
loss.backward() # 将损失函数进行反向传播计算梯度
optimizer.step() # 根据梯度来更新参数
running_loss += loss # 计算每300批的损失值
if batch_idx % 300 == 299:
print('[%d,%5d] loss:%.3f' % (epoch + 1, batch_idx + 1, running_loss / 300)) # 输出300批损失值的平均值
running_loss = 0 # 将损失值置0
def test(): # 定义测试集
corrent = 0 # 定义正确率
total = 0 # 定义总数
with torch.no_grad(): # 因为是测试集,不用计算跟踪梯度
for data in test_loader:
images, labels = data
outputs = model(images)
_, prediction = torch.max(outputs, dim=1) # torch.max()函数按维度dim返回最大值的那个元素和索引
total += labels.size(0) # labels的size(0)就是每批数据的数目
corrent += (prediction == labels).sum().item() # 用预测正确的数目的和除以总数目来获得正确率
print('Accuracy on test set %d %%' % (100 * corrent / total))
if __name__ == '__main__':
for epoch in range(10): # 训练十轮
train(epoch)
test()
运行结果:
原文地址:https://www.cnblogs.com/Reion/p/15235991.html
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- laravel实现一个上传图片的接口,并建立软链接,访问图片的方法
- Laravel中validation验证 返回中文提示 全局设置的方法
- laravel5表单唯一验证的实例代码
- 实现laravel 插入操作日志到数据库的方法
- laravel validate 设置为中文的例子(验证提示为中文)
- Laravel 使用查询构造器配合原生sql语句查询的例子
- php面试实现反射注入的详细方法
- laravel框架 api自定义全局异常处理方法
- laravel实现于语言包的完美切换方法
- PHP校验15位和18位身份证号的类封装
- 用Echarts打造一个轮播图!
- Laravel5.5 实现后台管理登录的方法(自定义用户表登录)
- PHP 获取客户端 IP 地址的办法实例代码
- laravel http 自定义公共验证和响应的方法
- Windows服务器中PHP如何安装redis扩展