奉献pytorch 搭建 CNN 卷积神经网络训练图像识别的模型,配合numpy 和matplotlib 一起使用调用 cuda GPU进行加速训练
时间:2019-08-24
本文章向大家介绍奉献pytorch 搭建 CNN 卷积神经网络训练图像识别的模型,配合numpy 和matplotlib 一起使用调用 cuda GPU进行加速训练,主要包括奉献pytorch 搭建 CNN 卷积神经网络训练图像识别的模型,配合numpy 和matplotlib 一起使用调用 cuda GPU进行加速训练使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。
1、Torch构建简单的模型
# coding:utf-8 import torch class Net(torch.nn.Module): def __init__(self,img_rgb=3,img_size=32,img_class=13): super(Net, self).__init__() self.conv1 = torch.nn.Sequential( torch.nn.Conv2d(in_channels=img_rgb, out_channels=img_size, kernel_size=3, stride=1,padding= 1), # torch.nn.ReLU(), torch.nn.MaxPool2d(2), # torch.nn.Dropout(0.5) ) self.conv2 = torch.nn.Sequential( torch.nn.Conv2d(28, 64, 3, 1, 1), torch.nn.ReLU(), torch.nn.MaxPool2d(2) ) self.conv3 = torch.nn.Sequential( torch.nn.Conv2d(64, 64, 3, 1, 1), torch.nn.ReLU(), torch.nn.MaxPool2d(2) ) self.dense = torch.nn.Sequential( torch.nn.Linear(64 * 3 * 3, 128), torch.nn.ReLU(), torch.nn.Linear(128, img_class) ) def forward(self, x): conv1_out = self.conv1(x) conv2_out = self.conv2(conv1_out) conv3_out = self.conv3(conv2_out) res = conv3_out.view(conv3_out.size(0), -1) out = self.dense(res) return out CUDA = torch.cuda.is_available() model = Net(1,28,13) print(model) optimizer = torch.optim.Adam(model.parameters()) loss_func = torch.nn.MultiLabelSoftMarginLoss()#nn.CrossEntropyLoss() if CUDA: model.cuda() def batch_training_data(x_train,y_train,batch_size,i): n = len(x_train) left_limit = batch_size*i right_limit = left_limit+batch_size if n>=right_limit: return x_train[left_limit:right_limit,:,:,:],y_train[left_limit:right_limit,:] else: return x_train[left_limit:, :, :, :], y_train[left_limit:, :]
2、奉献训练过程的代码
# coding:utf-8 import time import os import torch import numpy as np from data_processing import get_DS from CNN_nework_model import cnn_face_discern_model from torch.autograd import Variable from use_torch_creation_model import optimizer, model, loss_func, batch_training_data,CUDA from sklearn.metrics import accuracy_score os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' st = time.time() # 获取训练集与测试集以 8:2 分割 x_,y_,y_true,label = get_DS() label_number = len(label) x_train,y_train = x_[:960,:,:,:].reshape((960,1,28,28)),y_[:960,:] x_test,y_test = x_[960:,:,:,:].reshape((340,1,28,28)),y_[960:,:] y_test_label = y_true[960:] print(time.time() - st) print(x_train.shape,x_test.shape) batch_size = 100 n = int(len(x_train)/batch_size)+1 for epoch in range(100): global loss for batch in range(n): x_training,y_training = batch_training_data(x_train,y_train,batch_size,batch) batch_x,batch_y = Variable(torch.from_numpy(x_training)).float(),Variable(torch.from_numpy(y_training)).float() if CUDA: batch_x=batch_x.cuda() batch_y=batch_y.cuda() out = model(batch_x) loss = loss_func(out, batch_y) optimizer.zero_grad() loss.backward() optimizer.step() # 测试精确度 if epoch%9==0: global x_test_tst if CUDA: x_test_tst = Variable(torch.from_numpy(x_test)).float().cuda() y_pred = model(x_test_tst) y_predict = np.argmax(y_pred.cpu().data.numpy(),axis=1) acc = accuracy_score(y_test_label,y_predict) print("loss={} aucc={}".format(loss.cpu().data.numpy(),acc))
3、总结
通过博主通过TensorFlow、keras、pytorch进行训练同样的模型同样的图像数据,结果发现,pyTorch快了很多倍,特别是在导入模型的时候比TensorFlow快了很多。合适部署接口和集成在项目中。
原文地址:https://www.cnblogs.com/wuzaipei/p/11406450.html
- [大数据之Spark]——Transformations转换入门经典实例
- 字符串的排列
- 斐波那契额数列及青蛙跳台阶问题
- 在Mac OS X上配置Apache2
- 合并两个排序的链表
- 还有5天,你的比特币最重要的孩子UB-UBTC 可能就永远不属于你了
- Spark SQL 用户自定义函数UDF、用户自定义聚合函数UDAF 教程(Java踩坑教学版)
- Webpack多入口文件、热更新等体验
- 从hello world 解析程序运行机制
- 万达大量员工“被”辞职?曲德君回应:万达网科没有倒
- iOS Programming – 触摸事件处理(2)
- 洋葱海外仓融资2亿元 官网启用msyc.cc域名
- Webpack单元测试,e2e测试
- [看图说话] 基于Spark UI性能优化与调试——初级篇
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- Android日期和时间选择器实现代码
- Android开发实现ImageView加载摄像头拍摄的大图功能
- Android开发实现的Intent跳转工具类实例
- Android开发中的文件操作工具类FileUtil完整实例
- Android开发中超好用的正则表达式工具类RegexUtil完整实例
- Android ijkplayer的使用方法解析
- Android开发实现查询远程服务器的工具类QueryUtils完整实例
- 解决android studio 3.0 加载项目过慢问题–maven仓库选择
- Android实现朋友圈点赞列表
- Kotlin基本类型自动装箱一点问题剖析
- Kotlin入门教程之开发环境搭建
- Android:Field can be converted to a local varible.的解决办法
- Android使用多线程进行网络聊天室通信
- android实现banner轮播图无限轮播效果
- Android CheckBox中设置padding无效解决办法