【深度学习】使用tensorflow实现VGG19网络
接上一篇AlexNet,本文讲述使用tensorflow实现VGG19网络。
VGG网络与AlexNet类似,也是一种CNN,VGG在2014年的 ILSVRC localization and classification 两个问题上分别取得了第一名和第二名。VGG网络非常深,通常有16-19层,卷积核大小为 3 x 3,16和19层的区别主要在于后面三个卷积部分卷积层的数量。第二个用tensorflow独立完成的小玩意儿......
同样先放上我的代码,由AlexNet的代码改过来的:https://github.com/hjptriplebee/VGG19_with_tensorflow
如果想运行代码,详细的配置要求都在上面链接的readme文件中了。本文建立在一定的tensorflow基础上,不会对太细的点进行说明。
模型结构
可以看到VGG的前几层为卷积和maxpool的交替,每个卷积包含多个卷积层,后面紧跟三个全连接层。激活函数采用Relu,训练采用了dropout,但并没有像AlexNet一样采用LRN(论文给出的理由是加LRN实验效果不好)。
模型定义
def maxPoolLayer(x, kHeight, kWidth, strideX, strideY, name, padding = "SAME"):
"""max-pooling"""
return tf.nn.max_pool(x, ksize = [1, kHeight, kWidth, 1],
strides = [1, strideX, strideY, 1], padding = padding, name = name)
def dropout(x, keepPro, name = None):
"""dropout"""
return tf.nn.dropout(x, keepPro, name)
def fcLayer(x, inputD, outputD, reluFlag, name):
"""fully-connect"""
with tf.variable_scope(name) as scope:
w = tf.get_variable("w", shape = [inputD, outputD], dtype = "float")
b = tf.get_variable("b", [outputD], dtype = "float")
out = tf.nn.xw_plus_b(x, w, b, name = scope.name)
if reluFlag:
return tf.nn.relu(out)
else:
return out
def convLayer(x, kHeight, kWidth, strideX, strideY,
featureNum, name, padding = "SAME"):
"""convlutional"""
channel = int(x.get_shape()[-1]) #获取channel数
with tf.variable_scope(name) as scope:
w = tf.get_variable("w", shape = [kHeight, kWidth, channel, featureNum])
b = tf.get_variable("b", shape = [featureNum])
featureMap = tf.nn.conv2d(x, w, strides = [1, strideY, strideX, 1], padding = padding)
out = tf.nn.bias_add(featureMap, b)
return tf.nn.relu(tf.reshape(out, featureMap.get_shape().as_list()), name = scope.name)
定义了卷积、pooling、dropout、全连接五个模块,使用了上一篇AlexNet中的代码,其中卷积模块去除了group参数,因为网络没有像AlexNet一样分成两部分。接下来定义VGG19。
class VGG19(object):
"""VGG model"""
def __init__(self, x, keepPro, classNum, skip, modelPath = "vgg19.npy"):
self.X = x
self.KEEPPRO = keepPro
self.CLASSNUM = classNum
self.SKIP = skip
self.MODELPATH = modelPath
#build CNN
self.buildCNN()
def buildCNN(self):
"""build model"""
conv1_1 = convLayer(self.X, 3, 3, 1, 1, 64, "conv1_1" )
conv1_2 = convLayer(conv1_1, 3, 3, 1, 1, 64, "conv1_2")
pool1 = maxPoolLayer(conv1_2, 2, 2, 2, 2, "pool1")
conv2_1 = convLayer(pool1, 3, 3, 1, 1, 128, "conv2_1")
conv2_2 = convLayer(conv2_1, 3, 3, 1, 1, 128, "conv2_2")
pool2 = maxPoolLayer(conv2_2, 2, 2, 2, 2, "pool2")
conv3_1 = convLayer(pool2, 3, 3, 1, 1, 256, "conv3_1")
conv3_2 = convLayer(conv3_1, 3, 3, 1, 1, 256, "conv3_2")
conv3_3 = convLayer(conv3_2, 3, 3, 1, 1, 256, "conv3_3")
conv3_4 = convLayer(conv3_3, 3, 3, 1, 1, 256, "conv3_4")
pool3 = maxPoolLayer(conv3_4, 2, 2, 2, 2, "pool3")
conv4_1 = convLayer(pool3, 3, 3, 1, 1, 512, "conv4_1")
conv4_2 = convLayer(conv4_1, 3, 3, 1, 1, 512, "conv4_2")
conv4_3 = convLayer(conv4_2, 3, 3, 1, 1, 512, "conv4_3")
conv4_4 = convLayer(conv4_3, 3, 3, 1, 1, 512, "conv4_4")
pool4 = maxPoolLayer(conv4_4, 2, 2, 2, 2, "pool4")
conv5_1 = convLayer(pool4, 3, 3, 1, 1, 512, "conv5_1")
conv5_2 = convLayer(conv5_1, 3, 3, 1, 1, 512, "conv5_2")
conv5_3 = convLayer(conv5_2, 3, 3, 1, 1, 512, "conv5_3")
conv5_4 = convLayer(conv5_3, 3, 3, 1, 1, 512, "conv5_4")
pool5 = maxPoolLayer(conv5_4, 2, 2, 2, 2, "pool5")
fcIn = tf.reshape(pool5, [-1, 7*7*512])
fc6 = fcLayer(fcIn, 7*7*512, 4096, True, "fc6")
dropout1 = dropout(fc6, self.KEEPPRO)
fc7 = fcLayer(dropout1, 4096, 4096, True, "fc7")
dropout2 = dropout(fc7, self.KEEPPRO)
self.fc8 = fcLayer(dropout2, 4096, self.CLASSNUM, True, "fc8")
def loadModel(self, sess):
"""load model"""
wDict = np.load(self.MODELPATH, encoding = "bytes").item()
#for layers in model
for name in wDict:
if name not in self.SKIP:
with tf.variable_scope(name, reuse = True):
for p in wDict[name]:
if len(p.shape) == 1:
#bias 只有一维
sess.run(tf.get_variable('b', trainable = False).assign(p))
else:
#weights
sess.run(tf.get_variable('w', trainable = False).assign(p))
buildCNN函数完全按照VGG的结构搭建网络。
loadModel函数从模型文件中读取参数,采用的模型文件见github上的readme说明。 至此,我们定义了完整的模型,下面开始测试模型。
模型测试
ImageNet训练的VGG有很多类,几乎包含所有常见的物体,因此我们随便从网上找几张图片测试。比如我直接用了之前做项目的图片,为了避免审美疲劳,我们不只用渣土车,还要用挖掘机、采沙船:
然后编写测试代码:
parser = argparse.ArgumentParser(description='Classify some images.')
parser.add_argument('mode', choices=['folder', 'url'], default='folder')
parser.add_argument('path', help='Specify a path [e.g. testModel]')
args = parser.parse_args(sys.argv[1:])
if args.mode == 'folder': #测试方式为本地文件夹
#get testImage
withPath = lambda f: '{}/{}'.format(args.path,f)
testImg = dict((f,cv2.imread(withPath(f))) for f in os.listdir(args.path) if os.path.isfile(withPath(f)))
elif args.mode == 'url': #测试方式为URL
def url2img(url): #获取URL图像
'''url to image'''
resp = urllib.request.urlopen(url)
image = np.asarray(bytearray(resp.read()), dtype="uint8")
image = cv2.imdecode(image, cv2.IMREAD_COLOR)
return image
testImg = {args.path:url2img(args.path)}
if testImg.values():
#some params
dropoutPro = 1
classNum = 1000
skip = []
imgMean = np.array([104, 117, 124], np.float)
x = tf.placeholder("float", [1, 224, 224, 3])
model = vgg19.VGG19(x, dropoutPro, classNum, skip)
score = model.fc8
softmax = tf.nn.softmax(score)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
model.loadModel(sess) #加载模型
for key,img in testImg.items():
#img preprocess
resized = cv2.resize(img.astype(np.float), (224, 224)) - imgMean #去均值
maxx = np.argmax(sess.run(softmax, feed_dict = {x: resized.reshape((1, 224, 224, 3))})) #网络输入为224*224
res = caffe_classes.class_names[maxx]
font = cv2.FONT_HERSHEY_SIMPLEX
cv2.putText(img, res, (int(img.shape[0]/3), int(img.shape[1]/3)), font, 1, (0, 255, 0), 2) #在图像上绘制结果
print("{}: {}n----".format(key,res)) #输出测试结果
cv2.imshow("demo", img)
cv2.waitKey(0)
如果你看完了我AlexNet的博客,那么一定会发现我这里的测试代码做了一些小的修改,增加了URL测试的功能,可以测试网上的图像 ,测试结果如下:
- 消灭假新闻:使用Scikit-Learn检测虚假新闻
- WPF 跳动的文字
- Sql批量修改帝国cms文章发布时间(需unix时间,否则会变为1970-01-01)
- socket 通信 多线程调用窗体(委托)的几个知识点,记录在案,以备查阅
- 微信公众平台改版 终于支持图文消息链接了
- 机器学习黑客系列:模型比较与选择
- asp.net :使用jquery 的ajax +WebService+json 实现无刷新去后台值
- 用A标签实现页面内容定位 点击链接跳到具体位置
- Seq2seq模型的一个变种网络:Pointer Network的简单介绍
- 框架设计原则和规范(三)
- 微信推广功能支持图片广告和投放外链广告
- PHP 高级编程之多线程
- 框架设计原则和规范(四)
- python 函数参数的传递(参数带星号的说明)
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- 二维树状数组-POJ 2155 Matrix
- 浙大版《C语言程序设计(第3版)》题目集 习题10-7 十进制转换二进制
- flink开发过程中遇到的问题集锦
- FLINK实战-使用CEP进行网站监控报警和报警恢复
- 浙大版《C语言程序设计(第3版)》题目集 习题10-8 递归实现顺序输出整数
- 浅谈DAO设计模式(示例)
- 浙大版《C语言程序设计(第3版)》题目集 习题11-1 输出月份英文名
- flink教程-flink 1.11 使用sql将流式数据写入hive
- flink实战-使用广播实现报警阈值动态更新
- 浅谈MVC设计模式(示例)
- flink实战-聊一聊flink中的聚合算子
- 浙大版《C语言程序设计(第3版)》题目集 习题11-2 查找星期
- flink实战-模拟简易双11实时统计大屏
- Android Studio首次运行指南
- PAT (Basic Level) Practice (中文)1047 编程团体赛