基于PyTorch的CIFAR-10分类
作者:如缕清风
本文为博主原创,未经允许,请勿转载:https://www.cnblogs.com/warren2123/articles/11823690.html
一、前言
本文基于Facebook的PyTorch框架,通过对VGGNet模型实现,对CIFAR-10数据集进行分类。
CIFAR-10数据集包含60000张 32x32的彩色图片,共分为10种类别,每种类别6000张。其中训练集包含50000张图片,测试机包含10000张图片。CIFAR-10的样本图如下所示。
二、基于PyTorch构建VGGNet模型
PyTorch与TensorFlow最大的不同是运用动态图计算,并采用自动autograph的方法,大大方便了模型的构建。本文模型构建分为四个部分:数据读取及预处理、构建VGGNet模型、定义模型超参数以及评估方法、参数优化。
1、数据读取及预处理
本文采用GPU对PyTorch进行速度提升,如果不存在GPU,则会自动选择CPU进行运算。
import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torchvision import torchvision.transforms as transforms device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
数据集的读取利用PyTorch的torchvision库,可以选择将download参数改为True进行下载,由于本文已经下载好,所以定位为False。数据集提前采用正则化的方式进行预处理,分为训练集和测试集,并采用生成器的方式加载数据,便于更好的处理大批量数据。classes为CIFAR-10数据集的10个标签类别。
transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) trainset = torchvision.datasets.CIFAR10('./data', train=True, download=False, transform=transform) testset = torchvision.datasets.CIFAR10('./data', train=False, download=False, transform=transform) trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2) testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=True, num_workers=2) classes = ('airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
2、构建VGGNet模型
VGGNet是通过对AlexNet的改进,进一步加深了卷积神经网络的深度,采用堆叠3 x 3的卷积层和2 x 2的降采样层,实现11到19层的网络深度。VGG的结构图如下所示。
VGGNet模型总的来说,分为VGG16和VGG19两类,区别在于模型的层数不同,以下'M'参数代表池化层,数据代表各层滤波器的数量。
cfg = { 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'] }
本文中定义VGGNet模型的为全连接层,卷积层中都运用批量归一化的方法,提升模型的训练速度与收敛效率,并且可以一定的代替dropout的作用,有利于模型的泛化效果。
class VGG(nn.Module): def __init__(self, vgg_name): super(VGG, self).__init__() self.features = self._make_layers(cfg[vgg_name]) self.classifier = nn.Linear(512, 10) def forward(self, x): out = self.features(x) out = out.view(out.size(0), -1) out = self.classifier(out) return out def _make_layers(self, cfg): layers = [] in_channels = 3 for x in cfg: if x == 'M': layers += [nn.MaxPool2d(kernel_size=2, stride=2)] else: layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1), nn.BatchNorm2d(x), nn.ReLU(inplace=True)] in_channels = x layers += [nn.AvgPool2d(kernel_size=1, stride=1)] return nn.Sequential(*layers)
3、定义模型超参数以及评估方法
模型的学习率、训练次数、批次大小通过超参数的方式设定,优化函数采用Adam,损失函数采用交叉熵进行计算。
LR = 0.001 EPOCHES = 20 BATCHSIZE = 100 net4 = VGG('VGG16') mlps = [net4.to(device)] optimizer = torch.optim.Adam([{"params": mlp.parameters()} for mlp in mlps], lr=LR) loss_function = nn.CrossEntropyLoss()
4、参数优化
以下是通过定义的训练次数进行模型的参数优化过程,每一次训练输出模型的测试正确率。
for ep in range(EPOCHES): for img, label in trainloader: img, label = img.to(device), label.to(device) optimizer.zero_grad() for mlp in mlps: mlp.train() out = mlp(img) loss = loss_function(out, label) loss.backward() optimizer.step() pre = [] vote_correct = 0 mlps_correct = [0 for i in range(len(mlps))] for img, label in testloader: img, label = img.to(device), label.to(device) for i, mlp in enumerate(mlps): mlp.eval() out = mlp(img) _, prediction = torch.max(out, 1) pre_num = prediction.cpu().numpy() mlps_correct[i] += (pre_num == label.cpu().numpy()).sum() pre.append(pre_num) arr = np.array(pre) pre.clear() result = [Counter(arr[:, i]).most_common(1)[0][0] for i in range(BATCHSIZE)] vote_correct += (result == label.cpu().numpy()).sum() for idx, correct in enumerate(mlps_correct): print("Epoch:" + str(ep) + "VGG的正确率为:" + str(correct/len(testloader)))
训练输出如下所示:
Epoch:0 VGG的正确率为:57.67 Epoch:1 VGG的正确率为:67.13 Epoch:2 VGG的正确率为:74.84 Epoch:3 VGG的正确率为:79.59 Epoch:4 VGG的正确率为:79.93 Epoch:5 VGG的正确率为:82.61 Epoch:6 VGG的正确率为:82.96 Epoch:7 VGG的正确率为:84.31 Epoch:8 VGG的正确率为:82.43 Epoch:9 VGG的正确率为:85.12 Epoch:10 VGG的正确率为:84.33 Epoch:11 VGG的正确率为:83.66 Epoch:12 VGG的正确率为:82.02 Epoch:13 VGG的正确率为:85.44 Epoch:14 VGG的正确率为:84.08 Epoch:15 VGG的正确率为:85.67 Epoch:16 VGG的正确率为:84.87 Epoch:17 VGG的正确率为:85.21 Epoch:18 VGG的正确率为:84.62 Epoch:19 VGG的正确率为:85.88 Epoch:20 VGG的正确率为:83.46 Epoch:21 VGG的正确率为:86.63 Epoch:22 VGG的正确率为:85.75 Epoch:23 VGG的正确率为:86.29 Epoch:24 VGG的正确率为:83.33 Epoch:25 VGG的正确率为:86.48 Epoch:26 VGG的正确率为:85.6 Epoch:27 VGG的正确率为:86.66 Epoch:28 VGG的正确率为:85.45 Epoch:29 VGG的正确率为:85.65 Epoch:30 VGG的正确率为:86.36 Epoch:31 VGG的正确率为:86.27 Epoch:32 VGG的正确率为:85.09 Epoch:33 VGG的正确率为:85.6 Epoch:34 VGG的正确率为:86.82 Epoch:35 VGG的正确率为:85.76 Epoch:36 VGG的正确率为:86.59 Epoch:37 VGG的正确率为:85.56 Epoch:38 VGG的正确率为:85.71 Epoch:39 VGG的正确率为:86.07 Epoch:40 VGG的正确率为:84.87 Epoch:41 VGG的正确率为:85.91 Epoch:42 VGG的正确率为:86.8 Epoch:43 VGG的正确率为:87.43 Epoch:44 VGG的正确率为:85.99 Epoch:45 VGG的正确率为:86.32 Epoch:46 VGG的正确率为:86.72 Epoch:47 VGG的正确率为:86.39 Epoch:48 VGG的正确率为:86.08 Epoch:49 VGG的正确率为:86.97
三、总结
本文基于PyTorch构建的VGG模型,在CIFAR-10中分类效果达到86.97%,最高达到87.43%的分类准确率,当然后续可以进一步调整超参数优化模型,也可以运用多模型架构。通过细分各类别的准确率,可以看出模型在dog类别准确率较低,在truck类别准确率较高。
Accuracy of airplane : 90 % Accuracy of automobile : 90 % Accuracy of bird : 90 % Accuracy of cat : 82 % Accuracy of deer : 88 % Accuracy of dog : 71 % Accuracy of frog : 93 % Accuracy of horse : 85 % Accuracy of ship : 81 % Accuracy of truck : 95 %
基于PyTorch的构建,能够从中体会到Python之禅的哲学,简洁、方便等。相信这将是深度学习的一大助力,当然这也是因人而异。
原文地址:https://www.cnblogs.com/warren2123/p/15009431.html
- win7怎么去除快捷方式的小箭头
- 零基础学编程015:画些有趣的图案
- Spring boot with Thymeleaf
- 零基础学编程014:小海龟做画
- Springboot @RequestBody 传递 List
- 零基础学编程013:import让你飞起来
- 【教程】利用Tensorflow目标检测API确定图像中目标的位置
- 零基础学编程012:画出复利曲线图
- OpenAI发布高度优化的GPU计算内核—块稀疏GPU内核
- SQL SERVER 原来还可以这样玩 FOR XML PATH
- 零基础学编程011:复利数据表问题(总结)
- 一个小程序引发的思考
- 深入内核:DUMP Block的数据读取与脏数据写入影响
- 零基础学编程010:最终可以输出完整的复利数据表了
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- 在Ubuntu中实现人脸识别登录的完整步骤
- Linux下如何寻找相同文件的方法
- CentOS 7中Nginx日志定时拆分实现过程详解
- 浅谈linux模拟多线程崩溃和多进程崩溃
- Linux下MongoDB的安装和配置教程
- Linux配置实现免密钥登录过程解析
- 可以提高效率的十个Linux命令别名汇总
- 基于linux命令提取文件夹内特定文件路径
- Ubuntu20.04修改ip地址的方法示例
- Linux 逻辑卷管理(LVM)使用方法总结
- Linux 下载安装VSCode 使用编程输出当前时间的方法
- 详解Linux获取线程的PID(TID、LWP)的几种方式
- Linux文件基本属性知识点总结
- Linux MySQL忘记root密码解决方案
- 如何使用iostat查看linux硬盘IO性能