4.训练模型之准备训练数据
终于要开始训练识别熊猫的模型了, 第一步是准备好训练数据,这里有三件事情要做:
- 收集一定数量的熊猫图片。
- 将图片中的熊猫用矩形框标注出来。
- 将原始图片和标注文件转换为TFRecord格式的文件。
数据标注
收集熊猫的图片和标注熊猫位置的工作称之为“Data Labeling”,这可能是整个机器学习领域内最低级、最机械枯燥的工作了,有时候大量的 Data Labeling 工作会外包给专门的 Data Labeling 公司做, 以加快速度和降低成本。 当然我们不会把这个工作外包给别人,要从最底层的工作开始!收集熊猫图片倒不是太难,从谷歌和百度图片上收集 200 张熊猫的图片,应该足够训练一个可用的识别模型了。然后需要一些工具来做标注,我使用的是 Mac 版的 RectLabel,常用的还有 LabelImg 和 LabelMe 等。
RectLabel 标注时的界面大概是这样的:
当我们标注完成的时候,它会在 annotations 目录下生产和图片文件名相同的后缀名为 .json 的标注文件。
打开一个标注文件,其内容大概是这样的:
{
"filename" : "61.jpg",
"folder" : "panda_images",
"image_w_h" : [
453,
340
],
"objects" : [
{
"label" : "panda",
"x_y_w_h" : [
90,
104,
364,
233
]
}
]
}
- image_w_h:图片的宽和高。
- objects:图片的中的物体信息、数组。
- label:在标注的时候指定的物体名称。
- x_y_w_h:物体位置的矩形框:(xmin、ymin、width、height)。
接下来要做的是耐心的在这 200 张图片上面标出熊猫的位置,这个稍微要花点时间,可以在 这里 找已经标注好的图片数据。
生成 TFRecord
接下来需要一点 Python 代码来将图片和标注文件生成为 TFRecord 文件,TFRecord 文件是由很多tf.train.Example对象序列化以后组成的,先写由一个单独的图片文件生成tf.train.Example对象的函数:
def create_sample(image_filename, data_dir):
image_path = os.path.join(data_dir, image_filename)
annotation_path = os.path.join(data_dir, 'annotations', os.path.splitext(image_filename)[0] + ".json")
with tf.gfile.GFile(image_path, 'rb') as fid:
encoded_jpg = fid.read()
encoded_jpg_io = io.BytesIO(encoded_jpg)
with open(annotation_path) as fid:
image_annotation = json.load(fid)
width = image_annotation['image_w_h'][0]
height = image_annotation['image_w_h'][1]
xmins = []
ymins = []
xmaxs = []
ymaxs = []
classes = []
classes_text = []
for obj in image_annotation['objects']:
classes.append(1)
classes_text.append('panda')
box = obj['x_y_w_h']
xmins.append(float(box[0]) / width)
ymins.append(float(box[1]) / height)
xmaxs.append(float(box[0] + box[2] - 1) / width)
ymaxs.append(float(box[1] + box[3] - 1) / height)
filename = image_annotation['filename'].encode('utf8')
tf_example = tf.train.Example(features=tf.train.Features(feature={
'image/height': dataset_util.int64_feature(height),
'image/width': dataset_util.int64_feature(width),
'image/filename': dataset_util.bytes_feature(filename),
'image/source_id': dataset_util.bytes_feature(filename),
'image/encoded': dataset_util.bytes_feature(encoded_jpg),
'image/format': dataset_util.bytes_feature('jpeg'.encode('utf8')),
'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),
'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),
'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),
'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),
'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
'image/object/class/label': dataset_util.int64_list_feature(classes),
}))
return tf_example
在这里简单说明一下:
- 通过图片文件名找到对应的标注文件,并读入标注信息。
- 因为图片中标注的物体都是熊猫,用数字 1 来代表,所以 class 数组里的元素值都为 1,class_text数组的里的元素值都为‘panda’。
- Object Detection API 里面接受的矩形框输入格式为 (xmin, ymin, xmax, ymax) 和标注文件的 (xmin, ymin, width, height) 不一样,所以要做一下转换。同时需要将这些值归一化:将数值投影到 (0, 1] 的区间内。
- 将特征组成{特征名:特征值}的 dict 作为参数来创建tf.train.Example。
接下来将tf.train.Example对象序列化,我们写一个可以由图片文件列表生成对应 TFRecord 文件的的函数:
def create_tf_record(example_file_list, data_dir, output_file_path):
writer = tf.python_io.TFRecordWriter(output_file_path)
for filename in example_file_list:
tf_example = create_sample(filename, data_dir)
writer.write(tf_example.SerializeToString())
writer.close()
依次调用create_sample函数然后将生成的tf.train.Example对象依次序列化即可。
最后需要将数据集切分为训练集合测试集,将图片文件打乱,然后按照 7:3 的比例进行切分:
random.seed(42)
random.shuffle(all_examples)
num_examples = len(all_examples)
num_train = int(0.7 * num_examples)
train_examples = all_examples[:num_train]
val_examples = all_examples[num_train:]
create_tf_record(train_examples, data_dir, os.path.join(output_dir, 'train.record'))
create_tf_record(val_examples, data_dir, os.path.join(output_dir, 'val.record'))
写完这个脚本以后,最好再写一个测试用例来验证这个脚本,因为我们将会花很长的时间来训练,到时候再发现脚本有 bug 就太浪费时间了,我们主要测试create_sample方法有没有根据输入数据生成正确的tf.train.Example对象:
def test_dict_to_tf_example(self):
image_file = '61.jpg'
data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'test_data')
example = create_sample(image_file, data_dir)
self._assertProtoEqual(
example.features.feature['image/height'].int64_list.value, [340])
self._assertProtoEqual(
example.features.feature['image/width'].int64_list.value, [453])
self._assertProtoEqual(
example.features.feature['image/filename'].bytes_list.value,
[image_file])
self._assertProtoEqual(
example.features.feature['image/source_id'].bytes_list.value,
[image_file])
self._assertProtoEqual(
example.features.feature['image/format'].bytes_list.value, ['jpeg'])
self._assertProtoEqual(
example.features.feature['image/object/bbox/xmin'].float_list.value,
[90.0 / 453])
self._assertProtoEqual(
example.features.feature['image/object/bbox/ymin'].float_list.value,
[104.0/340])
self._assertProtoEqual(
example.features.feature['image/object/bbox/xmax'].float_list.value,
[1.0])
self._assertProtoEqual(
example.features.feature['image/object/bbox/ymax'].float_list.value,
[336.0/340])
self._assertProtoEqual(
example.features.feature['image/object/class/text'].bytes_list.value,
['panda'])
self._assertProtoEqual(
example.features.feature['image/object/class/label'].int64_list.value,
[1])
后台回复“准备训练数据”关键字可以获取全部源码。
完成之后运行脚本,传入图片和标注的文件夹路径和输出文件路径:
python create_tf_record.py --image_dir=PATH_OF_IMAGE_SET --output_dir=OUTPUT_DIR
执行完成后会在由output_dir参数指定的目录生成train.record和val.record文件, 分别为训练集和测试集。
生成 label map 文件
最后还需要一个 label map 文件,很简单,因为我们只有一种物体:熊猫
label_map.pbtxt:
item {
id: 1
name: 'panda'
}
训练一个熊猫识别模型所需要的训练数据就准备完了,接下来开始在 GPU 主机上面开始训练。
- Lua table之弱引用
- 看吧,这就是现代化 PHP 该有的样子
- 从web图片裁剪出发:了解H5中的Blob
- Android子线程更新UI主线程方法之Handler
- Drawable.Bitmap.Canvas.Paint.Matrix
- 关于JSON.stringify和Unicode编码,需要注意的几点
- 用 PHP 的方式实现的各类算法合集
- Nginx 反向代理解决前后端联调跨域问题
- JavaScript对象length
- Go1.8.4和Go1.9.1版本发布
- Javascript数组操作
- Tensorflow官方语音识别入门教程 | 附Google新语音指令数据集
- jQuery VS JavaScript原生API
- 居于H5的多文件、大文件、多线程上传解决方案
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- lib-flexible引入到Vue做移动端rem布局
- 微信小程序引入VantWeapp开发
- 手把手教你,嘴对嘴传达------源码编译LNMP部署及应用 , 手动搭建discuz论坛
- 微信小程序封装api接口
- 手把手教你,嘴对嘴传达------Nginx常规的优化(隐藏版本号,缓存时间,日志切割,网页压缩,防盗链优化)
- Vue Router 实现多种页面跳转
- Vue实现输入框自动聚焦
- 手把手教你,嘴对嘴传达------Apache --ab测试
- Css实现内容溢出添加横向滚动条
- 手把手教你,嘴对嘴传达------深入介绍Nginx的rewrite模块(理论加实验)
- jQuery实现点击图片弹出视频并自动播放
- 机器学习之决策树一-ID3原理与代码实现
- jQuery点击返回顶部
- 手把手教你,嘴对嘴传达------Nginx实现动静分离的两种方式
- Vue实现push数组并删除方法