深度学习与R语言
对于R语言用户来说,深度学习还没有生产级的解决方案(除了MXNET)。这篇文章介绍了R语言的Keras接口,以及如何使用它来执行图像分类。文章结尾会通过提供一些代码片段显示Keras的直观和强大
Tensorflow
去年1月,R语言中的Tensorflow 发布了,它提供了从R语言中获得的Tensorflow API的方法。这是很重要的,因为Tensorflow是最受欢迎的深度学习库。然而,对于大多数R语言用户来说,R语言的Tensorflow接口和R语言并不是很像。下面是训练模型的代码块。
cross_entropy <- tf$reduce_mean(-tf$reduce_sum(y_* tf$log(y_conv), reduction_indices=1L))
train_step <- tf$train$AdamOptimizer(1e-4)$minimize(cross_entropy)
correct_prediction <- tf$equal(tf$argmax(y_conv,1L), tf$argmax(y_,1L))
accuracy <- tf$reduce_mean(tf$cast(correct_prediction, tf$float32))
sess$run(tf$global_variables_initializer())
for (iin 1:20000) {
batch <- mnist$train$next_batch(50L)
if (i%% 100 == 0) {
train_accuracy <- accuracy$eval(feed_dict= dict(
x= batch[[1]], y_= batch[[2]], keep_prob= 1.0))
cat(sprintf("step %d, training accuracy %gn", i, train_accuracy))
}
train_step$run(feed_dict= dict(
x= batch[[1]], y_= batch[[2]], keep_prob= 0.5))
}
test_accuracy <- accuracy$eval(feed_dict= dict(
x= mnist$test$images, y_= mnist$test$labels, keep_prob= 1.0))
cat(sprintf("test accuracy %g", test_accuracy))
除非你熟悉Tensorflow,,否则你可能不清楚发生了什么。Github的快速搜索发现使用tensorflow为R语言提供的代码不到100个。
Keras
所有这一切都将随着Keras和R语言而改变。
Keras是一个用于实验的高级神经网络API,可以在Tensorflow上运行。Keras是科学家们喜欢使用的数据。Keras越来越受欢迎,并在越来越多的平台上得到支持,包括Tensorflow,CNTK,Apple的CoreML,Theano。在深度学习中越来越重要。
举一个简单的例子,在Keras中训练模型的代码如下:
model_top%>% fit(
x= train_x, y= train_y,
epochs=epochs,
batch_size=batch_size,
validation_data=valid)
用Keras进行图像分类
让我告诉你如何使用R语言、Keras和Tensorflow构建深度学习模型。你会发现一个Github repo在https://github.com/rajshah4/image_keras/,其中包含你需要的代码和数据。通过R notebook(和Python notebooks)构建一个图像分类器,可以很容易地应用到其他图像上。用于构建深度学习工作的高级方法包括:
增加的数据
使用预先训练的网络的瓶颈特征
对预先训练的网络顶层进行微调
保存模型的权重
Keras的代码片段
Keras的R语言接口确实可以很容易地在R语言中构建深度学习模型,这里有基于构建图像分类器一些代码片段,以说明R语言中Keras的直观和有用
加载folder:
train_generator <- flow_images_from_directory(train_directory, generator= image_data_generator(), target_size= c(img_width, img_height), color_mode= "rgb",
class_mode= "binary", batch_size= batch_size, shuffle= TRUE,
seed= 123)
定义一个简单的卷积神经网络:
model <- keras_model_sequential()
model%>%
layer_conv_2d(filter = 32, kernel_size= c(3,3), input_shape= c(img_width, img_height,3))%>%
layer_activation("relu")%>%
layer_max_pooling_2d(pool_size= c(2,2))%>%
layer_conv_2d(filter = 32, kernel_size= c(3,3))%>%
layer_activation("relu")%>%
layer_max_pooling_2d(pool_size= c(2,2))%>%
layer_conv_2d(filter = 64, kernel_size= c(3,3))%>%
layer_activation("relu")%>%
layer_max_pooling_2d(pool_size= c(2,2))%>%
layer_flatten()%>%
layer_dense(64)%>%
layer_activation("relu")%>%
layer_dropout(0.5)%>%
layer_dense(1)%>%
layer_activation("sigmoid")
增加数据:
augment <- image_data_generator(rescale=1./255,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=TRUE)
加载预先训练的网络:
model_vgg <- application_vgg16(include_top= FALSE, weights= "imagenet")
保存模型权重:
save_model_weights_hdf5(model_ft,'finetuning_30epochs_vggR.h5', overwrite= TRUE)
R语言中的Keras接口使R语言用户更容易构建和细化深度学习模型。不再强迫使用Python构建、精炼和测试深度学习模型。这应该向对使用python有点担心的受众开放。
首先,您可以应用我的repo,启动RStudio(或者您选择的IDE),然后使用Keras构建一个简单的分类器。
本文为编译文章,作者Rajiv Shah,原网址为
http://projects.rajivshah.com/blog/2017/06/04/deeplearningR/
- 两个四字母域名均以五位数被交易
- Flash/Flex学习笔记(15):FMS 3.5之远程共享对象(Remote Shared Object)
- Android Fragment完全解析
- Centos下堡垒机Jumpserver V3.0环境部署完整记录(2)-配置篇
- Flash/Flex学习笔记(53):利用FMS快速创建一个文本聊天室
- 28家银行用户体验调研报告:洞见银行业的“进化论”
- 性能计数器数据收集服务
- SQL SERVER 内存分配及常见内存问题 DMV查询
- 6 利用Docker .NET应用程序模板制作您的容器应用程序(第2部分)
- Mesos+Zookeeper+Marathon的Docker管理平台部署记录(1)
- git review报错一例
- Nginx采用https加密访问后出现的问题
- 对比git rm和rm的使用区别
- Gerrit日常操作命令收集
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- Spring 基于注解(annotation)的配置之@Required注解
- 由一个系统激活工具引起的一次简单测试
- Golang channel 快速入门
- 潘石屹用Python解决100个问题 | 素数
- Spring 自动装配模式之构造函数装配方式
- 安全狗 {safedog} 最新版注入bypass
- C语言定时关机小程序
- 深入k8s:Pod对象中重要概念及用法
- Golang语言排序的几种方式
- 性能分析(1)- Java 进程导致 CPU 使用率升高,问题怎么定位?
- 安全服务之安全基线及加固(三)Apache篇
- 使用docsify来管理文献
- Cypress系列(41)- Cypress 的测试报告
- SSRF绕过
- 性能测试必备知识(6)- 如何查看“CPU 上下文切换”