Tensor Tensor("predictions/Softmax:0",shape=(?,4),dtype=float32) is not an element of this graph
时间:2019-07-10
本文章向大家介绍Tensor Tensor("predictions/Softmax:0",shape=(?,4),dtype=float32) is not an element of this graph,主要包括Tensor Tensor("predictions/Softmax:0",shape=(?,4),dtype=float32) is not an element of this graph使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。
ValueError : Tensor Tensor("predictions/Softmax:0",shape=(?,4),dtype=float32) is not an element of this graph
原始问题及解决方案
https://github.com/keras-team/keras/issues/2397#issuecomment-254919212
问题描述:
在keras+tensorflow框架下训练神经网络并得到权重h5文件。
在之后需要调用的python代码中读取权重和图片并预测
在C#多线程的子线程调用python代码时出现以下报错
“ValueError : Tensor Tensor(“predictions/Softmax:0”, shape=(?, 2), dtype=float32) is not an element of this graph
解决方法:
主要是在读取权重后增加一行graph = tf.get_default_graph()
model = load_model()
graph = tf.get_default_graph()
并在需要预测时前加 with graph.as_default():
原始py文件代码
#-*- coding:utf-8 -*-
from keras.applications.vgg16 import preprocess_input,VGG16
from keras.layers import Dense
from keras.models import Model
import numpy as np
from PIL import Image
from keras.optimizers import SGD
import time
import cv2
from math import *
from scipy.stats import mode
import tensorflow as tf
from keras import backend as K
import os
def get_session(gpu_fraction=1.0):
num_threads = os.environ.get('OMP_NUM_THREADS')
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=gpu_fraction)
if num_threads:
return tf.Session(config=tf.ConfigProto(
gpu_options=gpu_options, intra_op_parallelism_threads=num_threads))
else:
return tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
def load_model():
sgd = SGD(lr=0.00001, momentum=0.9)
model = VGG16(weights=None, classes=4)
# 加载模型权重
model.load_weights('./models/modelAngle.h5', by_name=True)
# 编译模型,以较小的学习率进行训练
model.compile(optimizer=sgd, loss='categorical_crossentropy', metrics=['accuracy'])
return model
# 加载模型
K.set_session(get_session())
model = load_model()
def predict(im):
"""
图片文字方向预测
"""
ROTATE = [0, 270, 180, 90]
# im = cv2.imread(img)
w, h, _ = im.shape
thesh = 0.05
xmin, ymin, xmax, ymax = int(thesh * w), int(thesh * h), w - int(thesh * w), h - int(thesh * h)
#im = im.crop((xmin, ymin, xmax, ymax)) # 剪切图片边缘,清楚边缘噪声
im = im[ymin:ymax, xmin:xmax]
# im = im.resize((224, 224))
im = cv2.resize(im, (224, 224))
img = np.array(im)
img = preprocess_input(img.astype(np.float32))
pred = model.predict(np.array([img]))
index = np.argmax(pred, axis=1)[0]
return ROTATE[index]
def rotate(image, angle, center=None, scale=1.0): #1
filled_color = -1
if filled_color == -1:
filled_color = mode([image[0, 0], image[0, -1],
image[-1, 0], image[-1, -1]]).mode[0]
if np.array(filled_color).shape[0] == 2:
if isinstance(filled_color, int):
filled_color = (filled_color, filled_color, filled_color)
else:
filled_color = tuple([int(i) for i in filled_color])
(h, w) = image.shape[:2] #2
height, width = image.shape[:2]
heightNew = int(width * fabs(sin(radians(angle))) + height * fabs(cos(radians(angle)))) # 这个公式参考之前内容
widthNew = int(height * fabs(sin(radians(angle))) + width * fabs(cos(radians(angle))))
matRotation = cv2.getRotationMatrix2D((width / 2, height / 2), angle, scale) # 逆时针旋转 degree
matRotation[0, 2] += (widthNew - width) / 2 # 因为旋转之后,坐标系原点是新图像的左上角,所以需要根据原图做转化
matRotation[1, 2] += (heightNew - height) / 2
imgRotation = cv2.warpAffine(image, matRotation, (widthNew, heightNew), borderValue=filled_color)
# imgRotation = cv2.warpAffine(transform_img, matRotation, (widthNew, heightNew), borderVal
return imgRotation #7
if __name__ == "__main__":
t = time.time()
img = cv2.imread("st2.png")
degree = predict(img.copy())
transform_img = rotate(img, degree)
cv2.imwrite("st2a.png", transform_img)
# print("旋转角度:"+str(predict("37.jpg"))+"°")
# print(type(predict("37.jpg")))
print("旋转检测时间:{:.2f}秒".format(time.time() - t))
修改后的代码
#-*- coding:utf-8 -*-
from keras.applications.vgg16 import preprocess_input,VGG16
from keras.layers import Dense
from keras.models import Model
import numpy as np
from PIL import Image
from keras.optimizers import SGD
import time
import cv2
from math import *
from scipy.stats import mode
import tensorflow as tf
from keras import backend as K
import os
def get_session(gpu_fraction=1.0):
num_threads = os.environ.get('OMP_NUM_THREADS')
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=gpu_fraction)
if num_threads:
return tf.Session(config=tf.ConfigProto(
gpu_options=gpu_options, intra_op_parallelism_threads=num_threads))
else:
return tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
def load_model():
sgd = SGD(lr=0.00001, momentum=0.9)
model = VGG16(weights=None, classes=4)
# 加载模型权重
model.load_weights('./models/modelAngle.h5', by_name=True)
# 编译模型,以较小的学习率进行训练
model.compile(optimizer=sgd, loss='categorical_crossentropy', metrics=['accuracy'])
return model
# 加载模型
K.set_session(get_session(0.5))
model = load_model()
graph = tf.get_default_graph()
def predict(im):
"""
图片文字方向预测
"""
with graph.as_default():
ROTATE = [0, 270, 180, 90]
# im = cv2.imread(img)
w, h, _ = im.shape
thesh = 0.05
xmin, ymin, xmax, ymax = int(thesh * w), int(thesh * h), w - int(thesh * w), h - int(thesh * h)
#im = im.crop((xmin, ymin, xmax, ymax)) # 剪切图片边缘,清楚边缘噪声
im = im[ymin:ymax, xmin:xmax]
# im = im.resize((224, 224))
im = cv2.resize(im, (224, 224))
img = np.array(im)
img = preprocess_input(img.astype(np.float32))
pred = model.predict(np.array([img]))
index = np.argmax(pred, axis=1)[0]
return ROTATE[index]
def rotate(image, angle, center=None, scale=1.0): #1
filled_color = -1
if filled_color == -1:
filled_color = mode([image[0, 0], image[0, -1],
image[-1, 0], image[-1, -1]]).mode[0]
if np.array(filled_color).shape[0] == 2:
if isinstance(filled_color, int):
filled_color = (filled_color, filled_color, filled_color)
else:
filled_color = tuple([int(i) for i in filled_color])
(h, w) = image.shape[:2] #2
height, width = image.shape[:2]
heightNew = int(width * fabs(sin(radians(angle))) + height * fabs(cos(radians(angle)))) # 这个公式参考之前内容
widthNew = int(height * fabs(sin(radians(angle))) + width * fabs(cos(radians(angle))))
matRotation = cv2.getRotationMatrix2D((width / 2, height / 2), angle, scale) # 逆时针旋转 degree
matRotation[0, 2] += (widthNew - width) / 2 # 因为旋转之后,坐标系原点是新图像的左上角,所以需要根据原图做转化
matRotation[1, 2] += (heightNew - height) / 2
imgRotation = cv2.warpAffine(image, matRotation, (widthNew, heightNew), borderValue=filled_color)
# imgRotation = cv2.warpAffine(transform_img, matRotation, (widthNew, heightNew), borderVal
return imgRotation #7
if __name__ == "__main__":
t = time.time()
img = cv2.imread("st2.png")
degree = predict(img.copy())
transform_img = rotate(img, degree)
cv2.imwrite("st2a.png", transform_img)
# print("旋转角度:"+str(predict("37.jpg"))+"°")
# print(type(predict("37.jpg")))
print("旋转检测时间:{:.2f}秒".format(time.time() - t))
然后错误解决了
- HDUOJ-------1753大明A+B(大数之小数加法)
- HDUOJ---1754 I Hate It (线段树之单点更新查区间最大值)
- HDUOJ----1166敌兵布阵(线段树单点更新)
- poj----2155 Matrix(二维树状数组第二类)
- poj------2352 Stars(树状数组)
- HDUOJ-----2852 KiKi's K-Number(树状数组+二分)
- nyoj----522 Interval (简单树状数组)
- HDUOJ-----2838Cow Sorting(组合树状数组)
- HDUOJ---2642Stars(二维树状数组)
- HDUOJ -----Color the ball
- poj-----Ultra-QuickSort(离散化+树状数组)
- HDUOJ---1241Oil Deposits(dfs)
- HDUOJ------2398Savings Account
- HDUOJ-----2399GPA
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- Docker容器ElasticSearch-Head创建索引无响应406
- springboot监控&springboot配置https
- 面试中最长常问到的 HashMap,你都知道多少?
- spring security 使用自定义AuthenticationFailureHandler无法跳转failureUrl
- Android studio 下载安装教程和第一个程序运行最新,多图详解
- ubuntu16.04下qt5.14报错:/home/zhangfakai/Qt5.14.1/5.14.1/gcc_64/include/QtGui/qopengl.h:141: error: GL/
- 每天手撕一道算法-64. 最小路径和
- 每日手撕一道算法题-322.零钱兑换
- 每天手撕一道算法题-130. 被围绕的区域
- TKE上部署metrics-server
- Docker-Compose搭建mysql、redis、zookeeper、rabbitmq、consul、elasticsearch环境
- MDK更改配色方案
- Apache通过多端口配置多站点
- FatFs-目录下文件扫描
- Python之Bilibili自动更新邮件提醒并任务栏图标「完整代码」