深度学习Matlab工具箱代码注释之cnntrain.m
时间:2022-04-24
本文章向大家介绍深度学习Matlab工具箱代码注释之cnntrain.m,主要内容包括其使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。
%%=========================================================================
%函数名称:cnntrain()
%输入参数:net,神经网络;x,训练数据矩阵;y,训练数据的标签矩阵;opts,神经网络的相关训练参数
%输出参数:net,训练完成的卷积神经网络
%算法流程:1)将样本打乱,随机选择进行训练;
% 2)取出样本,通过cnnff2()函数计算当前网络权值和网络输入下网络的输出
% 3)通过BP算法计算误差对网络权值的导数
% 4)得到误差对权值的导数后,就通过权值更新方法去更新权值
%注意事项:1)使用BP算法计算梯度
%%=========================================================================
function net = cnntrain(net, x, y, opts)
m = size(x, 3); %m保存的是训练样本个数
disp(['样本总个数=' num2str(m)]);
numbatches = m / opts.batchsize; %numbatches表示每次迭代中所选取的训练样本数
if rem(numbatches, 1) ~= 0 %如果numbatches不是整数,则程序发生错误
error('numbatches not integer');
end
%%=====================================================================
%主要功能:CNN网络的迭代训练
%实现步骤:1)通过randperm()函数将原来的样本顺序打乱,再挑出一些样本来进行训练
% 2)取出样本,通过cnnff2()函数计算当前网络权值和网络输入下网络的输出
% 3)通过BP算法计算误差对网络权值的导数
% 4)得到误差对权值的导数后,就通过权值更新方法去更新权值
%注意事项:1)P = randperm(N),返回[1, N]之间所有整数的一个随机的序列,相当于把原来的样本排列打乱,
% 再挑出一些样本来训练
% 2)采用累积误差的计算方式来评估当前网络性能,即当前误差 = 以前误差 * 0.99 + 本次误差 * 0.01
% 使得网络尽可能收敛到全局最优
%%=====================================================================
net.rL = []; %代价函数值,也就是误差值
for i = 1 : opts.numepochs %对于每次迭代
disp(['epoch ' num2str(i) '/' num2str(opts.numepochs)]);
tic; %使用tic和toc来统计程序运行时间
%%%%%%%%%%%%%%%%%%%%取出打乱顺序后的batchsize个样本和对应的标签 %%%%%%%%%%%%%%%%%%%%
kk = randperm(m);
for l = 1 : numbatches
batch_x = x(:, :, kk((l - 1) * opts.batchsize + 1 : l * opts.batchsize));
batch_y = y(:, kk((l - 1) * opts.batchsize + 1 : l * opts.batchsize));
%%%%%%%%%%%%%%%%%%%%在当前的网络权值和网络输入下计算网络的输出(特征向量)%%%%%%%%%%%%%%%%%%%%
net = cnnff(net, batch_x); %卷积神经网络的前馈运算
%%%%%%%%%%%%%%%%%%%%通过对应的样本标签用bp算法来得到误差对网络权值的导数%%%%%%%%%%%%%%%%%%%%
net = cnnbp(net, batch_y); %卷积神经网络的BP算法
%%%%%%%%%%%%%%%%%%%%通过权值更新方法去更新权值%%%%%%%%%%%%%%%%%%%%
net = cnnapplygrads(net, opts);
if isempty(net.rL)
net.rL(1) = net.L; %代价函数值,也就是均方误差值 ,在cnnbp.m中计算初始值 net.L = 1/2* sum(net.e(:) .^ 2) / size(net.e, 2);
end
net.rL(end + 1) = 0.99 * net.rL(end) + 0.01 * net.L; %采用累积的方式计算累积误差
end
toc;
end
end
量化投资与机器学习
知识、能力、深度、专业
勤奋、天赋、耐得住寂寞
- 男程序员是不是都不会和女生表达交流?程序员的回答歪了
- Silverlight Telerik控件学习:主题Theme切换
- Silverlight自定义类库实现应用程序缓存
- Silverlight Telerik控件学习:TreeView数据绑定并初始化选中状态、PanelBar的Accordion效果、TabPanel、Frame基本使用
- 这或许是对小白最友好的python入门了吧——4,列表
- 每个人都应该知道的十个机器学习常识
- 重新带你了解React.js
- WebService又一个不爽的地方
- 劲爆!小程序又增新功能!为落地微信智慧零售方案做铺垫!
- 5G光传送网技术
- 突破封闭 Web 系统的技巧之正面冲锋
- 建立本地的Blast数据库
- [biztalk笔记]-1.Hello World!
- 人工智能对政府意味着什么
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- Linux Used内存到底哪里去了?
- 浙大版《C语言程序设计(第3版)》题目集 习题6-6 使用函数输出一个整数的逆序数
- 浙大版《C语言程序设计(第3版)》题目集 练习8-2 计算两数的和与差
- SQL查找是否"存在",别再count了!
- 浙大版《C语言程序设计(第3版)》题目集 练习8-8 移动字母
- 超赞!墙裂推荐这款开源、轻量无 Agent 自动化运维平台
- 详解Docker中Image、Container与 Volume 的迁移
- 浙大版《C语言程序设计(第3版)》题目集 习题8-1 拆分实数的整数与小数部分
- 如何在 Linux 上恢复误删除的文件或目录
- 浙大版《C语言程序设计(第3版)》题目集 习题8-2 在数组中查找指定元素
- Pandas学习笔记之时间序列总结
- HTML+JS动态获取当前时间
- HTML+JS实现时钟
- SQL-spj库创建脚本
- Sublime Text3 通过Package Control安装插件时找不到可用安装包的解决方法