除了写烂的手写数据分类,你会不会做自定义图像数据集的识别?!
网上看的很多教程都是几个常见的例子,从内置模块或在线download数据集,要么是iris,要么是MNIST手写识别数字,或是UCI ,数据集不需要自己准备,所以不关心如何读取数据、做数据预处理相关的内容,但是实际做项目的时候做数据预处理感觉一头雾水。
本文从图片下载,到生成数据集列表,建立模型,最后到预测,将整个图片分类的实操流程详细讲解。 代码基于百度开源的深度学习框架 paddlepaddle,该框架安装及其简单:
pip install paddlepaddle
mac版 安装后使用如果报错:
1 Fatal Python error: PyThreadState_Get: no current thread
2 Abort trap: 6
解决方案:
1.运行otool,可以看到pip安装之后的_swig_paddle.so依赖/usr/local/opt/python/Frameworks/Python.framework/Versions/2.7/Python,但实际系统中不存在该路径
1 otool -L /anaconda/lib/python2.7/site-packages/py_paddle/_swig_paddle.so
2 /anaconda/lib/python2.7/site-packages/py_paddle/_swig_paddle.so:
3 /System/Library/Frameworks/CoreFoundation.framework/Versions/A/CoreFoundation (compatibility version 150.0.0, current version 1445.12.0)
4 /System/Library/Frameworks/Security.framework/Versions/A/Security (compatibility version 1.0.0, current version 58286.20.16)
5 /usr/local/opt/python/Frameworks/Python.framework/Versions/2.7/Python (compatibility version 2.7.0, current version 2.7.0)
6 /usr/lib/libc++.1.dylib (compatibility version 1.0.0, current version 400.9.0)
7 /usr/lib/libSystem.B.dylib (compatibility version 1.0.0, current version 1252.0.0)
2.利用install_name_tool来替换依赖
1 install_name_tool -change /usr/local/opt/python/Frameworks/Python.framework/Versions/2.7/Python ~/anaconda/lib/libpython2.7.dylib ~/anaconda/lib/python2.7/site-packages/py_paddle/_swig_paddle.so
标颜色的地方要根据自己电脑修改
3.替换成功后,可以看到第五条已经成功的换成anaconda下的路径了
1 otool -L /anaconda/lib/python2.7/site-packages/py_paddle/_swig_paddle.so
2 /anaconda/lib/python2.7/site-packages/py_paddle/_swig_paddle.so:
3 /System/Library/Frameworks/CoreFoundation.framework/Versions/A/CoreFoundation (compatibility version 150.0.0, current version 1445.12.0)
4 /System/Library/Frameworks/Security.framework/Versions/A/Security (compatibility version 1.0.0, current version 58286.20.16)
5 /anaconda/lib/libpython2.7.dylib (compatibility version 2.7.0, current version 2.7.0)
6 /usr/lib/libc++.1.dylib (compatibility version 1.0.0, current version 400.9.0)
7 /usr/lib/libSystem.B.dylib (compatibility version 1.0.0, current version 1252.0.0)
现在再运行paddle.init就不会有问题了
下载图片的代码
这个程序可以从百度图片中下载图片,可以多个类别一起下载,还可以指定下载数量
本文代码、及测试图片在公众号 datadw 里 回复 图片分类 即可获取。
# -*- coding:utf-8 -*-
import re
import uuid
import requests
import os
class DownloadImages:
def __init__(self,download_max,key_word):
self.download_sum = 0
self.download_max = download_max
self.key_word = key_word
self.save_path = '../images/download/' + key_word
def start_download(self):
self.download_sum = 0
gsm = 80
str_gsm = str(gsm)
pn = 0
if not os.path.exists(self.save_path):
os.makedirs(self.save_path)
while self.download_sum < self.download_max:
str_pn = str(self.download_sum)
url = 'http://image.baidu.com/search/flip?tn=baiduimage&ie=utf-8&'
'word=' + self.key_word + '&pn=' + str_pn + '&gsm=' + str_gsm + '&ct=&ic=0&lm=-1&width=0&height=0'
print url
result = requests.get(url)
self.downloadImages(result.text)
print '下载完成'
def downloadImages(self,html):
img_urls = re.findall('"objURL":"(.*?)",', html, re.S)
print '找到关键词:' + self.key_word + '的图片,现在开始下载图片...'
for img_url in img_urls:
print '正在下载第' + str(self.download_sum + 1) + '张图片,图片地址:' + str(img_url)
try:
pic = requests.get(img_url, timeout=50)
pic_name = self.save_path + '/' + str(uuid.uuid1()) + '.jpg'
with open(pic_name, 'wb') as f:
f.write(pic.content)
self.download_sum += 1
if self.download_sum >= self.download_max:
break
except Exception, e:
print '【错误】当前图片无法下载,%s' % e
continue
if __name__ == '__main__':
key_word_max = input('请输入你要下载几个类别:')
key_words = []
for sum in range(key_word_max):
key_words.append(raw_input('请输入第%s个关键字:' % str(sum+1)))
max_sum = input('请输入每个类别下载的数量:')
for key_word in key_words:
downloadImages = DownloadImages(max_sum, key_word)
downloadImages.start_download()
数据集介绍
如果我们要训练自己的数据集的话,就需要先建立图像列表文件,下面的代码是Myreader.py
读取图像数据集的一部分,从这些代码中可以看出,图像列表中,图像的路径和标签是以t
来分割的,所以我们在生成这个列表的时候,使用t
就可以了.
def train_reader(self,train_list, buffered_size=1024):
def reader():
with open(train_list, 'r') as f:
lines = [line.strip() for line in f]
for line in lines:
img_path, lab = line.strip().split('t')
yield img_path, int(lab)
生成的图像列表的结构是这样的:
../images/vegetables/lotus_root/1515827057517.jpg 2
../images/vegetables/lotus_root/1515827057582.jpg 2
../images/vegetables/lotus_root/1515827057616.jpg 2
../images/vegetables/lettuce/1515827015922.jpg 1
../images/vegetables/lettuce/1515827015983.jpg 1
../images/vegetables/lettuce/1515827016045.jpg 1
../images/vegetables/cuke/1515827008337.jpg 0
../images/vegetables/cuke/1515827008370.jpg 0
../images/vegetables/cuke/1515827008402.jpg 0
生成图像列表
所以我们要编写一个程序可以为我们生成这样的图像列表
在这个程序中,我们只要把一个大类的文件夹路径传进去就可以了,该程序会把里面的每个小类别都迭代,生成固定格式的列表.比如我们把蔬菜类别的根目录传进去../images/vegetables
本文代码、及测试图片在公众号 datadw 里 回复 图片分类 即可获取。
运行这个程序之后,会生成在data文件夹中生成一个单独的大类文件夹,比如我们这次是使用到蔬菜类,所以我生成一个vegetables文件夹,在这个文件夹下有3个文件:
文件名 |
作用 |
---|---|
trainer.list |
用于训练的图像列表 |
test.list |
用于测试的图像列表 |
readme.json |
该数据集的json格式的说明,方便以后使用 |
readme.json
文件的格式如下,可以很清楚看到整个数据的图像数量,总类别名称和类别数量,还有每个类对应的标签,类别的名字,该类别的测试数据和训练数据的数量:
{
"all_class_images": 3300,
"all_class_name": "vegetables",
"all_class_sum": 3,
"class_detail": [
{
"class_label": 1,
"class_name": "cuke",
"class_test_images": 110,
"class_trainer_images": 990
},
{
"class_label": 2,
"class_name": "lettuce",
"class_test_images": 110,
"class_trainer_images": 990
},
{
"class_label": 3,
"class_name": "lotus_root",
"class_test_images": 110,
"class_trainer_images": 990
}
]}
读取数据
通过这个程序可以将上一部分的图像列表读取,生成训练和测试使用的reader,在生成reader前,要传入一个图像的大小,PaddlePaddle会帮我们按照这个大小随机裁剪一个方形的图像,这是种随机裁剪也是数据增强的一种方式.
使用PaddlePaddle开始训练
导入依赖包
首先要先导入依赖包,其中有PaddlePaddle的V2包和上面定义的Myreader.py
读取数据的程序
# coding:utf-8
import sys
import os
import numpy as np
import paddle.v2 as paddle
from MyReader import MyReader
初始化Paddle
然后我们创建一个类,再在类中创建一个初始化函数,在初始化函数中来初始化我们的PaddlePaddle
class PaddleUtil:
# ***********************初始化操作*********************
def __init__(self):
# 初始化paddpaddle,只是用CPU,把GPU关闭
paddle.init(use_gpu=False, trainer_count=2)
定义神经网络模型
这里使用的是VGG神经网络,跟上一篇文章用到的VGG又有一点不同,这里可以看到conv_with_batchnorm=False
,我是把BN
关闭了,在这里不使用BN
层,笔者也不知道为什么如果加上BN
层之后就办法正常训练了,根本就没办法正常收敛。
创建分类器
通过数据输入数据的大小和上面获得的神经模型,使用Softmax输出全连接,得到分类器
获取参数
该函数可以通过输入是否是参数文件路径,或者是损失函数,如果是参数文件路径,就使用之前训练好的参数生产参数.如果不传入参数文件路径,那就使用传入的损失函数生成参数
创建训练器
创建训练器要3个参数,分别是损失函数,参数,优化方法.通过图像的标签信息和分类器生成损失函数.参数可以选择是使用之前训练好的参数,然后在此基础上再进行训练,又或者是使用损失函数生成初始化参数.然后再生成优化方法.就可以创建一个训练器了.
开始训练
要启动训练要4个参数,分别是训练数据,训练的轮数,训练过程中的事件处理,输入数据和标签的对应关系.
训练数据:这次的训练数据是我们自定义的数据集. 训练轮数:表示我们要训练多少轮,次数越多准确率越高,最终会稳定在一个固定的准确率上.不得不说的是这个会比MNIST数据集的速度慢很多 事件处理:训练过程中的一些事件处理,比如会在每个batch打印一次日志,在每个pass之后保存一下参数和测试一下测试数据集的预测准确率. 输入数据和标签的对应关系:说明输入数据是第0维度,标签是第1维度
然后在main
中调用相应的函数,开始训练,可以看到通过myReader.train_reader
来生成一个reader
输出日志如下:’
Pass 0, Batch 0, Cost 1.162887, Error 0.6171875
.....................
Test with Pass 0, Classification_Error 0.353333324194
使用PaddlePaddle预测
该函数需要输入3个参数, 第一个是需要预测的图像,图像传入之后,会经过load_image函数处理,大小会变成32*32大小,训练是输入数据的大小一样. 第二个就是训练好的参数 第三个是通过神经模型生成的分类器
然后在main
中调用相应的函数,开始预测,这个可以同时传入多个数据,可以同时预测
本文代码、及测试图片在公众号 datadw 里 回复 图片分类 即可获取。
输出的结果是:
预测结果为:0,可信度为:0.699004
预测结果为:0,可信度为:0.546674
预测结果为:2,可信度为:0.756389
via http://blog.csdn.net/qq_33200967/article/details/79095265?%3E
- spark2 sql读取json文件的格式要求
- 容器化RDS|调度策略
- Go语言并发编程总结
- hdu------(4302)Holedox Eating(树状数组+二分)
- spark2的SparkSession思考与总结2:SparkSession有哪些函数及作用是什么
- GO语言并发编程之互斥锁、读写锁详解
- spark2.2 SparkSession思考与总结1
- 【译】Spring 官方教程:Spring Security 架构
- hdu----(4301)Divide Chocolate(状态打表)
- hdu------(4300)Clairewd’s message(kmp)
- TensorFlow ML cookbook 第一章7、8节 实现激活功能和使用数据源
- Go语言struct类型详解
- spark1.x升级spark2如何升级及需要考虑的问题
- 使用 kubeadm 创建一个 kubernetes 集群
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- Spring Web MVC 请求消息
- Codeforces Round #633 (Div. 2) A Filling Diamonds (假题,观察)
- 如何管理和组织一个机器学习项目
- Spring Web MVC 简单使用
- Spring 中的 JDBC
- IDEA 快键键:展开所有文件夹、折叠所有文件夹(自定义)
- mysql 数据库的悲观锁和乐观锁
- C语言 二维数组和指针的一些笔记
- Java SpringBoot2.3.4 配置redis 基于lettuce 同时支持集群与单机 配置密码加密 并使用redisson分布式锁
- 使用elasticsearch-dump迁移elasticsearch集群数据
- Python爬虫之scrapy的入门使用
- 告别传统工业互联网,提高数字管控思维:三维组态分布式能源站
- 爱奇艺iOS移动端网络优化实践:请求成功率优化
- Java数据类型
- Python爬虫之scrapy构造并发送请求