PyTorch中基于TPU的FastAI多类图像分类
作者 | PRUDHVI VARMA
编译 | VK
来源 | Analytics Indiamag
计算机视觉因其广泛的应用而成为人工智能领域中最具发展趋势的子领域之一。在某些领域,甚至它们在快速准确地识别图像方面超越了人类的智能。
在本文中,我们将演示最流行的计算机视觉应用之一-多类图像分类问题,使用fastAI库和TPU作为硬件加速器。TPU,即张量处理单元,可以加速深度学习模型的训练过程。
「本文涉及的主题」:
- 多类图像分类
- 常用的图像分类模型
- 使用TPU并在PyTorch中实现
多类图像分类
我们使用图像分类来识别图像中的对象,并且可以用于检测品牌logo、对对象进行分类等。但是这些解决方案有一个局限性,即只能识别对象,但无法找到对象的位置。但是与目标定位相比,图像分类模型更容易实现。
图像分类的常用模型
我们可以使用VGG-16/19,Resnet,Inception v1,v2,v3,Wideresnt,Resnext,DenseNet等,它们是卷积神经网络的高级变体。这些是流行的图像分类网络,并被用作许多最先进的目标检测和分割算法的主干。
基于FasAI库和TPU硬件的图像分类
我们将在以下方面开展这项工作步骤:
1.选择硬件加速器
这里我们使用Google Colab来实现。要在Google Colab中使用TPU,我们需要打开edit选项,然后打开notebook设置,并将硬件加速器更改为TPU。
通过运行下面的代码片段,你可以检查你的Notebook是否正在使用TPU。
import os
assert os.environ['COLAB_TPU_ADDR']
Path = 'grpc://'+os.environ['COLAB_TPU_ADDR']
print('TPU Address:', Path)
![](http://qiniu.aihubs.net/Screenshot -158.png)
2.加载FastAI库
在下面的代码片段中,我们将导入fastAI库。
from fastai.vision import *
from fastai.metrics import error_rate, accuracy
3.定制数据集
在下面的代码片段中,你还可以尝试使用自定义数据集。
PATH = '/content/images/dataset'
np.random.seed(24)
tfms = get_transforms(do_flip=True)
data = ImageDataBunch.from_folder(PATH, valid_pct=0.2, ds_tfms=tfms, size=299, bs=16).normalize(imagenet_stats)
data.show_batch(rows=4, figsize=(8, 8))
4.加载预训练的深度学习模型
在下面的代码片段中,我们将导入VGG-19 batch_normalisation模型。我们将把它作为fastAI的计算机视觉学习模块的一个实例。
learn = cnn_learner(data, models.vgg19_bn, metrics=accuracy)
5.训练模型
在下面的代码片段中,我们尝试使用一个epoch。
learn.fit_one_cycle(1)
在输出中,我们可以看到我们得到了0.99的准确度,它花了1分2秒。
在下面的代码片段中,我们使用混淆矩阵显示结果。
con_matrix = ClassificationInterpretation.from_learner(learn)
con_matrix.plot_confusion_matrix()
6.利用模型进行预测
在下面的代码片段中,我们可以通过在test_your_image中给出图像的路径来测试我们自己的图像。
test_your_image='/content/images (3).jpg'
test = open_image(test_your_image)
test.show()
在下面的代码片段中,我们可以得到输出张量及其所属的类。
learn.predict(test)
正如我们在上面的输出中看到的,模型已经预测了输入图像的类标签,它属于“flower”类别。
结论
在上面的演示中,我们使用带TPU的fastAI库和预训练VGG-19模型实现了一个多类的图像分类。在这项任务中,我们在对验证数据集进行分类时获得了0.99的准确率。
原文链接:https://analyticsindiamag.com/fastai-with-tpu-in-pytorch-for-multiclass-image-classification/
- 手把手教你使用sklearn快速入门机器学习
- 【 关关的刷题日记48】Leetcode 58. Length of Last Word
- RESTful API 设计指南
- 洛谷P1043 数字游戏
- 使用“空”对象替代引用是否为空判断
- 真是绝了!史上最详细的Jupyter Notebook入门教程
- 10.socket网络编程
- BZOJ1269: [AHOI2006]文本编辑器editor
- 开发人员为何需要企业服务总线?
- 搭建Visual Studio Code+Python开发环境1.对象简介2. 搭建步骤3.小结
- 洛谷P3835 【模板】可持久化平衡树
- 17.HTML
- 洛谷P2925 [USACO08DEC]干草出售Hay For Sale
- Numpy 修炼之道 (13)—— 将python函数向量化
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- 致敬Vue3: 1.1万字从零解读Vue3.0源码响应式系统
- APP自动化测试系列之Desired Capabilities详解
- Kafka分区分配策略(Partition Assignment Strategy)
- 内网渗透-代理篇(一)
- java学习应用篇|逃不掉的HelloWorld
- java学习原理篇|java程序运行套路
- 架构师成长之路系列(二)
- 前端性能优化 24 条建议(2020)
- 【Flutter 实战】大量复杂数据持久化
- GBDT+LR:Practical Lessons from Predicting Clicks on Ads
- 告别setState()! 优雅的UI与Model绑定 Flutter DataBus使用~
- k8s etcd 的实现原理
- iOS动态View的探索
- 安卓开发的瑞士军刀“Retrofit2框架”
- R语言中的广义线性模型(GLM)和广义相加模型(GAM):多元(平滑)回归分析保险资金投资组合信用风险敞口