在lightgbm中使用交叉验证

时间:2022-07-24
本文章向大家介绍在lightgbm中使用交叉验证,主要内容包括其使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。

记录一下在学习过程中实验lightgbm的交叉验证的方法~

import numpy as np
import lightgbm as lgb
from sklearn.model_selection import KFold, StratifiedKFold
params = {'num_leaves': 10, 
           'min_child_samples': 20,
          'min_data_in_leaf': 0,
          'objective': 'binary', 
          'max_depth': 3,
          'learning_rate': 0.01,
          "min_sum_hessian_in_leaf": 15,
          "boosting": "gbdt",
          "feature_fraction": 0.3,  
          "bagging_freq": 1,
          "bagging_fraction": 0.8,
          "bagging_seed": 1,
          "lambda_l1": 0.3,             #l1
          'lambda_l2': 0.01,     #l2
          "verbosity": -1,
          "nthread": -1,                
          'metric': {'binary_logloss', 'auc'}, 
          "random_state": 1, 
          }

train_label =  train_data["Survived"]
test = test_data.copy()

NFOLDS = 10
kfold = StratifiedKFold(n_splits=NFOLDS, shuffle=True, random_state=1)
kf = kfold.split(train_data, train_label)
cv_pred = np.zeros(test.shape[0])
valid_best = 0

for i, (train_fold, validate) in enumerate(kf):
    X_train, X_validate, label_train, label_validate = 
    train_data.iloc[train_fold, :], train_data.iloc[validate, :], 
    train_label[train_fold], train_label[validate]
    
    dtrain = lgb.Dataset(X_train, label_train)
    dvalid = lgb.Dataset(X_validate, label_validate, reference=dtrain)
    
    bst = lgb.train(params, dtrain, num_boost_round=10000, valid_sets=dvalid,early_stopping_rounds=500)
    
    preds_last = bst.predict(test, num_iteration=bst.best_iteration)
    cv_pred += bst.predict(test, num_iteration=bst.best_iteration)
    valid_best += bst.best_score['valid_0']['auc']
    
cv_pred /= NFOLDS
valid_best /= NFOLDS

防止过拟合的参数: max_depth 树的深度,不要设置的太大; num_leaves 应该小于 2^(max_depth),否则可能会导致过拟合; min_child_samples 较大的值可以避免生成一个过深的树, 避免过拟合,但有可能导致欠拟合; min_sum_hessian_in_leaf 设置较大防止过拟合; feature_fraction 和 bagging_fraction都可以降低过拟合; 正则化参数lambda_l1(reg_alpha), lambda_l2(reg_lambda)。

随机文章