[深度学习] Pytorch学习(二)—— torch.nn 实践:训练分类器
时间:2019-10-20
本文章向大家介绍[深度学习] Pytorch学习(二)—— torch.nn 实践:训练分类器,主要包括[深度学习] Pytorch学习(二)—— torch.nn 实践:训练分类器使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。
#%%
# 1.Loading and normalizing CIFAR10
import torch
import torchvision
import torchvision.transforms as transforms
batch_size = 16
transform = transforms.Compose( [transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] )
# 对图像的预处理,用在加载数据时,当作函数传给transform参数
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
shuffle=True, num_workers=0)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
shuffle=False, num_workers=0)
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
#%%
import matplotlib.pyplot as plt
import numpy as np
# functions to show an image
def imshow(img):
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
print( npimg.shape )
plt.imshow(np.transpose(npimg, (1, 2, 0)))
print( np.transpose( npimg, (1, 2, 0) ).shape )
plt.show()
# get some random training images
dataiter = iter(trainloader)
# images torch.Size([16, 3, 32, 32]). labels torch.Size([16])
images, labels = dataiter.next()
# show images
imshow(torchvision.utils.make_grid(images))
# print labels
print(' '.join('%5s' % classes[labels[j]] for j in range(batch_size)))
#%%
# 2.Define a Convolutional Neural Network
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 16, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(16, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
net = Net()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net = nn.DataParallel(net) # 多GPU
net.to(device) #GPU
#%%
# 3.Define a Loss Function and optimizer
import torch.optim as optim
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
#%%
# 4.Train the network
for epoch in range(2):
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data # torch.Size([16, 3, 32, 32])
# GPU
inputs, labels = inputs.to(device), labels.to(device)
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
if i % 500 == 499:
print('[%d ,%5d] loss: %.3f' %
(epoch+1, i+1, running_loss/2000))
running_loss = 0.0
print("Finished Training")
# save trained model:
PATH = 'cifar_net.pth'
torch.save(net.state_dict(), PATH)
#%%
# 5.Test the network on the test data
# dataiter = iter(testloader)
# images, labels = dataiter.next()
# imshow(torchvision.utils.make_grid(images))
# print('GroundTruth: ',
# ''.join('%5s' % classes[labels[j]] for j in range(batch_size)))
# net = Net()
# net.load_state_dict(torch.load(PATH))
# # 输出的是能量能量越大的 是这个类的可能性越大
# outputs = net(images)
# # 按行取最大值
# _, predicted = torch.max(outputs, 1)
# print('Predicted: ',
# ''.join('%5s' % classes[predicted[j]] for j in range(batch_size)))
# Let us look at how the network performs on the whole dataset
correct = 0
total = 0
with torch.no_grad():
for data in testloader:
images, labels = data
# GPU
images, labels = images.to(device), labels.to(device)
outputs = net(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Accuracy of the network on the 10000 test images: %d %%'
% (100 * correct / total))
# what are the classes that performed well,
# and the classes that did not perform well
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
with torch.no_grad():
for data in testloader:
images, labels = data
images, labels = images.to(device), labels.to(device)
outputs = net(images)
_, predicted = torch.max(outputs, 1)
c = (predicted == labels).squeeze()
for i in range(batch_size):
label = labels[i]
class_correct[label] += c[i].item()
class_total[label] += 1
for i in range(10):
print('Accuracy of %5s : %2d %%' % (
classes[i], 100 * class_correct[i] / class_total[i]))
原文地址:https://www.cnblogs.com/importGPX/p/11706720.html
- Silverlight:针式打印机文字模糊的改善办法
- 大数据和云计算技术周报:NoSQL特辑
- 常用业务接口界面化 in python flask
- 打印机设置(PrintDialog)、页面设置(PageSetupDialog) 及 RDLC报表如何选择指定打印机
- 区块链推动支付革命
- MySQL常见的库操作,表操作,数据操作集锦及一些注意事项
- nohup命令
- 跨浏览器的剪贴板访问解决方案
- 装逼必备:大型分布式网站术语分析
- 年前爆炸一波!小程序视频功能来了!
- ubuntu13.04环境hadoop1.2.1单机模式安装
- silverlight:telerik RadControls中RadGridView的一个Bug及解决办法
- scope引起的问题
- JS正则表达式常用函数汇总
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- 搭建高可用的Replication集群归档大量的冷数据
- Python 技术篇-文件操控:文件的移动和复制
- Python发邮件脚本,Python调用163邮箱SMTP服务实现邮件群发
- 为PXC集群引入Mycat并构建完整的高可用集群架构
- Python3 字典
- 安装Percona Server数据库(in CentOS 8)
- Python 基础篇-正斜杠("/")和反斜杠("")的用法
- 在CentOS8下搭建PXC集群
- Python 基础篇-相对路径、绝对路径的写法
- Python3 元组
- 关于MySQL的基准测试
- Python 技术篇-操作excel,对excel进行读取和写入
- Mycat 整合 MySQL 8.x 踩坑实践
- Python 技术篇-xlwt库不新建,直接读取已存在的excel并写入
- Python3 列表