梯度是如何计算的
引言
深度学习模型的训练本质上是一个优化问题,而常采用的优化算法是梯度下降法(SGD)。对于SGD算法,最重要的就是如何计算梯度。此时,估计跟多人会告诉你:采用BP(backpropagation)算法,这没有错,因为神经网络曾经的一大进展就是使用BP算法计算梯度提升训练速度。但是从BP的角度,很多人陷入了推导公式的深渊。如果你学过微积分,我相信你一定知道如何计算梯度,或者说计算导数。对于深度网络来说,其可以看成多层非线性函数的堆积,即:
而我们知道深度学习模型的优化目标L一般是output的函数,如果要你求L关于各个参数的导数,你会不假思索地想到:链式法则。因为output是一个复合函数。在微积分里面,求解复合函数的导数采用链式法则再合适不过了。其实本质上BP算法就是链式法则的一个调用。让我们先忘记BP算法,从链式法则开始说起。
链式法则
链式法则无非是将一个复杂的复合函数从上到下逐层求导,比如你要求导式子:f(x,y,z)=(x+y)(x+z)。当然这个例子是足够简单的,但是我们要使用链式法则的方式来求导。首先可以将f(x,y,z)看成两个函数p(x,y)=(x+y)与q(x,z)=(x+z)的复合:f=pq。假如你要求df/dz,首先我们先要求出df/dp与df/dq。很显然,df/dp=q,df/dq=p。然后要求dp/dx与dq/dx,显然,dp/dx=1.0,dq/dx=1.0。这个时候已经求到最底层了,可以利用链式法则求出最终的结果了:df/dx=(df/dp)(dp/dx)+(df/dq)(dq/dx)=q+p。同样的方法,可以求出:df/dy=q,df/dz=p。如果大家细致观察的话,可以看到要求出最终的导数,你需要计算出中间结果:p与q。计算中间结果的过程一般是前向(forward)过程,然后再反向(backward)计算出最终的导数。过程如下:
#程序1#
#输入
x, y, z = -3, 2, 5
#执行前向过程
p = x + y #p = -1
q = x + z #q = 2
#执行反向过程计算梯度
#第一个层反向:f = pq
dfdp = q #df / dp = 2
dfdq = p #df / dq = -1
#第二个层反向,并累积第一层梯度:
#p = x + y, q = x + z
dfdx = 1.0 * dfdp + 1.0 * dfdq #df / dx = 1
dfdy = 1.0 * dfdp #df / dy = 2
dfdz = 1.0 * dfdq #df / dz = -1
上面的一个过程就是BP算法,包含两个过程:前向forward)过程与反向(backword)过程。前向过程是从输入计算得到输出,而反向过程就是一个梯度累积的过程,或者说是BP,即误差反向传播。这就是BP的思想。上面的例子应该是比较简单的,而对于深度学习模型来说,其只不过是函数复杂一点罢了,但是如果你严格按照链式法则来去推导,只要你会基本求导方法,应该都不是什么难事了。
矩阵运算
其实对于深度学习模型来说,其运算都是基于矩阵运算的。对于新手来说,矩阵运算的求导可能会是一件比较头疼的事。其实矩阵运算求导是一个纸老虎。对于元素级的矩阵运算来说,比如激活函数这种,你完全可以把看成普通的求导。但是对于矩阵乘法,你需要特别注意,这里先抛出例子:
#程序2#
import numpy as np
#前向过程
W = np.random.randn(5,10)
X = np.random.randn(10,3)
D = W.dot(X)
#反向过程
#假定dD是后面传播过来的梯度项
dD = np.random.randn(*D.shape)
dW = dD.dot(X.T)
dX = W.T.dot(dD)
如果你认真推导的话,是可以得到上面的结果的。但是这里有其它捷径。对于两个矩阵相乘的话,在反向传播时反正是另外一个项与传播过来的梯度项相乘。差别就在于位置以及翻转。这里有个小窍门,就是最后计算出梯度肯定要与原来的矩阵是同样的shape。那么这就容易了,反正组合不多。比如你要计算dW,你知道要用dD与X两个矩阵相乘就可以得到。W的shape是[5,10],而dD的shape是[5,3],X的shape是[10,3]。要保证dW与W的shape一致,好吧,此时只能用dD.dot(X.T),真的没有其它选择了,那这就是对了。
活学活用:
实现一个简单的神经网络
上面我们讲了链式法则,也讲了BP的思想,并且也讲了如何对矩阵运算求梯度。下面我们基于Python中的Numpy库实现一个简单的神经网络模型,代码如下:
#程序3#
"""
一个简单两层神经网络回归模型
"""
import numpy as np
# batch size
N = 32
# 输入维度
D = 100
# 隐含层单元数
H = 200
# 输出维度
O = 10
# 训练样本(这里随机生成)
X = np.random.randn(N, D)
y = np.random.randn(N, O)
# 初始化参数
W1 = np.random.randn(D, H)
b1 = np.zeros((H,))
W2 = np.random.randn(H, O)
b2 = np.zeros((O,))
# 训练参数
learning_rate = 1e-02
iterations = 200
# 训练过程
for t in range(iterations): # 前向过程
h = X.dot(W1) + b1
h_relu = np.maximum(h, 0)
pred = h_relu.dot(W2) + b2
#定义loss,采用均方差
loss = np.sum(np.square(y - pred))
print("Iteration %d loss: %f" % (t, loss))
# 反向过程计算梯度
dpred = 2.0 * (pred - y)
db2 = np.sum(dpred, axis=0)
dW2 = h_relu.T.dot(db2)
dh_relu = db2.dot(W2.T)
dh = (h > 0) * dh_relu
db1 = np.sum(dh, axis=0)
dW1 = X.T.dot(dh)
# SGD更新梯度
params = [W1, b1, W2, b2]
grads = [dW1, db1, dW2, db2]
for p, g in zip(params, grads):
p += -learning_rate * g
总结
这里我们简单介绍了梯度下降法中最重要的一部分,就是如何计算梯度。相信通过本文,大家对BP算法以及链式法则有更深刻的理解。
参考资料
cs231n教程:http://cs231n.github.io/optimization-2/
- 大数据平台搭建 Hadoop-2.7.4 + Spark-2.2.0 快速搭建
- ajax 设置Access-Control-Allow-Origin实现跨域访问
- gradle新建工程,多项目依赖,聚合工程
- Apache Hive-2.3.0 快速搭建与使用
- HBase-1.3.1 集群搭建 - 报错整理
- 分布式唯一ID生成器Twitter 的 Snowflake idworker java版本
- 使用 Phoenix-4.11.0连接 Hbase 集群 ,并使用 JDBC 查询测试
- 高并发分布式系统中生成全局唯一Id汇总
- ZooKeeper 可视化监控 zkui
- 关于RBAC(Role-Base Access Control)的理解
- Spring Boot 中使用 Kafka
- 如何评价一段代码
- java系统高并发的解决方案
- Spring Boot 中使用 Redis
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法