pytorch读取一张图像进行分类预测需要注意的问题(opencv、PIL)
时间:2022-07-23
本文章向大家介绍pytorch读取一张图像进行分类预测需要注意的问题(opencv、PIL),主要内容包括其使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。
读取图像一般是两个库:opencv和PIL
1、使用opencv读取图像
import cv2
image=cv2.imread("/content/drive/My Drive/colab notebooks/image/cat1.jpg")
print(image.shape)
(490, 410, 3)
2、使用PIL读取图像
import PIL
image=PIL.Image.open("/content/drive/My Drive/colab notebooks/image/cat1.jpg")
print(image.shape)
这里会报错:
AttributeError Traceback (most recent call last)
<ipython-input-30-807ec7af434b> in <module>()
1 import PIL
2 image=PIL.Image.open("/content/drive/My Drive/colab notebooks/image/cat1.jpg")
----> 3 print(image.shape)
AttributeError: 'JpegImageFile' object has no attribute 'shape'
我们要输出要这么做:
import numpy as np
print(np.array(image).shape)
(490, 410, 3)
需要注意的是:
使用opencv读取图像之后是BGR格式的,使用PIL读取图像之后是RGB格式的。
3、opencv格式的和PIL格式的之间的转换
这里参考:https://www.cnblogs.com/enumx/p/12359850.html
(1)opencv格式转换为PIL格式
import cv2
from PIL import Image
import numpy
img = cv2.imread("plane.jpg")
cv2.imshow("OpenCV",img)
image = Image.fromarray(cv2.cvtColor(img,cv2.COLOR_BGR2RGB))
image.show()
cv2.waitKey()
(2)PIL格式转换为opencv格式
import cv2
from PIL import Image
import numpy
image = Image.open("plane.jpg")
image.show()
img = cv2.cvtColor(numpy.asarray(image),cv2.COLOR_RGB2BGR)
cv2.imshow("OpenCV",img)
cv2.waitKey()
4、使用pytorch读取一张图片并进行分类预测
需要注意两个问题:
- 输入要转换为:[1,channel,H,W]
- 对输入的图像进行数据增强时要求是PIL.Image格式的
import torchvision
import sys
import torch
import torch.nn as nn
from PIL import Image
sys.path.append("/content/drive/My Drive/colab notebooks")
import glob
import numpy as np
import torchvision.transforms as transforms
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model=torchvision.models.resnet18(pretrained=False)
model.fc = nn.Linear(model.fc.in_features,4,bias=False)
model.to(device)
model.eval()
save_path="/content/drive/My Drive/colab notebooks/checkpoint/resnet18_best_v2.t7"
checkpoint = torch.load(save_path)
model.load_state_dict(checkpoint['model'])
print("当前模型准确率为:",checkpoint["epoch_acc"])
images_path="/content/drive/My Drive/colab notebooks/data/dataset/test/four"
transform = transforms.Compose([transforms.Resize((224,224))])
def predict():
true_labels=[]
output_labels=[]
for image in glob.glob(images_path+"/*.png"):
print(image)
true_labels.append(0)
#image=Image.open(image)
#image=image.resize((224,224))
image=cv2.imread(image)
image=cv2.resize(image,(224,224))
image = Image.fromarray(cv2.cvtColor(image,cv2.COLOR_BGR2RGB))
#print(np.array(image).shape)
tensor=torch.from_numpy(np.asarray(image)).permute(2,0,1).float()/255.0
tensor=tensor.reshape((1,3,224,224))
tensor=tensor.to(device)
#print(tensor.shape)
output=model(tensor)
print(output)
_, pred = torch.max(output.data,1)
output_labels.append(pred.item())
return true_labels,output_labels
true_labels,output_labels=predict()
print("正确的标签是:")
print(true_labels)
print("预测的标签是:")
print(output_labels)
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- 分布式 | DBLE 3.20.07.0 来啦!
- C语言三剑客之《C陷阱与缺陷》一书精华提炼
- Linux进程间通信(上)之管道、消息队列实践
- FPGA上电时序
- 更新Kubernetes APIServer证书
- R语言连续时间马尔科夫链模拟案例 Markov Chains
- 如何用R语言在机器学习中建立集成模型?
- 从零开始Kubernetes Operator
- TiKV源码解析系列文章(二十)Region Split源码解析
- scrapy爬虫框架和selenium的使用:对优惠券推荐网站数据LDA文本挖掘
- 单性状动物模型矩阵形式计算BLUP值
- 如何计算一般配合力和特殊配合力
- 【29期】Java集合框架 10 连问,你有被问过吗?
- 学徒数据挖掘之谁说生存分析一定要按照表达量中位值或者平均值分组呢?
- 软件质量的黄金准则