【python实现卷积神经网络】池化层实现
代码来源:https://github.com/eriklindernoren/ML-From-Scratch
卷积神经网络中卷积层Conv2D(带stride、padding)的具体实现:https://www.cnblogs.com/xiximayou/p/12706576.html
激活函数的实现(sigmoid、softmax、tanh、relu、leakyrelu、elu、selu、softplus):https://www.cnblogs.com/xiximayou/p/12713081.html
损失函数定义(均方误差、交叉熵损失):https://www.cnblogs.com/xiximayou/p/12713198.html
优化器的实现(SGD、Nesterov、Adagrad、Adadelta、RMSprop、Adam):https://www.cnblogs.com/xiximayou/p/12713594.html
卷积层反向传播过程:https://www.cnblogs.com/xiximayou/p/12713930.html
全连接层实现:https://www.cnblogs.com/xiximayou/p/12720017.html
批量归一化层实现:https://www.cnblogs.com/xiximayou/p/12720211.html
包括D的平均池化和最大池化:
class PoolingLayer(Layer): """A parent class of MaxPooling2D and AveragePooling2D """ def __init__(self, pool_shape=(2, 2), stride=1, padding=0): self.pool_shape = pool_shape self.stride = stride self.padding = padding self.trainable = True def forward_pass(self, X, training=True): self.layer_input = X batch_size, channels, height, width = X.shape _, out_height, out_width = self.output_shape() X = X.reshape(batch_size*channels, 1, height, width) X_col = image_to_column(X, self.pool_shape, self.stride, self.padding) # MaxPool or AveragePool specific method output = self._pool_forward(X_col) output = output.reshape(out_height, out_width, batch_size, channels) output = output.transpose(2, 3, 0, 1) return output def backward_pass(self, accum_grad): batch_size, _, _, _ = accum_grad.shape channels, height, width = self.input_shape accum_grad = accum_grad.transpose(2, 3, 0, 1).ravel() # MaxPool or AveragePool specific method accum_grad_col = self._pool_backward(accum_grad) accum_grad = column_to_image(accum_grad_col, (batch_size * channels, 1, height, width), self.pool_shape, self.stride, 0) accum_grad = accum_grad.reshape((batch_size,) + self.input_shape) return accum_grad def output_shape(self): channels, height, width = self.input_shape out_height = (height - self.pool_shape[0]) / self.stride + 1 out_width = (width - self.pool_shape[1]) / self.stride + 1 assert out_height % 1 == 0 assert out_width % 1 == 0 return channels, int(out_height), int(out_width) class MaxPooling2D(PoolingLayer): def _pool_forward(self, X_col): arg_max = np.argmax(X_col, axis=0).flatten() output = X_col[arg_max, range(arg_max.size)] self.cache = arg_max return output def _pool_backward(self, accum_grad): accum_grad_col = np.zeros((np.prod(self.pool_shape), accum_grad.size)) arg_max = self.cache accum_grad_col[arg_max, range(accum_grad.size)] = accum_grad return accum_grad_col class AveragePooling2D(PoolingLayer): def _pool_forward(self, X_col): output = np.mean(X_col, axis=0) return output def _pool_backward(self, accum_grad): accum_grad_col = np.zeros((np.prod(self.pool_shape), accum_grad.size)) accum_grad_col[:, range(accum_grad.size)] = 1. / accum_grad_col.shape[0] * accum_grad return accum_grad_col
需要注意的是池化层是没有可学习的参数的(如果不利用带步长的卷积来代替池化的作用),还有就是池化层反向传播的过程,这里参考:https://blog.csdn.net/Jason_yyz/article/details/80003271
为了结合代码看直观些,就将其内容摘了下来:
原文地址:https://www.cnblogs.com/xiximayou/p/12720324.html
- yaffs_bitmap
- 原创 | 实战:R环境下Echart的8种可视化
- Yarn(MapReduce 2.0)下分布式缓存(DistributedCache)的注意事项
- Yaffs_guts(三)
- 浅谈 python multiprocessing(多进程)下如何共享变量
- 文件地址映射之yaffs_GetTnode
- bash/shell 解析命令行参数工具:getopts/getopt
- ssh 双机互信:免密码登录设置步骤及常见问题
- yaffs_guts(一)
- 聊聊 Java 中 HashMap 初始化的另一种方式
- 基于 Hive 的文件格式:RCFile 简介及其应用
- MapReduce 计数器简介
- 流水线乘法器
- Hive 基础(2):库、表、字段、交互式查询的基本操作
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- 谁告诉的你们Python是强类型语言!站出来,保证不打你!
- 项目内容细分
- 『算法理论学』深度学习推理加速方法之网络层与算子融合
- 【剑指Offer】链表中倒数第k个字节
- pyplot只有两个数值做barplot
- 两个矩阵对应位置相除
- 基于openresty的URL 断路器/熔断器 -- URL-fuse
- 温故知新——Spring AOP(二)
- 2020年,你应该知道 23 个非常有用的 NodeJs 库
- 面试题系列第4篇:重写了equals方法,为什么还要重写hashCode方法?
- LPC17XX之SSP0/1接口
- 宝塔打开ssl面板后打不开登录界面的解决方法
- 记录一下vuedraggable clone的坑,获取数据
- Go语言|go version命令的高级用法
- 基于 git flow + gitlab 协作开发:01