深度学习Matlab工具箱代码注释之cnnbp.m
时间:2022-04-24
本文章向大家介绍深度学习Matlab工具箱代码注释之cnnbp.m,主要内容包括其使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。
%%=========================================================================
%函数名称:cnnbp()
%输入参数:net,呆训练的神经网络;y,训练样本的标签,即期望输出
%输出参数:net,经过BP算法训练得到的神经网络
%主要功能:通过BP算法训练神经网络参数
%实现步骤:1)将输出的残差扩展成与最后一层的特征map相同的尺寸形式
% 2)如果是卷积层,则进行上采样
% 3)如果是下采样层,则进行下采样
% 4)采用误差传递公式对灵敏度进行反向传递
%注意事项:1)从最后一层的error倒推回来deltas,和神经网络的BP十分相似,可以参考“UFLDL的反向传导算法”的说明
% 2)在fvd里面保存的是所有样本的特征向量(在cnnff.m函数中用特征map拉成的),所以这里需要重新换回来特征map的形式,
% d保存的是delta,也就是灵敏度或者残差
% 3)net.o .* (1 - net.o))代表输出层附加的非线性函数的导数,即sigm函数的导数
%%=========================================================================
function net = cnnbp(net, y)
n = numel(net.layers); %网络层数
net.e = net.o - y; %实际输出与期望输出之间的误差
net.L = 1/2* sum(net.e(:) .^ 2) / size(net.e, 2); %代价函数,采用均方误差函数作为代价函数
net.od = net.e .* (net.o .* (1 - net.o)); %输出层的灵敏度或者残差,(net.o .* (1 - net.o))代表输出层的激活函数的导数
net.fvd = (net.ffW' * net.od); %残差反向传播回前一层,net.fvd保存的是残差
if strcmp(net.layers{n}.type, 'c') %只有卷积层采用sigm函数
net.fvd = net.fvd .* (net.fv .* (1 - net.fv)); %net.fv是前一层的输出(未经过simg函数),作为输出层的输入
end
%%%%%%%%%%%%%%%%%%%%将输出的残差扩展成与最后一层的特征map相同的尺寸形式%%%%%%%%%%%%%%%%%%%%
sa = size(net.layers{n}.a{1}); %最后一层特征map的大小。这里的最后一层都是指输出层的前一层
fvnum = sa(1) * sa(2); %因为是将最后一层特征map拉成一条向量,所以对于一个样本来说,特征维数是这样
for j = 1 : numel(net.layers{n}.a) %最后一层的特征map的个数
net.layers{n}.d{j} = reshape(net.fvd(((j - 1) * fvnum + 1) : j * fvnum, :), sa(1), sa(2), sa(3));
end
for l = (n - 1) : -1 : 1 %对于输出层前面的层(与输出层计算残差的方式不同)
if strcmp(net.layers{l}.type, 'c') %如果是卷积层,则进行上采样
for j = 1 : numel(net.layers{l}.a) %该层特征map的个数
%%=========================================================================
%主要功能:卷积层的灵敏度误差传递
%注意事项:1)net.layers{l}.d{j} 保存的是 第l层 的 第j个 map 的 灵敏度map。 也就是每个神经元节点的delta的值
% expand的操作相当于对l+1层的灵敏度map进行上采样。然后前面的操作相当于对该层的输入a进行sigmoid求导
% 这条公式请参考 Notes on Convolutional Neural Networks
%%=========================================================================
net.layers{l}.d{j} = net.layers{l}.a{j} .* (1 - net.layers{l}.a{j}) .* (expand(net.layers{l + 1}.d{j}, [net.layers{l + 1}.scale net.layers{l + 1}.scale 1]) / net.layers{l + 1}.scale ^ 2);
end
elseif strcmp(net.layers{l}.type, 's') %如果是下采样层,则进行下采样
%%=========================================================================
%主要功能:下采样层的灵敏度误差传递
%注意事项:1)这条公式请参考 Notes on Convolutional Neural Networks
%%=========================================================================
for i = 1 : numel(net.layers{l}.a) %第i层特征map的个数
z = zeros(size(net.layers{l}.a{1}));
for j = 1 : numel(net.layers{l + 1}.a) %第l+1层特征map的个数
z = z + convn(net.layers{l + 1}.d{j}, rot180(net.layers{l + 1}.k{i}{j}), 'full');
end
net.layers{l}.d{i} = z;
end
end
end
%%=========================================================================
%主要功能:计算梯度
%实现步骤:
%注意事项:1)这里与Notes on Convolutional Neural Networks中不同,这里的子采样层没有参数,也没有
% 激活函数,所以在子采样层是没有需要求解的参数的
%%=========================================================================
for l = 2 : n
if strcmp(net.layers{l}.type, 'c')
for j = 1 : numel(net.layers{l}.a)
for i = 1 : numel(net.layers{l - 1}.a)
%%%%%%%%%%%%%%%%%%%%dk保存的是误差对卷积核的导数%%%%%%%%%%%%%%%%%%%%
net.layers{l}.dk{i}{j} = convn(flipall(net.layers{l - 1}.a{i}), net.layers{l}.d{j}, 'valid') / size(net.layers{l}.d{j}, 3);
end
%%%%%%%%%%%%%%%%%%%%db保存的是误差对于bias基的导数%%%%%%%%%%%%%%%%%%%%
net.layers{l}.db{j} = sum(net.layers{l}.d{j}(:)) / size(net.layers{l}.d{j}, 3);
end
end
end
%%%%%%%%%%%%%%%%%%%%最后一层perceptron的gradient的计算%%%%%%%%%%%%%%%%%%%%
net.dffW = net.od * (net.fv)' / size(net.od, 2);
net.dffb = mean(net.od, 2);
function X = rot180(X)
X = flipdim(flipdim(X, 1), 2);
end
end
量化投资与机器学习
知识、能力、深度、专业
勤奋、天赋、耐得住寂寞
- GET/POST/g和钩子函数(hook)
- cookie和session
- Python Flask模块
- Java直接内存与非直接内存性能测试
- Elasticsearch——multi termvectors的用法
- Elasticsearch增删改查 之 —— Delete删除
- Elasticsearch增删改查 之 —— Get查询
- 实现两个N*N矩阵的乘法,矩阵由一维数组表示
- Elasticsearch入门必备——ES中的字段类型以及常用属性
- C++容器与算法
- Effective c++ 小结
- Java程序员的日常—— Properties文件的读写
- Java程序员的日常——经验贴(纯干货)二
- Elasticsearch——使用_cat查看Elasticsearch状态
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- 统计字符串中字符出现的次数-Python
- 类和对象的哲学思考
- MySQL进阶篇(03):合理的使用索引结构和查询
- 动态查看及加载PHP扩展
- 少有人知的 Python "重试机制"
- Gradle之恋(14)-实战spring mvc插件式多项目构建
- Spring Boot 开箱即用
- Qt音视频开发2-vlc回调处理
- cmake学习
- 数据分析:在缓慢变化中寻找跳变——基于缓慢变化维度的用户分群
- 02 Confluent_Kafka权威指南 第二章:安装kafka
- Kafka集群搭建过程(kafka2.5+eagle)
- 07 Confluent_Kafka权威指南 第七章: 构建数据管道
- java中的reference(二): jdk1.8中Reference的源码阅读
- 08 Confluent_Kafka权威指南 第八章:跨集群数据镜像