Caffe-Python-自定义网络 原
时间:2022-06-19
本文章向大家介绍Caffe-Python-自定义网络 原,主要内容包括其使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。
这里我们用一个例子先来体验一下
- 首先定义一下我们的环境变量 $PYTHONPATH,我这儿是Windows开发环境,至于Windows Caffe怎么编译由读者自己下去搞定
我使用的控制台是 Windows PowerShell
添加环境变量
$env:PATHPATH="F:caffe-pythonpython;F:caffe-windowswindowsinstallpython"
这里F:caffe-pythonpython
是我的新Layer的路径F:caffe-windowswindowsinstallpython
是我的Caffe编译以后install的路径
编写自己的TestLayer
import caffe
import numpy as np
class TestLayer(caffe.Layer):
def setup(self, bottom, top):
if len(bottom) != 1:
raise Exception("Need two inputs to compute distance.")
def reshape(self, bottom, top):
print("-----------------1---------------------")
top[0].reshape(1)
def forward(self, bottom, top):
top[0].data[...] = bottom[0].data
print("-----------------2---------------------")
def backward(self, top, propagate_down, bottom):
bottom[...].data=top[0].data
pass
- 官方给出的一个例子
import caffe
import numpy as np
class EuclideanLossLayer(caffe.Layer):
def setup(self, bottom, top):
# 输入检查
if len(bottom) != 2:
raise Exception("Need two inputs to compute distance.")
def reshape(self, bottom, top):
# 输入检查
if bottom[0].count != bottom[1].count:
raise Exception("Inputs must have the same dimension.")
# 初始化梯度差分zeros_like函数的意义是创建一个与参数等大小的全0矩阵
self.diff = np.zeros_like(bottom[0].data, dtype=np.float32)
# loss 输出(loss是一个标量)
top[0].reshape(1)
#前向传播(计算loss bottom[0].data是第一个输入 bottom[1].data是第二个输入)
#注意:前向传播是输出top
def forward(self, bottom, top):
self.diff[...] = bottom[0].data - bottom[1].data
top[0].data[...] = np.sum(self.diff**2) / bottom[0].num / 2.
#后向传播
#注意:前向传播是输出到bottom
def backward(self, top, propagate_down, bottom):
for i in range(2):
if not propagate_down[i]:
continue
if i == 0:
sign = 1
else:
sign = -1
#误差向后扩散
bottom[i].diff[...] = sign * self.diff / bottom[i].num
编写完我们的ayers以后写出网络结构
name: "TEST"
layer {
name: "cifar"
type: "Data"
top: "data"
top: "label"
include {
phase: TRAIN
}
transform_param {
mean_file: "examples/cifar10/Release/cifar10/mean.binaryproto"
}
data_param {
source: "examples/cifar10/Release/cifar10/cifar10_train_lmdb"
batch_size: 100
backend: LMDB
}
}
layer {
name: "cifar"
type: "Data"
top: "data"
top: "label"
include {
phase: TEST
}
transform_param {
mean_file: "examples/cifar10/Release/cifar10/mean.binaryproto"
}
data_param {
source: "examples/cifar10/Release/cifar10/cifar10_test_lmdb"
batch_size: 100
backend: LMDB
}
}
layer {
name: "test1"
type: "Python"
bottom: "data"
top: "test1"
python_param {
module: "test_layer"
layer: "Test_Layer"
}
}
可视化我们的网络结构以后如图
编写solver
net: "F:/caffe-python/python/test_layer.prototxt"
base_lr: 0.001
lr_policy: "fixed"
max_iter: 10
solver_mode: CPU
接下来在powershell里面去启动caffe 先cd到caffe所在的目录 我的目录是这样的
cd F:Smart_Classroom3rdpartyALLPLATHFORMcaffe-windowswindowsexamplescifar10Release
然后执行caffe
./caffe.exe train --solver=F:/caffe-python/python/test_python_layer_solver.prototxt
如下图:
在后向和前向传播的过程中我们成功的调用了两个print 至此,编写自己的Caffe层就成功了
PS: 编写的时候严格注意路径否则会出现以下报错
- JavaScript中removeEventListener()使用注意事项
- dubbox REST服务使用fastjson替换jackson
- struts2(二)之配置文件详解与结果视图
- CSS魔法堂:你真的懂text-align吗?
- 黑客可以利用传感器数据来破解手机密码
- spring-boot 速成(3) actuator
- 利用sharding-jdbc分库分表
- 利用sharding-jdbc分库分表
- 协议森林17 我和你的悄悄话 (SSL/TLS协议)
- spring-boot 速成(1) helloworld
- spring-boot 速成(1) helloworld
- 协议森林16 小美的桌号(DHCP协议)
- struts2(一)之初识struts2
- AI聊天机器人备受青睐 专家呼吁少卖萌
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- 跟牛老师一起学WEBGIS——WEBGIS基础(地图切片)
- Go语言 | 并发设计中的同步锁与waitgroup用法
- LeetCode 99 | 如何不用递归遍历二叉搜索树?MT方法给你答案
- 以攻击者角度学习某风控设备指纹产品
- 高并发系统三大利器之缓存
- 前端测试题:(解析)js中关于类(class)的继承的说法,下面错误的是?
- 程序员深夜惨遭老婆鄙视,原因竟是CAS原理太简单?| 每一张图都力求精美
- MySQL数据延迟跳动的问题分析
- Python GUI项目实战(八)修改密码功能的实现
- Prometheus监控神器-Alertmanager篇(3)
- Prometheus监控神器-Alertmanager篇(4)
- 71-STM32+ESP8266+AIR202基本控制篇-移植使用-移植微信小程序MQTT底层包到自己的工程项目
- 目标检测 | Anchor free之CornerNet网络深度解析
- 手把手教你 3 分钟搞定个人网站 http 免费升级到 https
- 设计模式(四):通过做蛋糕理解构建模式及Android中的变种