动态分组卷积-Dynamic Group Convolution for Accelerating Convolutional Neural Networks
时间:2022-07-24
本文章向大家介绍动态分组卷积-Dynamic Group Convolution for Accelerating Convolutional Neural Networks,主要内容包括其使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。
地址:https://arxiv.org/pdf/2007.04242.pdf
github:https://github.com/zhuogege1943/dgc/
from __future__ import absolute_import
from __future__ import unicode_literals
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.nn.functional as F
class DynamicMultiHeadConv(nn.Module):
global_progress = 0.0
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, heads=4, squeeze_rate=16, gate_factor=0.25):
super(DynamicMultiHeadConv, self).__init__()
self.norm = nn.BatchNorm2d(in_channels)
self.relu = nn.ReLU(inplace=True)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.in_channels = in_channels
self.out_channels = out_channels
self.heads = heads
self.squeeze_rate = squeeze_rate
self.gate_factor = gate_factor
self.stride = stride
self.padding = padding
self.dilation = dilation
self.is_pruned = True
self.register_buffer('_inactive_channels', torch.zeros(1))
### Check if arguments are valid
assert self.in_channels % self.heads == 0,
"head number can not be divided by input channels"
assert self.out_channels % self.heads == 0,
"head number can not be divided by output channels"
assert self.gate_factor <= 1.0, "gate factor is greater than 1"
for i in range(self.heads):
self.__setattr__('headconv_%1d' % i,
HeadConv(in_channels, out_channels // self.heads, squeeze_rate,
kernel_size, stride, padding, dilation, 1, gate_factor))
def forward(self, x):
"""
The code here is just a coarse implementation.
The forward process can be quite slow and memory consuming, need to be optimized.
"""
if self.training:
progress = DynamicMultiHeadConv.global_progress
# gradually deactivate input channels
if progress < 3.0 / 4 and progress > 1.0 / 12:
self.inactive_channels = round(self.in_channels * (1 - self.gate_factor) * 3.0 / 2 * (progress - 1.0 / 12))
elif progress >= 3.0 / 4:
self.inactive_channels = round(self.in_channels * (1 - self.gate_factor))
_lasso_loss = 0.0
x = self.norm(x)
x = self.relu(x)
x_averaged = self.avg_pool(x)
x_mask = []
weight = []
for i in range(self.heads):
i_x, i_lasso_loss= self.__getattr__('headconv_%1d' % i)(x, x_averaged, self.inactive_channels)
x_mask.append(i_x)
weight.append(self.__getattr__('headconv_%1d' % i).conv.weight)
_lasso_loss = _lasso_loss + i_lasso_loss
x_mask = torch.cat(x_mask, dim=1) # batch_size, 4 x C_in, H, W
weight = torch.cat(weight, dim=0) # 4 x C_out, C_in, k, k
out = F.conv2d(x_mask, weight, None, self.stride,
self.padding, self.dilation, self.heads)
b, c, h, w = out.size()
out = out.view(b, self.heads, c // self.heads, h, w)
out = out.transpose(1, 2).contiguous().view(b, c, h, w)
return [out, _lasso_loss]
@property
def inactive_channels(self):
return int(self._inactive_channels[0])
@inactive_channels.setter
def inactive_channels(self, val):
self._inactive_channels.fill_(val)
class HeadConv(nn.Module):
def __init__(self, in_channels, out_channels, squeeze_rate, kernel_size, stride=1,
padding=0, dilation=1, groups=1, gate_factor=0.25):
super(HeadConv, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride,
padding, dilation, groups=1, bias=False)
self.target_pruning_rate = gate_factor
if in_channels < 80:
squeeze_rate = squeeze_rate // 2
self.fc1 = nn.Linear(in_channels, in_channels // squeeze_rate, bias=False)
self.relu_fc1 = nn.ReLU(inplace=True)
self.fc2 = nn.Linear(in_channels // squeeze_rate, in_channels, bias=True)
self.relu_fc2 = nn.ReLU(inplace=True)
nn.init.kaiming_normal_(self.fc1.weight)
nn.init.kaiming_normal_(self.fc2.weight)
nn.init.constant_(self.fc2.bias, 1.0)
def forward(self, x, x_averaged, inactive_channels):
b, c, _, _ = x.size()
x_averaged = x_averaged.view(b, c)
y = self.fc1(x_averaged)
y = self.relu_fc1(y)
y = self.fc2(y)
mask = self.relu_fc2(y) # b, c
_lasso_loss = mask.mean()
mask_d = mask.detach()
mask_c = mask
if inactive_channels > 0:
mask_c = mask.clone()
topk_maxmum, _ = mask_d.topk(inactive_channels, dim=1, largest=False, sorted=False)
clamp_max, _ = topk_maxmum.max(dim=1, keepdim=True)
mask_index = mask_d.le(clamp_max)
mask_c[mask_index] = 0
mask_c = mask_c.view(b, c, 1, 1)
x = x * mask_c.expand_as(x)
return x, _lasso_loss
class Conv(nn.Sequential):
def __init__(self, in_channels, out_channels, kernel_size,
stride=1, padding=0, groups=1):
super(Conv, self).__init__()
self.add_module('norm', nn.BatchNorm2d(in_channels))
self.add_module('relu', nn.ReLU(inplace=True))
self.add_module('conv', nn.Conv2d(in_channels, out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding, bias=False,
groups=groups))
- 数据结构之二叉树
- 微信快速开发框架(六)-- 微信快速开发框架(WXPP QuickFramework)V2.0版本上线--源码已更新至github
- 数据结构之数组
- Android资源动态加载以及相关原理分析
- 微信快速开发框架(七)--发送客服信息,版本更新至V2.2 代码已更新至github
- 微信快速开发框架(八)-- V2.3--增加语音识别及网页获取用户信息,代码已更新至Github
- 微信公众平台快速开发框架 For Core 2.0 beta –JCSoft.WX.Core 5.2.0 beta发布
- Android系统层Watchdog机制源码分析
- 算法之插入排序
- Android Studio环境下搭建ReactNative
- Android实现两个ScrollView互相联动,同步滚动的效果
- 一个可以拖动的自定义Gridview代码
- android图片加载库Glide
- 密码最短长度为7,其中必须包含以下非字母数字字符1 完美解决方案
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- Linux 解决Deepin无法在root用户启动Google Chrome浏览器的问题
- 在Linux上安装和使用Docker的方法
- centOS7 NET模式设置静态Ip的方法步骤
- CentOS搭建PHP服务器环境简明教程
- CentOS7.2安装MySql5.7并开启远程连接授权的教程
- linux查看防火墙状态与开启关闭命令详解
- linux防火墙iptables规则的查看、添加、删除和修改方法总结
- Linux expect实现自动登录脚本实例代码
- scRNA-seq marker identification(一)
- 关于linux权限s权限和t权限详解
- centOS7 桥接模式设置静态Ip的方法步骤
- linux环境下卸载oracle 11g的过程
- Seurat包基本分析实战—文献图表复现
- ubuntu配置tftp服务的步骤小结
- CentOS7下GitLab跨大版本升级的方法