【教程】利用Tensorflow目标检测API确定图像中目标的位置

时间:2022-05-03
本文章向大家介绍【教程】利用Tensorflow目标检测API确定图像中目标的位置,主要内容包括准备数据集、准备模型、训练、测试、结语、基本概念、基础应用、原理机制和需要注意的事项等,并结合实例形式分析了其使用技巧,希望通过本文能帮助到大家理解应用这部分内容。

深度学习提供了另一种解决“Wally在哪儿”(美国漫画)问题的方法。与传统的图像处理计算机视觉方法不同的是,它只使用了少量的标记出Wally位置的示例。

在我的Github repo上发布了具有评估图像和检测脚本的最终训练模型。

  • Github repo地址:https://github.com/tadejmagajna/HereIsWally

这篇文章描述了使用Tensorflow目标检测API来训练神经网络的过程,并使用围绕它构建的Python脚本来寻找Wally。它由以下步骤组成:

  • 通过创建一组标记训练图像来准备数据集,其中标签代表图像中Wally的xy位置;
  • 读取和配置模型以使用Tensorflow目标检测API;
  • 在我们的数据集上训练模型
  • 使用导出的图形对评估图像的模型进行测试

开始之前,请确保按照说明安装Tensorflow目标检测API。

准备数据集

神经网络是深度学习的过程中最值得注意的过程,但遗憾的是,科学家们花费大量时间的准备和格式化训练数据。

最简单的机器学习问题的目标值通常是标量(比如数字检测器)或分类字符串。Tensorflow目标检测API训练数据使用两者的结合。它包括一组图像,并附有特定目标的标签和它们在图像中出现的位置。位置用两点(二维空间)定义,两点足够画一个物体周围的包围盒。

因此,为了创建训练集,我们需要提出一组Wally出现地点的图片。

虽然我可以用LabelImg这样的注释工具,花费数周的时间通过手工标记图像来解决问题,但我发现了一个已经解决了Where’s Wally这个问题的训练集。

Wally训练数据集,最后四列描述了Wally出现在图像中的位置

准备数据集的最后一步是将我们的标签(保存为文本文件)和图像(.jpeg)打包成一个二进制.tfrecord文件(该过程的解释代码地址见段末),但可以找到训练和重新运算求出Wally位置的参数内容。 .tfecord文件在我的Github repo上。

  • Github repo地址:https://github.com/tadejmagajna/HereIsWally
  • 解释地址:http://warmspringwinds.github.io/tensorflow/tf-slim/2016/12/21/tfrecords-guide/

准备模型

Tensorflow目标检测API提供了一组经过多次公开数据集训练的具有不同性能(通常为速度 – 精度折衷)的预训练模型。

虽然模型可以从头开始随机初始化网络权值,但这个过程可能需要几周的时间。我们使用一种称为转移学习的方法来替换该过程。

转移学习包含采用通常训练的模型解决一些一般问题并且重新训练模型以解决我们的问题。转移学习的工作原理是,通过使用在预先训练的模型中获得的知识并将其转移到新的模型中,来代替从头开始训练模型这些无用的重复工作。这为我们节省了大量的时间,将花费在训练上的时间用于获得针对我们问题的知识。

我们使用带有经过COCO数据集训练的Inception v2模型的RCNN,以及它的管道配置文件。该模型包含一个检查点.ckpt文件,我们可以使用该文件开始训练。

  • RCNN地址: http://download.tensorflow.org/models/object_detection/faster_rcnn_inception_v2_coco_2017_11_08.tar.gz
  • 管道配置文件地址: https://github.com/tensorflow/models/blob/master/research/object_detection/samples/configs/ssd_inception_v2_coco.config

下载配置文件后,请确保用指向检查点文件、训练以及评估.tfrecord文件与标签映射文件的路径代替“PATH_TO_BE_CONFIGURED”字段。

需要配置的最终文件是labels.txt映射文件,其中包含所有不同目标的标签。由于我们只是在寻找一种类型的目标,我们的标签文件看起来像这样:

item {
  id: 1
  name: 'waldo'
}

最后,我们最终应该:

  • 具有.ckpt检查点文件的预训练模型;
  • 训练和评估.tfrecord数据集;
  • 标记映射文件;
  • 指向以上文件的管道配置文件。

现在,我们准备开始训练。

训练

Tensorflow目标检测API提供了一个简单易用的Python脚本来重新训练我们的模型。它位于models / research / object_detection中,可以利用下列路径运行:

python train.py –logtostderr –pipeline_config_path= PATH_TO_PIPELINE_CONFIG –train_dir=PATH_TO_TRAIN_DIR

其中PATH_TO_PIPELINE_CONFIG是到管道配置文件的路径,PATH_TO_TRAIN_DIR是一个新创建的目录,我们的新检查点和模型将被存储在该目录中。

train.py的输出应该如下所示:

用最重要的信息来查找损失。这是在训练或验证集中每个示例错误的总和。当然,你希望它尽可能低,这意味着,缓慢下降表示你的模型正在学习(或过度拟合你的训练数据)。你还可以使用Tensorboard来更详细地显示训练数据。

该脚本将在一定数量的步骤后自动存储检查点文件,以便你随时恢复保存的检查点,以防计算机在学习过程中崩溃。

这意味着当你想结束模型的训练时,你可以终止脚本。

但是什么时候停止学习?关于何时停止训练,原则上是当评估集的损失减少或非常低时(在我们的例子中低于0.01)。

测试

现在我们可以通过在一些示例图像上进行测试来实际使用我们的模型。

首先,我们需要使用models/research/object_detection脚本中存储的检查点(位于我们的训练目录中)导出推理图:

python export_inference_graph.py — pipeline_config_path PATH_TO_PIPELINE_CONFIG --trained_checkpoint_prefix PATH_TO_CHECPOINT --output_directory OUTPUT_PATH

我们的Python脚本可以用导出的推理图来查找Wally的位置。

我写了一些简单的Python脚本(基于Tensorflow 目标检测API),你可以在模型上使用它们执行目标检测,并在检测到的目标周围绘制框或将其暴露。

find_wally.py和find_wally_pretty.py都可以在我的Github仓库中找到,可以简单地运行:

  • find_wally.py地址: https://github.com/tadejmagajna/HereIsWally/blob/master/find_wally.py
  • find_wally_pretty.py地址: https://github.com/tadejmagajna/HereIsWally/blob/master/find_wally_pretty.py
  • Github repo 地址: https://github.com/tadejmagajna/HereIsWally
python find_wally.py

或者

python find_wally_pretty.py

在自己的模型或自己的评估图像上使用脚本时,请确保修改model_path和image_path变量。

结语

在我的Github repo 上发布的模型表现非常出色。

模型设法在评估图像中找到Wally,并且对网络上的一些额外的随机例子处理得很好。它未能找到很大的Wally,直观来说,找到小的walley应该更容易解决。这表明我们的模型可能过度适合我们的训练数据,主要是因为训练图像较少。