WeightNet之Rethinking

时间:2022-07-22
本文章向大家介绍WeightNet之Rethinking,主要内容包括其使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。

笔者昨天针对WeightNet一文进行了Rethinking,但是理解不够深入,可能对不少同学造成了误导,深表歉意。笔者今天与WeightNet一作就其中的一些疑惑点进行了讨论,在此将关于WeightNet的思考记录如下。

WeightNet

我们还是先看一下WeightNet的结构,见下图。它包含两个分支:(1)一个分支用于卷积权值预测;(2)一个分支类似Identity。将两个分支按照Convolution方式进行融合。

SENet

在SENet一文中,SE模块它一般置于卷积之后,而WeightNet也是进行了类似的桥接。笔者简单绘制了下图这样的桥接,左图为Conv+SE,右图为桥接变换后等价的WeightNet。看上去这两个是等价的,WeightNet是将Attention并入到前接的卷积权值中,从这个角度来看,笔者也认为两者是等价的。

在实现过程中,有两种形式可以完成上述转换(文章后面有code层面的解析说明,WeightNet采用的是隐式融合,笔者是按照显式方式进行思考,角度的不同导致了理解上的偏差):

CondConv

接下来,我们再来看一下CondConv与WeightNet的之间的关联性吧。从笔者的角度来看,它采用了“三个臭皮匠顶个诸葛亮”的思想,通过注意力机制自适应融合权值。

B, Ci, Co, H, W, k = 4, 32, 64, 16, 16, 3
inputs = torch.randn(B, Ci, H, W)
inputs = inputs.view(1, -1, H, W)

attention = torch.randn(B, Co, 1, 1, 1)
weight = torch.randn(Co, Ci, k, k)
weight = weight.unsqueeze(0)
weight = attention * weight
weight = weight.view(B * Co, Ci, k, k)

output = F.conv2d(inputs, weight, padding=1, groups=4)
output = output.view(B, Co, H, W)
print(output.size())

关于WeightNet的解读与分析到底结束。对WeightNet感兴趣的同学可以在下方留言一起讨论。