pytorch 如何设置 可学习参数

时间:2022-07-22
本文章向大家介绍pytorch 如何设置 可学习参数,主要内容包括其使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。

交流、咨询,有疑问欢迎添加QQ 2125364717,一起交流、一起发现问题、一起进步啊,哈哈哈哈哈

各位看官老爷,如果觉得对您有用麻烦赏个子,创作不易,0.1元就行了。下面是微信乞讨码:

添加描述

添加描述

如何根据自己需求设定,可学习参数,并进行初始化。

#比如cnn输出4个东西,你又不想concate到到一起,你想用权重加法,权重又不想自己设定,想让网络自己学

#requires_grad=True这个很重要

#设置前置网络及 可学习参数
self.cnn=cnn_output4()
self.fuse_weight_1 = torch.nn.Parameter(torch.FloatTensor(1), requires_grad=True)
self.fuse_weight_2 = torch.nn.Parameter(torch.FloatTensor(1), requires_grad=True)
self.fuse_weight_3 = torch.nn.Parameter(torch.FloatTensor(1), requires_grad=True)
self.fuse_weight_4 = torch.nn.Parameter(torch.FloatTensor(1), requires_grad=True)
     
#初始化
self.fuse_weight_1.data.fill_(0.25)
self.fuse_weight_2.data.fill_(0.25)
self.fuse_weight_3.data.fill_(0.25)
self.fuse_weight_4.data.fill_(0.25)
     
def forward(x):
    x1,x2,x3,x4=self.cnn(x)
    return fuse_weight_1*x1+fuse_weight_2*x2+fuse_weight_3*x3+fuse_weight_4*x4