使用sklearn构建含有标量属性的决策树

时间:2022-05-08
本文章向大家介绍使用sklearn构建含有标量属性的决策树,主要内容包括其使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。

网络上使用sklearn生成决策树的资料很多,这里主要说明遇见标量数据的处理。

经查验参考资料,sklearn并非使用了课上以及书上讲的ID3算法,而是选择了CART,该算法生成二叉树;scikit-learn使用了一种优化的CART算法,要求元数据为数值型(要能转换为np.float32类型的矩阵),因为该实现同时可以做回归分析。然而,题目数据中有天气等标量数据,所以还要进行转化,这里采用了sklearn中的LabelEncoder来将n个标量转化为1至n-1的整数。将数据训练完毕后,安装并使用了Graphviz(一个图形显示库)和pydotplus(方便使用Graphviz的Python编程接口)来进行结果图形化显示;查阅资料说的配置好像比较复杂,其实下载下来Graphviz后解压缩,并把bin文件夹加入环境变量就可以用pydotplus来访问了。使用信息熵作为度量,结果如图所示,其中value表示目标两类各包含多少实例。

结果:

为展示训练结果如何,将原数据再次使用score函数输入,发现正确率100%。应该是由于没有限制树的深度结果比较精确,并且发现“湿度”这个属性根本没有使用!但是一旦数据比较多,就需要限制树的深度了和每个叶子的实例个数了,由max_depth、min_samples_split、min_samples_leaf来设置。

最后还有一些疑问,就是把标量当做数值属性来处理,会影响最后分类的结果吗?需要拿数据说话还是有一些已经存在的结论。。。?

 1 # -*- coding: utf-8 -*-
 2 """
 3 Created on Tue Nov 22 17:45:37 2016
 4 
 5 @author: Ascii0x03
 6 """
 7 
 8 from sklearn import tree
 9 from sklearn import preprocessing
10 import pydotplus
11 from IPython.display import Image
12 
13 #将数据集data中的字符串属性全部转化为对应的标签
14 #data为矩阵,同tree.DecisionTreeClassifier.fit方法中的数据
15 #返回值le_list是preprocessing.LabelEncoder()对象的列表
16 #str_index是属性中字符串类型的下标
17 def preprocess(data):
18     str_index = []
19     #temp_label = []
20     le_list = []
21     le_num = 0
22     for i in range(0,len(data[1])):
23         if (isinstance(data[1][i], str)):
24             str_index.append(i)
25     #整理出labelEncoder
26     for index in str_index:
27         temp_label = []
28         for i in data:
29             temp_label.append(i[index])
30         le_list.append(preprocessing.LabelEncoder())
31         le_list[le_num].fit(temp_label)
32         #根据labelEncoder修改原始数据
33         #print temp_label
34         for i in data:
35             i[index] = le_list[le_num].transform([(i[index])])[0]
36         
37         le_num += 1
38        
39     return (le_list, str_index)
40     
41     
42     
43 clf = tree.DecisionTreeClassifier(criterion = "entropy")
44 #每行是一个数据,分别为天气,温度,湿度风况
45 data = [["Sunny", 85, 85, "No"], 
46         ["Sunny", 80, 90, "Yes"],
47         ["Cloudy", 83, 78, "No"], 
48         ["Rainy", 70, 96, "No"], 
49         ["Rainy", 68, 80, "No"],
50         ["Rainy", 65, 70, "Yes"],
51         ["Cloudy", 64, 65, "Yes"],
52         ["Sunny", 72, 95, "No"],
53         ["Sunny", 69, 70, "No"],
54         ["Rainy", 75, 80, "No"],
55         ["Sunny", 75, 70, "Yes"],
56         ["Cloudy", 72, 90, "Yes"],
57         ["Cloudy", 81, 75, "No"],
58         ["Rainy", 71, 80, "Yes"]
59         ]
60 #针对每行数据,分类为适合运动与不适合运动
61 labels = ["unfit", "unfit", "fit", "fit", "fit",
62           "unfit", "fit", "unfit", "fit", "fit",
63           "fit", "fit", "fit","unfit"]
64 (le_list, str_index) = preprocess(data)
65 #print data
66 clf.fit(data, labels)
67 
68 print clf.feature_importances_
69 dot_data = tree.export_graphviz(clf, out_file=None) 
70 graph = pydotplus.graph_from_dot_data(dot_data)
71 Image(graph.create_png()) #这里貌似不能正确显示
72 graph.write_pdf("test1.pdf")
73 graph.write_png("test1.png") 
74 #print dot_data
75 
76 test = [["Rainy", 71, 80, "Yes"]]
77 #Preprocessing the test data
78 for index in range(0, len(str_index)):
79     for i in test:
80         i[str_index[index]] = le_list[index].transform([i[str_index[index]]])[0]
81 #print test
82 print clf.predict(test)
83 print clf.predict_proba(test)
84 print clf.score(data, labels)

参考:

0. ID3算法实现决策树可http://blog.csdn.net/u012822866/article/details/42419471

1. http://scikit-learn.org/stable/modules/tree.html#tree-classification

2. http://scikit-learn.org/dev/modules/generated/sklearn.preprocessing.LabelEncoder.html

3. http://pydotplus.readthedocs.io/

4. http://www.graphviz.org/