【python实现卷积神经网络】批量归一化层实现
时间:2022-07-23
本文章向大家介绍【python实现卷积神经网络】批量归一化层实现,主要内容包括其使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。
代码来源:https://github.com/eriklindernoren/ML-From-Scratch
卷积神经网络中卷积层Conv2D(带stride、padding)的具体实现:https://www.cnblogs.com/xiximayou/p/12706576.html
激活函数的实现(sigmoid、softmax、tanh、relu、leakyrelu、elu、selu、softplus):https://www.cnblogs.com/xiximayou/p/12713081.html
损失函数定义(均方误差、交叉熵损失):https://www.cnblogs.com/xiximayou/p/12713198.html
优化器的实现(SGD、Nesterov、Adagrad、Adadelta、RMSprop、Adam):https://www.cnblogs.com/xiximayou/p/12713594.html
卷积层反向传播过程:https://www.cnblogs.com/xiximayou/p/12713930.html
全连接层实现:https://www.cnblogs.com/xiximayou/p/12720017.html
class BatchNormalization(Layer):
"""Batch normalization.
"""
def __init__(self, momentum=0.99):
self.momentum = momentum
self.trainable = True
self.eps = 0.01
self.running_mean = None
self.running_var = None
def initialize(self, optimizer):
# Initialize the parameters
self.gamma = np.ones(self.input_shape)
self.beta = np.zeros(self.input_shape)
# parameter optimizers
self.gamma_opt = copy.copy(optimizer)
self.beta_opt = copy.copy(optimizer)
def parameters(self):
return np.prod(self.gamma.shape) + np.prod(self.beta.shape)
def forward_pass(self, X, training=True):
# Initialize running mean and variance if first run
if self.running_mean is None:
self.running_mean = np.mean(X, axis=0)
self.running_var = np.var(X, axis=0)
if training and self.trainable:
mean = np.mean(X, axis=0)
var = np.var(X, axis=0)
self.running_mean = self.momentum * self.running_mean + (1 - self.momentum) * mean
self.running_var = self.momentum * self.running_var + (1 - self.momentum) * var
else:
mean = self.running_mean
var = self.running_var
# Statistics saved for backward pass
self.X_centered = X - mean
self.stddev_inv = 1 / np.sqrt(var + self.eps)
X_norm = self.X_centered * self.stddev_inv
output = self.gamma * X_norm + self.beta
return output
def backward_pass(self, accum_grad):
# Save parameters used during the forward pass
gamma = self.gamma
# If the layer is trainable the parameters are updated
if self.trainable:
X_norm = self.X_centered * self.stddev_inv
grad_gamma = np.sum(accum_grad * X_norm, axis=0)
grad_beta = np.sum(accum_grad, axis=0)
self.gamma = self.gamma_opt.update(self.gamma, grad_gamma)
self.beta = self.beta_opt.update(self.beta, grad_beta)
batch_size = accum_grad.shape[0]
# The gradient of the loss with respect to the layer inputs (use weights and statistics from forward pass)
accum_grad = (1 / batch_size) * gamma * self.stddev_inv * (
batch_size * accum_grad
- np.sum(accum_grad, axis=0)
- self.X_centered * self.stddev_inv**2 * np.sum(accum_grad * self.X_centered, axis=0)
)
return accum_grad
def output_shape(self):
return self.input_shape
批量归一化的过程:
前向传播的时候按照公式进行就可以了。需要关注的是BN层反向传播的过程。
accm_grad是上一层传到本层的梯度。反向传播过程:
- Spring基础篇——通过Java注解和XML配置装配bean
- Java多线程高并发学习笔记(二)——深入理解ReentrantLock与Condition
- 算法模板——线段树1(区间加法+区间求和)
- 【LeetCode 205】关关的刷题日记38 Isomorphic Strings
- JavaScript基础2---控制权DOM操作
- 算法模板——线段树3(区间覆盖值+区间求和)
- 算法模板——线段树4(区间加+区间乘+区间覆盖值+区间求和)
- 【LeetCode 204】关关的刷题日记39 Count Primes
- 算法模板——并查集 1
- Java 持久化操作之 --io流与序列化
- 算法模板——LCA(最近公共祖先)
- 算法模板——AC自动机
- UOJ #117. 欧拉回路
- 算法模板——左偏树(可并堆)
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- 基于TypeScript封装Axios笔记(八)
- springmvc之HttpMessageConverter<T>
- django-模板之静态文件加载(十四)
- springmvc之使用JstlView
- django-模板之include标签(十五)
- 【pytorch】改造mobilenet_v2进行multi-class classification(多标签分类)
- 走进STL - heap,小树芽
- 走进STL - 序列式容器(常用篇)
- springmvc之RequestMapping中的请求方式
- 拥抱STL - union,天作之秀
- 拥抱STL -typename该怎么理解
- 走近STL - map,只愿一键对一值
- springmvc之使用servlet原生API作为参数
- 走近STL - 填上list删除的大坑
- springmvc之RequestMapping中的请求参数和请求头