C-SATS工程副总裁教你如何用TensorFlow分类图像 part1

时间:2022-05-03
本文章向大家介绍C-SATS工程副总裁教你如何用TensorFlow分类图像 part1,主要内容包括运行原理、训练和分类、安装配置、训练模型、分类、基本概念、基础应用、原理机制和需要注意的事项等,并结合实例形式分析了其使用技巧,希望通过本文能帮助到大家理解应用这部分内容。

最近在深度学习算法和硬件性能方面的最新进展使研究人员和公司在图像识别,语音识别,推荐引擎和机器翻译等领域取得了巨大的进步。六年前,首次机器在视觉模式识别方面的表现首次超过人类。两年前,Google Brain团队发布了TensorFlow,让深度学习可以应用于大众。TensorFlow超越了许多用于深度学习的复杂工具。

有了TensorFlow,你可以访问具有强大功能的复杂特征。它之所以如此强大,是因为TensorFlow的易用性非常好。

本文由两部分组成,我将介绍如何快速创建用于实际图像识别的卷积神经网络。计算步骤是Embarrassingly parallel和可部署执行逐帧视频分析和temporal-aware视频分析。

这个系列直接讲解最重要的地方。对于命令行和Python的基本理解需要你自己研究。写这篇文章的目的是让大家可以快速入门,并激励大家创建自己的项目。

运行原理

我们将按照以下步骤操作:

1. 标记是管理训练数据的过程。对于花卉,将雏菊的图像拖入“雏菊”目录,将玫瑰拖入“玫瑰”目录等等,以便根据需要选择许多不同的花朵。如果我们不去标记“蕨类植物”,分类器也永远不会返回“蕨类植物”。每个类型都需要大量的例子,所以这是一个重要的但很耗时的过程。为了省时,在这里我们使用预先标记好的数据。

2. 训练是将标记后的数据(图像)输入到模型中。工具将抓取一组随机图像,使用模型来猜测每种花的类型,测试猜测的准确性,并重复此过程,直到大部分训练数据被使用。最后一部分未过使用的图像用于计算训练模型的准确性。

3. 分类是使用模型分类新的图像。例如,输入:IMG207.JPG,输出:雏菊。这是最快,最简单的一步。

训练和分类

在本教程中,我们将训练图像分类器来识别不同类型的花朵。深度学习需要大量的训练数据,所以我们需要大量的分类好的花卉图像。值得庆幸的是,我有现成的,所以我会使用带有很好脚本的分类后的数据集,并使用一个现有的、经过完全训练的图像分类模型,并重新训练模型的最后几层。这种技术被称为迁移学习。

我们正在进行再培训的模型被称为Inception v3,它的介绍论文如下。

  • 介绍论文:https://arxiv.org/abs/1512.00567

从不知道如何从雏菊中分辨出郁金香到训练后可以成功分辨,大约需要20分钟。这就是深度学习的“学习”部分。

安装配置

首先在你选择的平台上安装Docker。

  • https://www.docker.com/community-edition#/download

docker是唯一一个依赖项。在许多TensorFlow教程中也用到了docker(这应该表明这是一个合理的方法)。我也更喜欢这种安装TensorFlow的方法,因为它通过不需要安装一堆依赖项,可以保持主机(笔记本电脑或桌面)的整洁。

安装Docker后,我们准备启动一个TensorFlow容器(container)进行训练和分类。创建一个工作目录在你的硬盘上准备2GB的空闲空间。创建一个名为local的子目录并记录访问这个目录的完整路径。

docker run -v /path/to/local:/notebooks/local --rm -it --name tensorflow 
tensorflow/tensorflow:nightly /bin/bash

以下是这个命令详细解释。

  • -v /path/to/local:/notebooks/local加载你刚刚创建的local目录到容器中合适的位置。如果你使用RHEL,Fedora或其他支持SELinux的系统,附加:Z到允许容器访问目录。(https://www.projectatomic.io/blog/2015/06/using-volumes-with-docker-can-cause-problems-with-selinux/)
  • –rm 告诉Docker在完成后删除容器。
  • -it 附加我们的输入和输出以使容器有交互性。
  • –name tensorflow将我们的容器命名为tensorflow
  • tensorflow/tensorflow:表示从Docker Hub(公共镜像库)的tensorflow/tensorflow中运行nightly而不是最新的镜像(默认是运行最新的)。之所以不用最新的,是因为在撰写本文时最新的包含了破坏TensorBoard的bug。而我们稍后要用TensorBoard进行可视化。
  • /bin/bash表示不运行默认命令;而是运行一个Bash shell。

训练模型

在容器内部,运行这些命令下载并检查训练数据。

curl -O http://download.tensorflow.org/example_images/flower_photos.tgz
echo 'db6b71d5d3afff90302ee17fd1fefc11d57f243f  flower_photos.tgz' | sha1sum -c

如果你没有看到消息flower_photos.tgz: OK,则表示没有正确的文件。如果上述curl或sha1sum步骤失败,请手动下载并分解主机local目录中的训练数据tarball(SHA-1 checksum: db6b71d5d3afff90302ee17fd1fefc11d57f243f)。

现在把训练数据放在适当的地方,然后下载和理智检查再训练脚本。

mv flower_photos.tgz local/
cd local
curl -O https://raw.githubusercontent.com/tensorflow/tensorflow/
10cf65b48e1b2f16eaa82
6d2793cb67207a085d0/tensorflow/examples/image_retraining/retrain.py
echo 'a74361beb4f763dc2d0101cfe87b672ceae6e2f5  retrain.py' | sha1sum -c

查到并确认retrain.py具有正确内容。你会看到retrain.py: OK。

运行再训练脚本。

python retrain.py --image_dir flower_photos --output_graph output_graph.pb 
--output_labels output_labels.txt

如果遇到以下错误,忽略即可。

TypeError: not all arguments converted during string formatting Logged from file
tf_logging.py, line 82

执行retrain.py后,训练图像被自动成的训练、测试和验证数据集。

在输出中,我们希望“训练准确性”和“验证准确性”高一些,“交叉熵”低一些。有关这些术语的详细解释,请访问下方链接。在较好的硬件上的训练需要大约30分钟。

  • 术语:https://www.tensorflow.org/tutorials/image_retraining

看一看你的控制台输出的最后一行:

INFO:tensorflow:Final test accuracy = 89.1% (N=340)

这说明我们的模型十次中有九次能够正确地猜出给定图像中显示的使五种花型中的哪一种。由于训练过程中加入了随机性,你的准确性可能会有所不同。

分类

再加上一个小脚本,我们可以将新的花朵图像添加到模型中,并输出它的猜测。这就是图像分类。

在主机上的local目录中将以下代码保存成classify.py:

import tensorflow as tf, sys
 
image_path = sys.argv[1]
graph_path = 'output_graph.pb'
labels_path = 'output_labels.txt'
 
# Read in the image_data
image_data = tf.gfile.FastGFile(image_path, 'rb').read()
 
# Loads label file, strips off carriage return
label_lines = [line.rstrip() for line
    in tf.gfile.GFile(labels_path)]
 
# Unpersists graph from file
with tf.gfile.FastGFile(graph_path, 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    _ = tf.import_graph_def(graph_def, name='')
 
# Feed the image_data as input to the graph and get first prediction
with tf.Session() as sess:
    softmax_tensor = sess.graph.get_tensor_by_name('final_result:0')
    predictions = sess.run(softmax_tensor, 
    {'DecodeJpeg/contents:0': image_data})
    # Sort to show labels of first prediction in order of confidence
    top_k = predictions[0].argsort()[-len(predictions[0]):][::-1]
    for node_id in top_k:
         human_string = label_lines[node_id]
         score = predictions[0][node_id]
         print('%s (score = %.5f)' % (human_string, score))

要测试自己的图像,将其在你的local目录中保存为test.jpg,并在容器中运行python classify.py test.jpg。输出结果如下所示:

sunflowers (score = 0.78311)
daisy (score = 0.20722)
dandelion (score = 0.00605)
tulips (score = 0.00289)
roses (score = 0.00073)

数字表明自信程度。模型有78.311%的确定图像中的花是向日葵。得分越高表示图像越匹配结果。请注意,只显示一个匹配。多标签分类需要不同的方法。

欲了解更多详情,查看此大线,由线解释的classify.py。

分类器脚本中的图形加载代码损坏了,所以我应用了graph_def = tf.GraphDef()等图形加载代码。

我们创造了一个还可以的花朵图像分类器,可以在笔记本电脑上每秒钟处理大约五个图像。

在下一期中,我们将用到这些知识训练不同的图像分类器,并使用TensorBoard观察它。如果你想试试TensorBoard,请保持容器的运行,并确保docker运行没有被终止。