【猫狗数据集】从命令行接收参数
数据集下载地址:
链接:https://pan.baidu.com/s/1l1AnBgkAAEhh0vI5_loWKw 提取码:2xq4
创建数据集:https://www.cnblogs.com/xiximayou/p/12398285.html
读取数据集:https://www.cnblogs.com/xiximayou/p/12422827.html
进行训练:https://www.cnblogs.com/xiximayou/p/12448300.html
保存模型并继续进行训练:https://www.cnblogs.com/xiximayou/p/12452624.html
加载保存的模型并测试:https://www.cnblogs.com/xiximayou/p/12459499.html
划分验证集并边训练边验证:https://www.cnblogs.com/xiximayou/p/12464738.html
使用学习率衰减策略并边训练边测试:https://www.cnblogs.com/xiximayou/p/12468010.html
利用tensorboard可视化训练和测试过程:https://www.cnblogs.com/xiximayou/p/12482573.html
epoch、batchsize、step之间的关系:https://www.cnblogs.com/xiximayou/p/12405485.html
本节我们要在命令行接收参数,包括batch_size的值以及网络的类型。
基本上我们只需要修改main.py就行了:
main.py
import sys
sys.path.append("/content/drive/My Drive/colab notebooks")
from utils import rdata
from model import resnet
import torch.nn as nn
import torch
import numpy as np
import torchvision
import train
import torch.optim as optim
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def main(batch_size,baseline):
train_loader,val_loader,test_loader=rdata.load_dataset(batch_size)
if baseline:
model =torchvision.models.resnet18(pretrained=False)
model.fc = nn.Linear(model.fc.in_features,2,bias=False)
if torch.cuda.is_available():
model.cuda()
#定义训练的epochs
num_epochs=100
#定义学习率
learning_rate=0.1
#定义损失函数
criterion=nn.CrossEntropyLoss()
#定义优化方法,简单起见,就是用带动量的随机梯度下降
optimizer = torch.optim.SGD(params=model.parameters(), lr=0.1, momentum=0.9,
weight_decay=1*1e-4)
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, [40,80], 0.1)
print("训练集有:",len(train_loader.dataset))
#print("验证集有:",len(val_loader.dataset))
print("测试集有:",len(test_loader.dataset))
trainer=train.Trainer(criterion,optimizer,model)
trainer.loop(num_epochs,train_loader,val_loader,test_loader,scheduler)
if __name__ == "__main__":
import argparse
p=argparse.ArgumentParser()
p.add_argument("--batch_size",type=int,default=64)
p.add_argument("--baseline",action="store_true")
args=p.parse_args()
main(args.batch_size,args.baseline)
说明:我们将读取数据集、定义损失、优化器等代码放入到main()函数中,然后给main传入batch_size和baseline。使用argparse可以从命令行接收参数。add_argument()函数中,第一个参数是参数的名称,第二个是参数的类型,default是默认值,即不在命令行输入--batch_size 具体值,则会使用默认值。需要关注的是action="store_true",该参数的意思是默认baseline为False,如果在命令行中加入了--baseline,则baseline的值就为True。
结果如图所示:
没有加--batch_size,则batch_size默认为64,也就是18255/64约等于286。然后我们使用了--baseline,即默认使用resnet18模型。
由于图像分类一般考虑的衡量指标是top1和top5,下一节就是加上计算top5的代码了。
- DBA和开发同事的一些代沟(四) (r7笔记第36天)
- python获取文件所在目录和文件名,以及检索当前文件名的方法
- 数据同步中的误导(r7笔记第34天)
- java读取xml文件
- 优化算法——粒子群算法(PSO)
- Java开发画板
- Python—numpy模块下函数介绍(一)numpy.ones、empty等
- Tomcat用户权限设置
- 优化算法——模拟退火算法
- 绘制动态心形图案::R语言绘制心形图
- 物化视图中的统计信息导致的查询问题分析和修复 (r7笔记第47天)
- R语言之系统聚类(层次)分析之图谱形式完整版
- Java操作数据库Spring(1)
- python基础知识——内置数据结构(集合)
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- MyBatis结果集映射
- Hadoop分布式集群环境搭建
- 使用Hadoop统计日志数据
- Mybatis在接口上使用注解配置SQL语句以及接口与xml一起使用
- 分布式计算框架MapReduce
- 安装webpack后,执行webpack -v命令时报错:SyntaxError: Block-sc
- SpringMVC数据类型转换器与国际化配置
- 分布式资源调度——YARN框架
- 在SpringMVC中使用数据验证组件——hibernate-validator
- 我的 2020 iOS BAT 面试心得
- Java操作HDFS开发环境搭建以及HDFS的读写流程
- HDFS伪分布式环境搭建
- 初识Hadoop
- SpringMVC返回JSON数据以及文件上传、过滤静态资源
- SpringMVC返回数据到视图