调参过程中的参数 学习率,权重衰减,冲量(learning_rate , weight_decay , momentum)

时间:2022-05-06
本文章向大家介绍调参过程中的参数 学习率,权重衰减,冲量(learning_rate , weight_decay , momentum),主要内容包括其使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。

无论是深度学习还是机器学习,大多情况下训练中都会遇到这几个参数,今天依据我自己的理解具体的总结一下,可能会存在错误,还请指正.

learning_rate , weight_decay , momentum这三个参数的含义. 并附上demo.

我们会使用一个例子来说明一下:

            比如我们有一堆数据

,我们只知道这对数据是从一个

黑盒中得到的,我们现在要寻找到那个具体的函数f(x),我们定义为目标函数T.

          我们现在假定有存在这个函数并且这个函数为:

         我们现在要使用这对数据来训练目标函数. 我们可以设想如果存在一个这个函数,必定满足{x,y}所有的关系,也就是说:     

         那么最理想的情况下 :  

 ,那么我们不妨定义这样一个优化目标函数:

        对于这堆数据,我们认为当Loss(W)对于所有的pair{x,y}都满足 Loss(W)趋近于或者等于0时,我们认为我们找到这个理想的目标函数T. 也就是此时 

.

      以上,我们发现寻找的目标函数的问题,已经成功的转移为求解: 

      也就是Loss 越小, f(x)越接近我们寻找的目标函数T.

那么说了这么多,这个和我们说的学习率learning_rate有什么关系呢?

                既然我们知道了我们当前的f(x)和目标函数的T的误差,那么我们可以将这个误差转移到每一个参数上,也就是变成每一个参数w和目标函数T的参数w_t的误差. 然后我们就以一定的幅度stride来缩小和真实值的距离,我们称这个stride为学习率learning_rate 而且我们就是这么做的.

                我们用公式表述就是:

                        我们的误差(损失)Loss:    

                我们这一个凸函数. 我们先对这个函数进行各个分量求偏导.

对于w0的偏导数:

那么对于分量w0承担的误差为:

  并且这个误差带方向.

那么我们需要使我们当前的w0更加接近目标函数的T的w0_t参数.我们需要做运算:

(梯度下降算法)

来更新wo的值. 同理其他参数w,而这个学习率就是来控制我们每次靠近真实值的幅度,为什么要这么做呢?

因为我们表述的误差只是一种空间表述形式我们可以使用均方差也可以使用绝对值,还可以使用对数,以及交叉熵等等,所以只能大致的反映,并不精确,就想我们问路一样,别人告诉我们直走五分钟,有的人走的快,有的人走的慢,所以如果走的快的话,当再次问路的时候,就会发现走多了,而折回来,这就是我们训练过程中的loss曲线震荡严重的原因之一. 所以学习率要设置在合理的大小.


好了说了这么多,这是学习率. 那么什么是权重衰减weight_decay呢? 有什么作用呢?

          我们接着看上面的这个Loss(w),我们发现如果参数过多的话,对于高位的w3,我们对其求偏导:

我们发现w3开始大于1的时候,w3会调节的很快,幅度很大,从而使得特征x3变为异常敏感.从而出现过拟合(overfitting).

       这个时候,我们需要约束一下w2,w3等高阶参数的大小,于是我们对Loss增加一个惩罚项,使得Loss的正反方向,不应该只由f(x) -y 决定,而还应该加上一个

;于是Loss变成了:

我们继续对Loss求解偏导数:

对wo求偏导:

对w3求偏导:

我们发现当x3值过大时,会改变Loss的导数的方向.而来抑制w2,w3等高阶函数的继续增长. 然而这样抑制并不是很灵活,所以我们在前面加入一个系数

,这个系数在数学上称之为拉格朗日乘子系数,也就是我们用到的weight_decay. 这样我们可以通过调节weight_decay系数,来调节w3,w2等高阶的增长程度。加入weight_decay后的公式:

从公式可以看出 ,weight_decay越大,抑制越大,w2,w3等系数越小,weight_decay越小,抑制越小,w2,w3等系数越大


那么冲量momentum又是啥?

     我们在使用梯度下降法,来调整w时公式是这样的:

我们每一次都是计算当前的梯度:

这样会发现对于那些梯度比较小的地方,参数w更新的幅度比较小,训练变得漫长,或者收敛慢.有时候遇到非最优的凸点,会出现冲不出去的现象.

而冲量加进来是一种快速效应.借助上一次的势能来和当前的梯度来调节当前的参数w.

公式表达为:

这样可以有效的避免掉入凸点无法冲出来,而且收敛速度也快很多.

附上demo: 使用mxnet编码.

  1 //
  2 // Created by xijun1 on 2017/12/14.
  3 //
  4 
  5 #include <iostream>
  6 #include <vector>
  7 #include <string>
  8 #include <mxnet/mxnet-cpp/MxNetCpp.h>
  9 #include <mxnet/mxnet-cpp/op.h>
 10 
 11 namespace  mlp{
 12     class MlpNet{
 13     public:
 14         static mx_float OutputAccuracy(mx_float* pred, mx_float* target) {
 15             int right = 0;
 16             for (int i = 0; i < 128; ++i) {
 17                 float mx_p = pred[i * 10 + 0];
 18                 float p_y = 0;
 19                 for (int j = 0; j < 10; ++j) {
 20                     if (pred[i * 10 + j] > mx_p) {
 21                         mx_p = pred[i * 10 + j];
 22                         p_y = j;
 23                     }
 24                 }
 25                 if (p_y == target[i]) right++;
 26             }
 27             return right / 128.0;
 28         }
 29 
 30        static void net(){
 31             using mxnet::cpp::Symbol;
 32             using mxnet::cpp::NDArray;
 33 
 34             Symbol x = Symbol::Variable("X");
 35             Symbol y = Symbol::Variable("label");
 36 
 37             std::vector<std::int32_t> shapes({512 , 10});
 38             //定义一个两层的网络. wx + b
 39             Symbol weight_0 = Symbol::Variable("weight_0");
 40             Symbol biases_0 = Symbol::Variable("biases_0");
 41 
 42             Symbol fc_0 = mxnet::cpp::FullyConnected("fc_0",x,weight_0,biases_0
 43                     ,512);
 44 
 45             Symbol output_0 = mxnet::cpp::LeakyReLU("relu_0",fc_0,mxnet::cpp::LeakyReLUActType::kLeaky);
 46 
 47             Symbol weight_1 = Symbol::Variable("weight_1");
 48             Symbol biases_1 = Symbol::Variable("biases_1");
 49             Symbol fc_1 = mxnet::cpp::FullyConnected("fc_1",output_0,weight_1,biases_1,10);
 50             Symbol output_1 = mxnet::cpp::LeakyReLU("relu_1",fc_1,mxnet::cpp::LeakyReLUActType::kLeaky);
 51             Symbol pred = mxnet::cpp::SoftmaxOutput("softmax",output_1,y);  //目标函数,loss函数
 52             mxnet::cpp::Context ctx = mxnet::cpp::Context::cpu( 0);
 53 
 54             //定义输入数据
 55             std::shared_ptr< mx_float > aptr_x(new mx_float[128*28] , [](mx_float* aptr_x){ delete [] aptr_x ;});
 56             std::shared_ptr< mx_float > aptr_y(new mx_float[128] , [](mx_float * aptr_y){ delete [] aptr_y ;});
 57 
 58             //初始化数据
 59             for(int i=0 ; i<128 ; i++){
 60                 for(int j=0;j<28 ; j++){
 61                     //定义x
 62                     aptr_x.get()[i*28+j]= i % 10 +0.1f;
 63                 }
 64 
 65                 //定义y
 66                 aptr_y.get()[i]= i % 10;
 67             }
 68            std::map<std::string, mxnet::cpp::NDArray> args_map;
 69            //导入数据
 70            NDArray arr_x(mxnet::cpp::Shape(128,28),ctx, false);
 71            NDArray arr_y(mxnet::cpp::Shape( 128 ),ctx,false);
 72            //将数据转换到NDArray中
 73            arr_x.SyncCopyFromCPU(aptr_x.get(),128*28);
 74            arr_x.WaitToRead();
 75 
 76            arr_y.SyncCopyFromCPU(aptr_y.get(),128);
 77            arr_y.WaitToRead();
 78 
 79            args_map["X"]=arr_x.Slice(0,128).Copy(ctx) ;    
 80            args_map["label"]=arr_y.Slice(0,128).Copy(ctx);
 81            NDArray::WaitAll();
 82             //绑定网络
 83            mxnet::cpp::Executor *executor = pred.SimpleBind(ctx,args_map);
 84             //选择优化器
 85            mxnet::cpp::Optimizer *opt = mxnet::cpp::OptimizerRegistry::Find("sgd");
 86            mx_float learning_rate = 0.0001; //学习率
 87            mx_float weight_decay = 1e-4; //权重
 88            opt->SetParam("momentum", 0.9)
 89                    ->SetParam("lr", learning_rate)
 90                    ->SetParam("wd", weight_decay);
 91            //定义各个层参数的数组
 92            NDArray arr_w_0(mxnet::cpp::Shape(512,28),ctx, false);
 93            NDArray arr_b_0(mxnet::cpp::Shape( 512 ),ctx,false);
 94            NDArray arr_w_1(mxnet::cpp::Shape(10 , 512 ) , ctx , false);
 95            NDArray arr_b_1(mxnet::cpp::Shape( 10 ) , ctx , false);
 96 
 97            //初始化权重参数
 98            arr_w_0 = 0.01f;
 99            arr_b_1 = 0.01f;
100            arr_w_1 = 0.01f;
101            arr_b_1 = 0.01f;
102 
103             //初始化参数
104             executor->arg_dict()["weight_0"]=arr_w_0;
105             executor->arg_dict()["biases_0"]=arr_b_0;
106             executor->arg_dict()["weight_1"]=arr_w_1;
107             executor->arg_dict()["biases_1"]=arr_b_1;
108 
109             mxnet::cpp::NDArray::WaitAll();
110             //训练
111             std::cout<<" Training "<<std::endl;
112 
113             int max_iters = 20000;  //最大迭代次数
114            //获取训练网络的参数列表
115            std::vector<std::string>  args_name = pred.ListArguments();
116             for (int iter = 0; iter < max_iters ; ++iter) {
117                 executor->Forward(true);
118                 executor->Backward();
119 
120                 if(iter % 100 == 0){
121                     std::vector<NDArray> & out = executor->outputs;
122                     std::shared_ptr<mx_float> tp_x( new mx_float[128*28] ,
123                                                     [](mx_float * tp_x){ delete [] tp_x ;});
124                     out[0].SyncCopyToCPU(tp_x.get(),128*10);
125                     NDArray::WaitAll();
126                     std::cout<<"epoch "<<iter<<"  "<<"Accuracy: "<<  OutputAccuracy(tp_x.get() , aptr_y.get())<<std::endl;
127                 }
128                 //args_name.
129                 for(size_t arg_ind=0; arg_ind<args_name.size(); ++arg_ind){
130                     //执行
131                     if(args_name[arg_ind]=="X" || args_name[arg_ind]=="label")
132                         continue;
133 
134                     opt->Update(arg_ind,executor->arg_arrays[arg_ind],executor->grad_arrays[arg_ind]);
135                 }
136                 NDArray::WaitAll();
137 
138             }
139 
140 
141         }
142     };
143 }
144 
145 int main(int argc , char * argv[]){
146     mlp::MlpNet::net();
147     MXNotifyShutdown();
148     return EXIT_SUCCESS;
149 }

结果:

Training 
epoch 0  Accuracy: 0.09375
epoch 100  Accuracy: 0.304688
epoch 200  Accuracy: 0.195312
epoch 300  Accuracy: 0.203125
epoch 400  Accuracy: 0.304688
epoch 500  Accuracy: 0.296875
epoch 600  Accuracy: 0.304688
epoch 700  Accuracy: 0.304688
epoch 800  Accuracy: 0.398438
epoch 900  Accuracy: 0.5
epoch 1000  Accuracy: 0.5
epoch 1100  Accuracy: 0.40625
epoch 1200  Accuracy: 0.5
epoch 1300  Accuracy: 0.398438
epoch 1400  Accuracy: 0.40625
epoch 1500  Accuracy: 0.703125
epoch 1600  Accuracy: 0.609375
epoch 1700  Accuracy: 0.507812
epoch 1800  Accuracy: 0.703125
epoch 1900  Accuracy: 0.703125
epoch 2000  Accuracy: 0.804688
epoch 2100  Accuracy: 0.703125
epoch 2200  Accuracy: 0.804688
epoch 2300  Accuracy: 0.804688
epoch 2400  Accuracy: 0.804688
epoch 2500  Accuracy: 0.90625
epoch 2600  Accuracy: 0.90625
epoch 2700  Accuracy: 0.90625
epoch 2800  Accuracy: 1
epoch 2900  Accuracy: 1