神经网络之反向传播
上节课给大家简单介绍了神经网络,并且重点介绍了神经网络的前向传播工作原理。可能有些同学觉得难,因为上节课涉及到一些矩阵运算,以前没有学过线性代数的同学可能就看不懂了。这里想告诉大家的是,深度学习确实是需要数学基础的,接下来还会有不少求导(偏导)、向量以及矩阵运算等等,要求掌握高数、线性代数等学科知识,所以深度学习不是人人都适合学的。
如果确实没具备这些数学知识但又很想入门深度学习的同学,老shi给大家的建议是,可以从现在开始有针对性地去补一下相关的知识,从现在开始什么都不晚!因为老shi不可能什么都教你,这些知识只能靠你自己课后去自学!记住,学习这件事,三分靠老shi,七分靠自己。
好了,说了这么多,言归正传,本节课会在上节课的基础上继续给大家介绍神经网络的反向传播工作原理。反向传播??没错,反向传播!
反向传播的计算
反向传播是在前向传播的基础上反向传递误差的过程,假设我们使用随机梯度下降的方式来学习神经网络的参数,损失函数定义为
,其中y是样本的真实标签。使用梯度下降进行参数的学习,我们必须计算出损失函数关于神经网络中各层参数(权重w和偏置b)的偏导数。
假设我们要对第k层隐藏层的参数
和
求偏导,即
和
,假设
代表第k层神经元的输入,即
,其中
为前一层神经元的输出,根据链式法则有:
因此我们只需要计算偏导数
、
和
。
1.1计算偏导数
由
可得,
1.2计算偏导数
因为偏置b为一个常数,所以其偏导数计算如下:
1.3计算偏导数
偏导数
又称为误差项,一般用符号
表示,例如
是第一层神经元的误差,它的大小代表了第一层神经元对最终总误差的影响大小。
根据上节课前向传播的计算公式,我们可以得到第k+1层的输入与第k层的输出之间的关系为:
又因为
,根据链式法则,我们可以得到
为:
由上式我们可以看到,第k层神经元的误差项
是由第k+1层的误差项乘以第k+1层的权重,再乘以第k层激活函数的导数(梯度)得到的。这就是误差的反向传播。
现在我们已经计算出了偏导数
、
和
,则
和
又可以分别表示为:
前面我们已经计算出了第k层的误差项
,现在我们要利用每一层的误差项和梯度来更新每一层的参数,权重w和偏置b的更新公式如下:
注:
表示学习率,它的大小表示参数w和b更新速度的快慢。
下图表达了反向传播误差的传递过程,图中的数字对应上节课前向传播网络图中的权重w和偏置b,有兴趣的同学可以自己代入公式计算。
Ok,枯燥的公式推导终于结束,哈哈~但并不表示接下来就轻松了?。还是那句话,纸上得来终觉浅,绝知此事要躬行!最后附上神经网络反向传播部分代码,跟着好好敲一遍代码,你的收获一定会比别人多!!
——————分割线————————
# 反向传播
def backward_propagation(parameters, cache, X, Y):
m = X.shape[1]
w1 = parameters['w1']
w2 = parameters['w2']
A1 = cache['A1']
A2 = cache['A2']
dZ2 = A2 - Y
dW2 = 1 / m * np.dot(dZ2, A1.T)
db2 = 1 / m * np.sum(dZ2, axis=1, keepdims=True)
dZ1 = np.dot(w2.T, dZ2) * (1 - np.power(A1, 2))
dW1 = 1 / m * np.dot(dZ1, X.T)
db1 = 1 / m * np.sum(dZ1, axis=1, keepdims=True)
grads = {'dW1': dW1,
'db1': db1,
'dW2': dW2,
'db2': db2}
return grads
# 更新参数
def update_parameters(parameters, grads, learning_rate=1.2):
w1 = parameters['w1']
b1 = parameters['b1']
w2 = parameters['w2']
b2 = parameters['b2']
dW1 = grads['dW1']
db1 = grads['db1']
dW2 = grads['dW2']
db2 = grads['db2']
parameters = {'w1': w1,
'b1': b1,
'w2': w2,
'b2': b2}
return parameters
# 模型训练
def my_model(X, Y, n_h, num_iterations=10000, print_cost=False):
np.random.seed(3)
n_x = layer_sizes(X, Y)[0]
n_y = layer_sizes(X, Y)[2]
parameters = initialize_parameters(n_x, n_h, n_y)
w1 = parameters['w1']
b1 = parameters['b1']
w2 = parameters['w2']
b2 = parameters['b2']
for i in range(0, num_iterations):
A2, cache = forward_propagation(X, parameters)
cost = compute_cost(A2, Y, parameters)
grads = backward_propagation(parameters, cache, X, Y)
parameters = update_parameters(parameters, grads, learning_rate=1.2)
if print_cost and i % 1000 == 0:
print('Cost after iteration %i: %f' % (i, cost))
return parameters
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- springBoot 入门(五)—— 使用 纯注解方式 的springboot+ mybatis+junit4 整合
- 常见加载类错误分析
- 常用的ClassLoader分析
- 如何实现自己的ClassLoader
- Hacking with iOS: SwiftUI Edition - 潜力客户名单项目(三)
- 大型项目技术栈第四讲 SQL语句构建器
- redis 入门(一)——Linux环境安装测试以及基本命令演示
- 大型项目技术栈第五讲 富文本编辑器
- weblogic 11g StuckThreadMaxTime 问题解决 以及 线程池、数据库连接池参数调优
- 大型项目技术栈第九讲 kaptcha的使用
- 大型项目技术栈第十讲 日志与性能监控
- Mybatis系列第三讲 Mybatis使用详解(1)
- Maven系列第二讲 安装、配置、mvn运行过程详解
- Maven第六讲 生命周期详解 高手必备!
- 鸿蒙 Ability 讲解(页面生命周期、后台服务、数据访问)