基于单层决策树下 adaboost算法的实现代码
时间:2019-03-18
本文章向大家介绍基于单层决策树下 adaboost算法的实现代码,主要包括基于单层决策树下 adaboost算法的实现代码使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。
本文采用的数据是 Python 数据分析实战中的数据 ‘horseColicTest2.txt’,
网上都可以找到数据文件,若找不到可以私信我发给你们。
from numpy import *
def loadSimpData():
datMat = matrix([[ 1. , 2.1],
[ 2. , 1.1],
[ 1.3, 1. ],
[ 1. , 1. ],
[ 2. , 1. ]])
classLabels = [1.0, 1.0, -1.0, -1.0, 1.0]
return datMat,classLabels
################################# 单层决策树生成函数 ##############################
def stumpClassify(dataMatrix, dimen, threshVal, threshIneq):
retArray = ones((shape(dataMatrix)[0], 1))
if threshIneq == 'lt': # lt:less than ; gt: greater than
retArray[dataMatrix[:,dimen] <= threshVal] = -1.0 # 在小于的分类条件下, 某列值 <= 阈值 时标签为 -1
else:
retArray[dataMatrix[:, dimen] > threshVal] = -1.0 # 在大于的分类条件下, 某列值 > 阈值 时标签为 -1
return retArray
def buildStump(dataArr, classLables, D):
dataMatrix = mat(dataArr); lableMat = mat(classLables).T
m,n = shape(dataMatrix)
numSteps = 10.0; bestStump = {}; beastClasEst = mat(zeros((m,1)))
minError = inf
for i in range(n): # dimen 的取值
rangeMin = dataMatrix[:, i].min()
rangeMax = dataMatrix[:, i].max()
stepSzie = (rangeMax - rangeMin)/ numSteps
for j in range(-1, int(numSteps)):
for inequal in ['lt', 'gt']:
threshVal = (rangeMin + float(j) * stepSzie) # 确定阈值
predictVals = stumpClassify(dataMatrix, i, threshVal, inequal)
errArr = mat(ones((m,1)))
errArr[predictVals == lableMat] = 0
weightError = D.T*errArr
# print('split: dim %d, thresh %.2f, thresh inequal: %s, the weighted error is %.3f'%(i, threshVal, inequal, weightError))
if weightError < minError:
minError = weightError
bestClasEst = predictVals.copy()
bestStump['dim'] = i
bestStump['thresh'] = threshVal
bestStump['ineq'] = inequal
return bestStump, minError, bestClasEst
def adaBoostTrainDS(dataArr, classLabels, numIt = 4.0):
'''
:param dataArr: 训练数据集
:param classLabels: 训练数据集对应的标签
:param numIt: 最大迭代次数,该数值需要用户自定义,默认为4
:return: 返回每次训练的弱分类器的参数值, 格式如:{'dim': 0, 'thresh': 1.3, 'ineq': 'lt', 'alpha': 0.6931471805599453}
'''
weakClassArr = []
m = shape(dataArr)[0]
D = mat(ones((m, 1))/m)
aggClassEst = mat(zeros((m, 1)))
for i in range(numIt): # 循环次数
bestStump, error, classEst = buildStump(dataArr, classLabels, D) # 第一次默认D 为平均值
# print("D:", D.T)
alpha = float(0.5*log((1.0 - error)/max(error, 1e-16))) # alpha = 0.5*ln((1-error)/error) 除零防止溢出
bestStump['alpha'] = alpha
weakClassArr.append(bestStump)
# print('classEst:', classEst.T) # 打印预测值
expon = multiply(-1*alpha*mat(classLabels).T, classEst) # 若预测值与真实值相等 则权重为 -alpha 调整,否则权重值为 alpha 调整
D = multiply(D, exp(expon))
D = D/D.sum() # 调整后的 D
aggClassEst += alpha*classEst # 记录每个数据点的类别估计累计值
# print('aggClassEst:', aggClassEst.T)
aggErrors = multiply(sign(aggClassEst) != mat(classLabels).T, ones((m, 1))) # 最终时根据 aggClassEst 的值的正负号,来得出最终的预测结果
errRate = aggErrors.sum()/m
# print('total error:', errRate, '\n')
if errRate == 0.0:
break
return weakClassArr, aggClassEst
def adaClassify(datToClass, classifierArr):
'''
将弱分类器的训练过程从程序中抽出来,然后应用到某个具体的实列上去。
每个弱分类器的结果以其对应的alpha作为权重。
所有这些弱分类器的结果加权求和就得到了最后的结果。
:param datToClass: 训练数据
:param classifierArr: 格式为 {'dim': 0, 'thresh': 1.3, 'ineq': 'lt', 'alpha': 0.6931471805599453}
:return: 最终的预测值
'''
dataMarix = mat(datToClass)
m = shape(dataMarix)[0]
aggClassEst = mat(zeros((m, 1)))
for i in range(len(classifierArr)):
classEst = stumpClassify(dataMarix, classifierArr[i]['dim'], classifierArr[i]['thresh'], classifierArr[i]['ineq'])
aggClassEst += classifierArr[i]['alpha']*classEst
# print(aggClassEst)
return sign(aggClassEst)
def loadDataSet(fileName):
# numFeat = len(open(fileName).readline().split('\t'))
dataMat = []; lableMat = []
with open(fileName, mode= 'rt', encoding='utf-8') as fr:
for line in fr.readlines():
lineArr = line.strip().split('\t')
lineArr = list(map(float, lineArr))
dataMat.append(lineArr[:-1])
lableMat.append(lineArr[-1])
return dataMat, lableMat
def plotRoc(predStrenths, classLabels):
import matplotlib.pyplot as plt
cur = (1.0, 1.0) # 设置初始点为(1, 1)
ySum = 0.0
numProsCals = sum(array(classLabels) == 1.0)
yStep = 1/float(numProsCals) # 当标签为正例时,点在y 轴上减小一个量...y 代表的真正例
xStep = 1/float(len(classLabels) - numProsCals)# 当标签为不为正例时,点在x 轴上减小一个量...y 代表的假正例
sortIndex = predStrenths.argsort() # 从小到大排列
fig = plt.figure()
fig.clf()
ax = plt.subplot(1,1,1)
for index in sortIndex.tolist()[0]:
if classLabels[index] == 1.0:
delX = 0; delY = yStep
else:
delX = xStep; delY = 0
ySum += cur[1]
ax.plot([cur[0], cur[0] - delX], [cur[1], cur[1] - delY], c = 'b')
cur = (cur[0]- delX, cur[1] - delY)
ax.plot([0, 1], [0, 1], 'b--')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC curve for Adaboost Horse Colic Detection System')
ax.axis([0,1,0,1])
plt.show()
print('the Area under the curve is :', ySum*xStep)
if __name__ == '__main__':
# datMat, classLabels = loadSimpData()
# D = mat(ones((5,1))/5)
# # buildStump(datMat, classLabels, D)
# # print(buildStump(datMat,classLabels, D))
# # print(stumpClassify(datMat,1, 1, 'lt'))
# classifierArr= adaBoostTrainDS(datMat, classLabels,9)
# # print(adaClassify(datMat, classifierArr)) # 最终的预测结果
# ################### 预测值
# print(adaClassify([0,0], classifierArr))
# print(adaClassify([[0, 0],[5,5]], classifierArr))
datArr,labelArr = loadDataSet('horseColicTraining2.txt')
classifierArray, aggClassEst = adaBoostTrainDS(datArr, labelArr, 10)
testArr, testLableArr = loadDataSet('horseColicTest2.txt')
pred = adaClassify(testArr, classifierArray)
# print(pred)
# print(labelArr)
print((pred.T != testLableArr).sum())
print(mean(pred.T != testLableArr))
plotRoc(aggClassEst.T, labelArr)
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法