pytorch实现MNIST手写体识别(全连接神经网络)
时间:2019-08-14
本文章向大家介绍pytorch实现MNIST手写体识别(全连接神经网络),主要包括pytorch实现MNIST手写体识别(全连接神经网络)使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。
环境: pytorch1.1 cuda9.0 ubuntu16.04
该网络有3层,第一层input layer,有784个神经元(MNIST数据集是28*28的单通道图片,故有784个神经元)。第二层为hidden_layer,设置为500个神经元。最后一层是输出层,有10个神经元(10分类任务)。在第二层之后还有个ReLU函数,进行非线性变换。
#!/usr/bin/env python # encoding: utf-8 ''' @author: liualex @contact: liualex1109@163.com @software: pycharm @file: main.py @time: 2019/8/14 21:20 @desc: ''' import torch import torchvision import torchvision.transforms as transforms import torch.utils.data.dataloader as dataloader import torch.nn as nn import torch.optim as optim import os os.environ["CUDA_VISIBLE_DEVICES"] = "3" train_set = torchvision.datasets.MNIST( root="./data", train=True, transform=transforms.ToTensor(), download=True ) train_loader = dataloader.DataLoader( dataset=train_set, batch_size=100, shuffle=False, ) test_set = torchvision.datasets.MNIST( root="./data", train=False, transform=transforms.ToTensor(), download=True ) test_loader = dataloader.DataLoader( dataset=train_set, batch_size=100, shuffle=False, ) class NeuralNet(nn.Module): def __init__(self, input_num, hidden_num, output_num): super(NeuralNet, self).__init__() self.fc1 = nn.Linear(input_num, hidden_num) self.fc2 = nn.Linear(hidden_num, output_num) self.relu = nn.ReLU() def forward(self,x): x = self.fc1(x) x = self.relu(x) y = self.fc2(x) return y epoches = 20 lr = 0.001 input_num = 784 hidden_num = 500 output_num = 10 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = NeuralNet(input_num, hidden_num, output_num) model.to(device) criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=lr) for epoch in range(epoches): for i, data in enumerate(train_loader): (images, labels) = data images = images.reshape(-1, 28*28).to(device) labels = labels.to(device) output = model(images) loss = criterion(output, labels) optimizer.zero_grad() loss.backward() optimizer.step() if (i+1) % 100 == 0: print('Epoch [{}/{}], Loss: {:.4f}' .format(epoch + 1, epoches, loss.item())) with torch.no_grad(): correct = 0 total = 0 for images, labels in test_loader: images = images.reshape(-1, 28*28).to(device) labels = labels.to(device) output = model(images) _, predicted = torch.max(output, 1) total += labels.size(0) correct += (predicted == labels).sum().item() print("The accuracy of total {} images: {}%".format(total, 100 * correct/total))
结果:
原文地址:https://www.cnblogs.com/liualexsone/p/11355217.html
- .NET Core采用的全新配置系统[3]: “Options模式”下的配置是如何绑定为Options对象
- 游戏用户中心开发
- .NET Core采用的全新配置系统[4]: “Options模式”下各种类型的Options对象是如何绑定的?
- js运算符优先级笔记
- 通过协同绘制用GAN合成高分辨率无尽道路
- ASP.NET MVC的Model元数据与Model模板:预定义模板
- 为您的组织选择正确的企业云解决方案
- 搞定这些疑难杂症,向css3动画说yes
- 前十一个网络游戏业务收入1341亿 同比增22.1%
- ASP.NET MVC Model元数据及其定制:一个重要的接口IMetadataAware
- 使用Docker 1.12.x构建多容器Web应用程序
- 基于 vue2 + vuex 构建一个具有 45 个页面的大型单页面应用
- 深度解剖dubbo源码
- .NET Core采用的全新配置系统[6]: 深入了解三种针对文件(JSON、XML与INI)的配置源
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法