mxnet框架样本,使用C++接口
时间:2022-05-06
本文章向大家介绍mxnet框架样本,使用C++接口,主要内容包括其使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。
哇塞,好久么有跟进mxnet啦,python改版了好多好多啊,突然发现C++用起来才是最爽的. 贴一个mxnet中的C++Example中的mlp网络和实现,感觉和python对接毫无违和感。真是一级棒呐.
//
// Created by xijun1 on 2017/12/8.
//
#include <iostream>
#include <vector>
#include <string>
#include <mxnet/mxnet-cpp/MxNetCpp.h>
#include <mxnet/mxnet-cpp/op.h>
namespace mlp{
template < typename T , typename U >
class MLP{
public:
static mx_float OutputAccuracy(mx_float* pred, mx_float* target) {
int right = 0;
for (int i = 0; i < 128; ++i) {
float mx_p = pred[i * 10 + 0];
float p_y = 0;
for (int j = 0; j < 10; ++j) {
if (pred[i * 10 + j] > mx_p) {
mx_p = pred[i * 10 + j];
p_y = j;
}
}
if (p_y == target[i]) right++;
}
return right / 128.0;
}
static bool train(T x , U y);
static bool predict(T x);
static bool net() {
using mxnet::cpp::Symbol;
using mxnet::cpp::NDArray;
Symbol x = Symbol::Variable("X");
Symbol y = Symbol::Variable("label");
std::vector<std::int32_t> shapes({512 , 10});
//定义一个两层的网络. wx + b
Symbol weight_0 = Symbol::Variable("weight_0");
Symbol biases_0 = Symbol::Variable("biases_0");
Symbol fc_0 = mxnet::cpp::FullyConnected("fc_0",x,weight_0,biases_0
,512);
Symbol output_0 = mxnet::cpp::LeakyReLU("relu_0",fc_0,mxnet::cpp::LeakyReLUActType::kLeaky);
Symbol weight_1 = Symbol::Variable("weight_1");
Symbol biases_1 = Symbol::Variable("biases_1");
Symbol fc_1 = mxnet::cpp::FullyConnected("fc_1",output_0,weight_1,biases_1,10);
Symbol output_1 = mxnet::cpp::LeakyReLU("relu_1",fc_1,mxnet::cpp::LeakyReLUActType::kLeaky);
Symbol pred = mxnet::cpp::SoftmaxOutput("softmax",output_1,y); //目标函数,loss函数
//定义使用计算驱动
mxnet::cpp::Context ctx = mxnet::cpp::Context::cpu( 0);
NDArray arr_x(mxnet::cpp::Shape( 128 , 28 ) , ctx , false);
NDArray arr_y(mxnet::cpp::Shape(128) , ctx , false );
//定义输入数据
std::shared_ptr< mx_float > aptr_x(new mx_float[128*28] , [](mx_float* aptr_x){ delete [] aptr_x ;});
std::shared_ptr< mx_float > aptr_y(new mx_float[128] , [](mx_float * aptr_y){ delete [] aptr_y ;});
//初始化数据
for(int i=0 ; i<128 ; i++){
for(int j=0;j<28 ; j++){
//定义x
aptr_x.get()[i*28+j]= i % 10 +0.1f;
}
//定义y
aptr_y.get()[i]= i % 10;
}
//将数据转换到NDArray中
arr_x.SyncCopyFromCPU(aptr_x.get(),128*28);
arr_x.WaitToRead();
arr_y.SyncCopyFromCPU(aptr_y.get(),128);
arr_y.WaitToRead();
//定义各个层参数的数组
NDArray arr_w_0(mxnet::cpp::Shape(512,28),ctx, false);
NDArray arr_b_0(mxnet::cpp::Shape( 512 ),ctx,false);
NDArray arr_w_1(mxnet::cpp::Shape(10 , 512 ) , ctx , false);
NDArray arr_b_1(mxnet::cpp::Shape( 10 ) , ctx , false);
//初始化权重参数
arr_w_0 = 0.01f;
arr_b_1 = 0.01f;
arr_w_1 = 0.01f;
arr_b_1 = 0.01f;
//求解梯度
NDArray arr_w_0_g(mxnet::cpp::Shape( 512 , 28 ),ctx, false);
NDArray arr_b_0_g(mxnet::cpp::Shape( 512 ) , ctx , false);
NDArray arr_w_1_g(mxnet::cpp::Shape( 10 , 512 ) , ctx , false);
NDArray arr_b_1_g(mxnet::cpp::Shape( 10 ) , ctx , false);
//将数据绑定到网络图中.
//输入数据参数
std::vector< NDArray > bind_data;
bind_data.push_back( arr_x );
bind_data.push_back( arr_w_0 );
bind_data.push_back( arr_b_0 );
bind_data.push_back( arr_w_1 );
bind_data.push_back( arr_b_1 );
bind_data.push_back( arr_y );
//所有的梯度参数
std::vector< NDArray > arg_grad_store;
arg_grad_store.push_back( NDArray() ); //不需要输入的梯度
arg_grad_store.push_back( arr_w_0_g );
arg_grad_store.push_back( arr_b_0_g );
arg_grad_store.push_back( arr_w_1_g );
arg_grad_store.push_back( arr_b_1_g );
arg_grad_store.push_back( NDArray() ); //不需要输出 loss 的梯度
//如何操作梯度.
std::vector< mxnet::cpp::OpReqType > grad_req_type;
grad_req_type.push_back(mxnet::cpp::kNullOp);
grad_req_type.push_back(mxnet::cpp::kWriteTo);
grad_req_type.push_back(mxnet::cpp::kWriteTo);
grad_req_type.push_back(mxnet::cpp::kWriteTo);
grad_req_type.push_back(mxnet::cpp::kWriteTo);
grad_req_type.push_back(mxnet::cpp::kNullOp);
//定义一个状态数组
std::vector< NDArray > aux_states;
std::cout<<" make the Executor"<<std::endl;
std::shared_ptr<mxnet::cpp::Executor > executor
= std::make_shared<mxnet::cpp::Executor>(
pred,
ctx,
bind_data,
arg_grad_store,
grad_req_type,
aux_states );
//训练
std::cout<<" Training "<<std::endl;
int max_iters = 20000; //最大迭代次数
mx_float learning_rate = 0.0001; //学习率
for (int iter = 0; iter < max_iters ; ++iter) {
executor->Forward(true);
if(iter % 100 == 0){
std::vector<NDArray> & out = executor->outputs;
std::shared_ptr<mx_float> tp_x( new mx_float[128*28] ,
[](mx_float * tp_x){ delete [] tp_x ;});
out[0].SyncCopyToCPU(tp_x.get(),128*10);
NDArray::WaitAll();
std::cout<<"epoch "<<iter<<" "<<"Accuracy: "<< OutputAccuracy(tp_x.get() , aptr_y.get())<<std::endl;
}
//依据梯度更新参数
executor->Backward();
for (int i = 1; i <5 ; ++i) {
bind_data[i] -= arg_grad_store[i]*learning_rate;
}
NDArray::WaitAll();
}
}
static bool SetDriver();
};
template <typename T , typename U >
bool MLP<T,U>::SetDriver() {
return true;
}
template <typename T , typename U >
bool MLP<T,U>::train(T x, U y) {
return true;
}
template <typename T , typename U >
bool MLP<T,U>::predict(T x) {
return true;
}
}
int main(int argc , char * argv[]){
mlp::MLP<mx_float ,mx_uint>::net();
MXNotifyShutdown();
return 0;
}
结果:
poch 18900 Accuracy: 0.703125 epoch 19000 Accuracy: 0.703125 epoch 19100 Accuracy: 0.703125 epoch 19200 Accuracy: 0.703125 epoch 19300 Accuracy: 0.703125 epoch 19400 Accuracy: 0.703125 epoch 19500 Accuracy: 0.703125 epoch 19600 Accuracy: 0.703125 epoch 19700 Accuracy: 0.703125 epoch 19800 Accuracy: 0.703125 epoch 19900 Accuracy: 0.703125
- android 减少图片出现oom错误
- android分包方案
- 系统负载能力浅析
- 微软正式发布了Microsoft.Bcl.Async
- parcel和parcelable
- Windows Phone 7 WebBrowser 中文乱码问题
- Java并发包类总览
- 作业调度框架 Quartz.NET 2.0 beta 发布
- 系统捕获异常并发送到服务器
- 当调用GetAuthorizationGroups() 的错误-“试图访问卸载的应用程序域“(Exception from HRESULT: 0x80131014)解决方案
- WCF 4.0路由服务Routing Service
- ExpandableListView简单应用及listview模拟ExpandableListView
- 文件句柄与文件描述符
- android GifView分享
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- SpringMVC系列 MVC设计模式介绍+ SpringMVC的作用及其基本使用+组件解析+注解解析
- Spring系列之事务的控制 注解实现+xml实现+事务的隔离等级
- Greenplum集群扩容总结
- Leetcode刷题 237. 删除链表中的节点 两行代码实现
- Leetcode刷题 206. 反转链表 递归迭代两种方法实现
- MySQL索引和查询优化
- Elasticsearch:Index 生命周期管理入门
- springboot面试杀手锏-自动配置原理
- flink 1.11.2 学习笔记(1)-wordCount
- 我是如何开发维护8千多行代码组件的
- 我对JS延迟异步脚本的思考
- 大数据表查询优化 - 表分区
- 日志系统rsync和日志切割logrotate-Linux每日一练(9)
- Canvas 绘制点线相交
- Canvas监测雷达