使用TPU训练模型
时间:2022-07-22
本文章向大家介绍使用TPU训练模型,主要内容包括其使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。
如果想尝试使用Google Colab上的TPU来训练模型,也是非常方便,仅需添加6行代码。
在Colab笔记本中:修改->笔记本设置->硬件加速器 中选择 TPU
注:以下代码只能在Colab 上才能正确执行。
可通过以下colab链接测试效果《tf_TPU》:
https://colab.research.google.com/drive/1XCIhATyE1R7lq6uwFlYlRsUr5d9_-r1s
%tensorflow_version 2.x
import tensorflow as tf
print(tf.__version__)
from tensorflow.keras import *
一,准备数据
MAX_LEN = 300
BATCH_SIZE = 32
(x_train,y_train),(x_test,y_test) = datasets.reuters.load_data()
x_train = preprocessing.sequence.pad_sequences(x_train,maxlen=MAX_LEN)
x_test = preprocessing.sequence.pad_sequences(x_test,maxlen=MAX_LEN)
MAX_WORDS = x_train.max()+1
CAT_NUM = y_train.max()+1
ds_train = tf.data.Dataset.from_tensor_slices((x_train,y_train))
.shuffle(buffer_size = 1000).batch(BATCH_SIZE)
.prefetch(tf.data.experimental.AUTOTUNE).cache()
ds_test = tf.data.Dataset.from_tensor_slices((x_test,y_test))
.shuffle(buffer_size = 1000).batch(BATCH_SIZE)
.prefetch(tf.data.experimental.AUTOTUNE).cache()
二,定义模型
tf.keras.backend.clear_session()
def create_model():
model = models.Sequential()
model.add(layers.Embedding(MAX_WORDS,7,input_length=MAX_LEN))
model.add(layers.Conv1D(filters = 64,kernel_size = 5,activation = "relu"))
model.add(layers.MaxPool1D(2))
model.add(layers.Conv1D(filters = 32,kernel_size = 3,activation = "relu"))
model.add(layers.MaxPool1D(2))
model.add(layers.Flatten())
model.add(layers.Dense(CAT_NUM,activation = "softmax"))
return(model)
def compile_model(model):
model.compile(optimizer=optimizers.Nadam(),
loss=losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[metrics.SparseCategoricalAccuracy(),metrics.SparseTopKCategoricalAccuracy(5)])
return(model)
三,训练模型
#增加以下6行代码
import os
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='grpc://' + os.environ['COLAB_TPU_ADDR'])
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
strategy = tf.distribute.experimental.TPUStrategy(resolver)
with strategy.scope():
model = create_model()
model.summary()
model = compile_model(model)
history = model.fit(ds_train,validation_data = ds_test,epochs = 10)
- 真正将标注文字遮盖的方法
- “AS3.0高级动画编程”学习:第四章 寻路(AStar/A星/A*)算法 (中)
- li浮动时ul高度为0,解决ul自适应高度的几种方法
- 使用正则表达式求完整路径中的文件名
- “AS3.0高级动画编程”学习:第四章 寻路(AStar/A星/A*)算法 (下)
- Centos下SFTP双机高可用环境部署记录
- as3:Function以及call,apply
- centos6下redis cluster集群部署过程
- centos6下ActiveMQ+Zookeeper消息中间件集群部署记录
- 发布一个轻量级的滑块控件
- as3:sprite作为容器使用时,最好不要指定width,height
- openssl版本升级操作记录
- 清除浮动(clearfix hack)
- Linux下环境变量配置方法梳理(.bash_profile和.bashrc的区别)
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- Android自定义LinearLayout布局显示不完整的解决方法
- android短信管理器SmsManager实例详解
- Android开发判断一个app应用是否在运行的方法详解
- 收割腾讯等十几个Offer后,揭秘进大厂的秘诀和Android技术面试题汇总!
- Flutter BLoC 异步通信、BlocBuilder的基本使用、BlocProvider的初探
- Android设备获取扫码枪扫描的内容与可能遇到的问题解决
- 3分钟短文:胆儿真肥!Laravel在命令行问用户要数据!
- 实战矿马:数据异常牵出的挖矿木马(.systemd-service.sh)
- leetcode之两个相同字符之间的最长子字符串
- 面试阿里被P8质问:ConcurrentHashMap真的线程安全吗?
- 腾讯云TKE-搭建prometheus监控(二)
- Qt音视频开发41-人脸识别嵌入式
- 浅析Android Studio 3.0 升级各种坑(推荐)
- Android EasyPermissions官方库高效处理权限相关教程
- 关于Android 6.0权限的动态适配详解