【猫狗数据集】对一张张图像进行预测(而不是测试集)
数据集下载地址:
链接:https://pan.baidu.com/s/1l1AnBgkAAEhh0vI5_loWKw 提取码:2xq4
创建数据集:https://www.cnblogs.com/xiximayou/p/12398285.html
读取数据集:https://www.cnblogs.com/xiximayou/p/12422827.html
进行训练:https://www.cnblogs.com/xiximayou/p/12448300.html
保存模型并继续进行训练:https://www.cnblogs.com/xiximayou/p/12452624.html
加载保存的模型并测试:https://www.cnblogs.com/xiximayou/p/12459499.html
划分验证集并边训练边验证:https://www.cnblogs.com/xiximayou/p/12464738.html
使用学习率衰减策略并边训练边测试:https://www.cnblogs.com/xiximayou/p/12468010.html
利用tensorboard可视化训练和测试过程:https://www.cnblogs.com/xiximayou/p/12482573.html
从命令行接收参数:https://www.cnblogs.com/xiximayou/p/12488662.html
使用top1和top5准确率来衡量模型:https://www.cnblogs.com/xiximayou/p/12489069.html
使用预训练的resnet18模型:https://www.cnblogs.com/xiximayou/p/12504579.html
计算数据集的平均值和方差:https://www.cnblogs.com/xiximayou/p/12507149.html
读取数据集的第二种方式:https://www.cnblogs.com/xiximayou/p/12516735.html
epoch、batchsize、step之间的关系:https://www.cnblogs.com/xiximayou/p/12405485.html
首先我们上传一些图片到image文件夹中:
然后我们画出这些图片看看是什么样子:
import cv2
import matplotlib.pyplot as plt
import glob
# 使用matplotlib展示多张图片
def matplotlib_multi_pic1():
i=0
for img in glob.glob('/content/drive/My Drive/colab notebooks/image/*.jpg'):
img_name=img.split("/")[-1]
img = cv2.imread(img)
title=img_name
#行,列,索引
plt.subplot(3,3,i+1)
plt.imshow(img)
plt.title(title,fontsize=8)
plt.xticks([])
plt.yticks([])
i+=1
plt.show()
matplotlib_multi_pic1()
接着在test文件夹中新建一个test_from_image.py。
import torchvision
import sys
import torch
import torch.nn as nn
from PIL import Image
sys.path.append("/content/drive/My Drive/colab notebooks")
import glob
import numpy as np
import torchvision.transforms as transforms
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model=torchvision.models.resnet18(pretrained=False)
model.fc = nn.Linear(model.fc.in_features,2,bias=False)
model.to(device)
model.eval()
save_path="/content/drive/My Drive/colab notebooks/output/resnet18_best.t7"
checkpoint = torch.load(save_path)
model.load_state_dict(checkpoint['model'])
print("当前模型准确率为:",checkpoint["epoch_acc"])
images_path="/content/drive/My Drive/colab notebooks/image/"
transform = transforms.Compose([transforms.Resize((224,224))])
def predict():
true_labels=[]
output_labels=[]
for image in glob.glob(images_path+"/*.jpg"):
print(image)
if "cat" in image.split("/")[-1]:
tmp=0
else:
tmp=1
true_labels.append(tmp)
image=Image.open(image)
image=image.resize((224,224))
tensor=torch.from_numpy(np.asarray(image)).permute(2,0,1).float()/255.0
print(tensor.shape)
tensor=tensor.reshape((1,3,224,224))
tensor=tensor.to(device)
#print(tensor.shape)
output=model(tensor)
print(output)
_, pred = torch.max(output.data,1)
output_labels.append(pred.item())
return true_labels,output_labels
true_labels,output_labels=predict()
print("正确的标签是:")
print(true_labels)
print("预测的标签是:")
print(output_labels)
说明:这里需要注意的地方有:
- 图像要调整到网络输入一致的大小,即224×224
- 将【高,宽,通道】要转换成【通道,高,宽】的格式
- 输入的是【batchsize,C,H,W】,因此我们要增加一个batchsize维度
- 之前训练好的模型是使用cuda(),因此要将模型和数据放在GPU中
- 一定要转换状态,即model.eval()
结果:
下一节,可视化相应的特征图。
- ExtJs学习笔记(24)-Drag/Drop拖动功能
- 人工智能尚处探索阶段,为何我们对此异常焦虑
- ExtJs学习笔记(22)-XTemplate + WCF 打造无刷新数据分页
- 同步服务器系统时间操作记录
- kvm虚拟化管理平台WebVirtMgr部署-完整记录(安装Windows虚拟机)-(4)
- ExtJs学习笔记(11)_Absolute布局和Accordion布局
- ExtJs学习笔记(9)_Window的基本用法
- DateTime在ExtJs中无法正确序列化的问题
- ELK实时日志分析平台环境部署--完整记录
- 梳理Linux下OSI七层网络与TCP/IP五层网络架构
- 字符编码-使用c#研究
- iframe高度自适应的IE解决方案
- javascript读写本机文本文件
- 崔立鹏:腾讯云为知识竞技游戏提供解决方案
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- dubbo学习之本地存根实践
- vue3.0 加载json的“另类”方法(非ajax) 定义组件.vue文件
- pyhon3 安装 web 出错 ERROR: No matching distribution found for web
- tensorflow gpu 运行出现OOM错误
- jquery getJSON不执行问题解决
- python函数——创建文件夹
- 数据结构算法操作试题(C++/Python)—— 组合总和
- Android大厂收割秘籍:太难了,准备半年,腾讯/快手/美团外卖面试中的那些辛酸坎坷史
- leetcode链表之合并两个排序的链表
- 2020-09-09:裸写算法:两个线程轮流打印数字1-100。
- python中线程池使用
- 还在手动部署SpringBoot应用?试试这个自动化插件!
- Julia简易教程——3_复数和分数
- 怎么理解int main(int argc, const char *argv[])
- Julia简易教程——2_julia数学运算及其基本功能