基于PyTorch的CIFAR-10分类

时间:2021-07-14
本文章向大家介绍基于PyTorch的CIFAR-10分类,主要包括基于PyTorch的CIFAR-10分类使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。

作者:如缕清风

本文为博主原创,未经允许,请勿转载:https://www.cnblogs.com/warren2123/articles/11823690.html


一、前言

        本文基于Facebook的PyTorch框架,通过对VGGNet模型实现,对CIFAR-10数据集进行分类。

        CIFAR-10数据集包含60000张 32x32的彩色图片,共分为10种类别,每种类别6000张。其中训练集包含50000张图片,测试机包含10000张图片。CIFAR-10的样本图如下所示。

 

 

二、基于PyTorch构建VGGNet模型

        PyTorch与TensorFlow最大的不同是运用动态图计算,并采用自动autograph的方法,大大方便了模型的构建。本文模型构建分为四个部分:数据读取及预处理、构建VGGNet模型、定义模型超参数以及评估方法、参数优化。

1、数据读取及预处理

        本文采用GPU对PyTorch进行速度提升,如果不存在GPU,则会自动选择CPU进行运算。

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        数据集的读取利用PyTorch的torchvision库,可以选择将download参数改为True进行下载,由于本文已经下载好,所以定位为False。数据集提前采用正则化的方式进行预处理,分为训练集和测试集,并采用生成器的方式加载数据,便于更好的处理大批量数据。classes为CIFAR-10数据集的10个标签类别。

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

trainset = torchvision.datasets.CIFAR10('./data', train=True, download=False, transform=transform)
testset = torchvision.datasets.CIFAR10('./data', train=False, download=False, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=True, num_workers=2)

classes = ('airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

2、构建VGGNet模型

        VGGNet是通过对AlexNet的改进,进一步加深了卷积神经网络的深度,采用堆叠3 x 3的卷积层和2 x 2的降采样层,实现11到19层的网络深度。VGG的结构图如下所示。

        VGGNet模型总的来说,分为VGG16和VGG19两类,区别在于模型的层数不同,以下'M'参数代表池化层,数据代表各层滤波器的数量。

cfg = {
    'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M']
}

        本文中定义VGGNet模型的为全连接层,卷积层中都运用批量归一化的方法,提升模型的训练速度与收敛效率,并且可以一定的代替dropout的作用,有利于模型的泛化效果。

class VGG(nn.Module):
    def __init__(self, vgg_name):
        super(VGG, self).__init__()
        self.features = self._make_layers(cfg[vgg_name])
        self.classifier = nn.Linear(512, 10)
    
    def forward(self, x):
        out = self.features(x)
        out = out.view(out.size(0), -1)
        out = self.classifier(out)
        return out
    
    def _make_layers(self, cfg):
        layers = []
        in_channels = 3
        for x in cfg:
            if x == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
                           nn.BatchNorm2d(x),
                           nn.ReLU(inplace=True)]
                in_channels = x
        
        layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
        return nn.Sequential(*layers)

3、定义模型超参数以及评估方法

        模型的学习率、训练次数、批次大小通过超参数的方式设定,优化函数采用Adam,损失函数采用交叉熵进行计算。

LR = 0.001
EPOCHES = 20
BATCHSIZE = 100

net4 = VGG('VGG16')
mlps = [net4.to(device)]
optimizer = torch.optim.Adam([{"params": mlp.parameters()} for mlp in mlps], lr=LR)
loss_function = nn.CrossEntropyLoss()

4、参数优化

        以下是通过定义的训练次数进行模型的参数优化过程,每一次训练输出模型的测试正确率。

for ep in range(EPOCHES):
    for img, label in trainloader:
        img, label = img.to(device), label.to(device)
        optimizer.zero_grad()
        for mlp in mlps:
            mlp.train()
            out = mlp(img)
            loss = loss_function(out, label)
            loss.backward()
        optimizer.step()
    
    pre = []
    vote_correct = 0
    mlps_correct = [0 for i in range(len(mlps))]
    for img, label in testloader:
        img, label = img.to(device), label.to(device)
        for i, mlp in enumerate(mlps):
            mlp.eval()
            out = mlp(img)
            _, prediction = torch.max(out, 1)
            pre_num = prediction.cpu().numpy()
            mlps_correct[i] += (pre_num == label.cpu().numpy()).sum()
            pre.append(pre_num)
        arr = np.array(pre)
        pre.clear()
        result = [Counter(arr[:, i]).most_common(1)[0][0] for i in range(BATCHSIZE)]
        vote_correct += (result == label.cpu().numpy()).sum()
    
    for idx, correct in enumerate(mlps_correct):
        print("Epoch:" + str(ep) + "VGG的正确率为:" + str(correct/len(testloader)))

        训练输出如下所示:

Epoch:0 VGG的正确率为:57.67
Epoch:1 VGG的正确率为:67.13
Epoch:2 VGG的正确率为:74.84
Epoch:3 VGG的正确率为:79.59
Epoch:4 VGG的正确率为:79.93
Epoch:5 VGG的正确率为:82.61
Epoch:6 VGG的正确率为:82.96
Epoch:7 VGG的正确率为:84.31
Epoch:8 VGG的正确率为:82.43
Epoch:9 VGG的正确率为:85.12
Epoch:10 VGG的正确率为:84.33
Epoch:11 VGG的正确率为:83.66
Epoch:12 VGG的正确率为:82.02
Epoch:13 VGG的正确率为:85.44
Epoch:14 VGG的正确率为:84.08
Epoch:15 VGG的正确率为:85.67
Epoch:16 VGG的正确率为:84.87
Epoch:17 VGG的正确率为:85.21
Epoch:18 VGG的正确率为:84.62
Epoch:19 VGG的正确率为:85.88
Epoch:20 VGG的正确率为:83.46
Epoch:21 VGG的正确率为:86.63
Epoch:22 VGG的正确率为:85.75
Epoch:23 VGG的正确率为:86.29
Epoch:24 VGG的正确率为:83.33
Epoch:25 VGG的正确率为:86.48
Epoch:26 VGG的正确率为:85.6
Epoch:27 VGG的正确率为:86.66
Epoch:28 VGG的正确率为:85.45
Epoch:29 VGG的正确率为:85.65
Epoch:30 VGG的正确率为:86.36
Epoch:31 VGG的正确率为:86.27
Epoch:32 VGG的正确率为:85.09
Epoch:33 VGG的正确率为:85.6
Epoch:34 VGG的正确率为:86.82
Epoch:35 VGG的正确率为:85.76
Epoch:36 VGG的正确率为:86.59
Epoch:37 VGG的正确率为:85.56
Epoch:38 VGG的正确率为:85.71
Epoch:39 VGG的正确率为:86.07
Epoch:40 VGG的正确率为:84.87
Epoch:41 VGG的正确率为:85.91
Epoch:42 VGG的正确率为:86.8
Epoch:43 VGG的正确率为:87.43
Epoch:44 VGG的正确率为:85.99
Epoch:45 VGG的正确率为:86.32
Epoch:46 VGG的正确率为:86.72
Epoch:47 VGG的正确率为:86.39
Epoch:48 VGG的正确率为:86.08
Epoch:49 VGG的正确率为:86.97

 

三、总结

        本文基于PyTorch构建的VGG模型,在CIFAR-10中分类效果达到86.97%,最高达到87.43%的分类准确率,当然后续可以进一步调整超参数优化模型,也可以运用多模型架构。通过细分各类别的准确率,可以看出模型在dog类别准确率较低,在truck类别准确率较高。

Accuracy of   airplane : 90 %
Accuracy of automobile : 90 %
Accuracy of       bird : 90 %
Accuracy of        cat : 82 %
Accuracy of       deer : 88 %
Accuracy of        dog : 71 %
Accuracy of       frog : 93 %
Accuracy of      horse : 85 %
Accuracy of       ship : 81 %
Accuracy of      truck : 95 %

        基于PyTorch的构建,能够从中体会到Python之禅的哲学,简洁、方便等。相信这将是深度学习的一大助力,当然这也是因人而异。

原文地址:https://www.cnblogs.com/warren2123/p/15009431.html