使用VGG模型自定义图像分类任务

时间:2022-06-22
本文章向大家介绍使用VGG模型自定义图像分类任务,主要内容包括其使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。

前言

网上关于VGG模型的文章有很多,有介绍算法本身的,也有代码实现,但是很多代码只给出了模型的结构实现,并不包含数据准备的部分,这让人很难愉快的将代码迁移自己的任务中。为此,这篇博客接下来围绕着如何使用VGG实现自己的图像分类任务,从数据准备到实验验证。代码基于Python与TensorFlow实现,模型结构采用VGG-16,并且将很少的出现算法和理论相关的东西。

数据准备

下载数据和转换代码

大多数人自己的训练数据,一般都是传统的图片形式,如.jpg,.png等等,而图像分类任务的话,这些图片的天然组织形式就是一个类别放在一个文件夹里,那么有啥大众化的数据集是这样的组织形式呢?TensorFlow的FlowersData,它下载下来是这个样子:

一共有五类,每一类中都有几百张图,我们把这些数据组织成TFrecord形式,对应的博客在这里,源码的github在这里,FlowersData数据集在这里。 有上面这三个东西之后,就可以生成TFrecord文件了。

组织图片数据

首先将FlowersData文件夹下的数据分成两个部分,训练数据和测试数据,我把原文件五个类别中都拿出大概100张图左右,数据的构成和路径如下:

生成训练TFrecord

#图片路径
cwd = 'F:\flowersdata\trainimages\'
#文件路径
filepath = 'F:\flowersdata\tfrecord\train\'
classes=['daisy',
         'dandelion',
         'roses',
         'sunflowers',
         'tulips']
#tfrecords格式文件名
ftrecordfilename = ("traindata.tfrecords-%.3d" % recordfilenum)
#tfrecords格式文件名
ftrecordfilename = ("traindata.tfrecords-%.3d" % recordfilenum)

生成效果:

生成预测TFrecord

#图片路径
cwd = 'F:\flowersdata\testimages\'
#文件路径
filepath = 'F:\flowersdata\tfrecord\test\'
classes=['daisy',
         'dandelion',
         'roses',
         'sunflowers',
         'tulips']
#tfrecords格式文件名
ftrecordfilename = ("testdata.tfrecords-%.3d" % recordfilenum)
#tfrecords格式文件名
ftrecordfilename = ("testdata.tfrecords-%.3d" % recordfilenum)

生成效果:

训练模型

初始权重与源码下载

VGG-16的初始权重我上传到了百度云,在这里下载; VGG-16源码我上传到了github,在这里下载;

在源码中: train_and_val.py文件是最终要执行的文件,它定了训练和预测的过程; input_data.py是将上一步中生成的TFRecord文件组织成batch的过程; VGG.py定义了VGG-16的网络结构; tool.py是最底层,定义了一些卷积池化等操作。

训练模型

train_and_val.py文件修改:

if __name__=="__main__":
    train()
    #evaluate()

根据自己的路径修改:

#初始权重路径
pre_trained_weights = 'vgg16_pretrain/vgg16.npy'
#训练数据路径
train_data_dir = 'F:\flowersdata\tfrecord\train\traindata.tfrecords*'
    test_data_dir = 
#预测数据路径
'F:\flowersdata\tfrecord\test\testdata.tfrecords*'
#训练生成文件路径
train_log_dir = 'logs/train/'
#预测生成文件路径
val_log_dir = 'logs/val/'

根据自己的显存容量修改:

IMG_W = 224
IMG_H = 224
BATCH_SIZE = 8

训练过程每50个step打印loss; 每200个step计算一个batch中的准确率; 每1000个step保存一次权重。

预测

train_and_val.py文件修改:

if __name__=="__main__":
    #train()
    evaluate()
#训练过程中生成的权重
log_dir = 'logs/train/'
#预测数据集路径
test_data_dir = 'F:\flowersdata\tfrecord\test\testdata.tfrecords*'
#用于生成tf文件的图片数量
n_test = 502

打印测试样本总数; 打印正确预测的样本总数; 打印top_1。