机器学习之决策树熵&信息增量求解算法实现

时间:2022-05-06
本文章向大家介绍机器学习之决策树熵&信息增量求解算法实现,主要内容包括其使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。

此文不对理论做相关阐述,仅涉及代码实现:

1.熵计算公式:

             P为正例,Q为反例

     Entropy(S)   = -PLog2(P) - QLog2(Q);

2.信息增量计算:

    Gain(S,Sv) = Entropy(S) - (|Sv|/|S|)ΣEntropy(Sv);

举例:

转化数据输入:

 5  14
 Outlook       Sunny  Sunny  Overcast  Rain  Rain    Rain    Overcast  Sunny  Sunny    Rain    Sunny   Overcast   Overcast    Rain
 Temperature   Hot    Hot    Hot       Mild  Cool    Cool        Cool   Mild  Cool     Mild    Mild    Mild       Hot         Mild
 Humidity      High   High   High      High  Normal  Normal  Normal     High  Normal   Normal  Normal  High       Normal      High
 Wind          Weak   Strong Weak      Weak  Weak    Strong  Strong    Weak   Weak     Weak    Strong  Strong     Weak        Strong
 PlayTennis    No     No     Yes       Yes   Yes     No      Yes       No     Yes      Yes     Yes     Yes        Yes         No
 Outlook Temperature Humidity Wind PlayTennis
 1 package com.qunar.data.tree;
 2 
 3 /**
 4  * *********************************************************
 5  * <p/>
 6  * Author:     XiJun.Gong
 7  * Date:       2016-09-02 15:28
 8  * Version:    default 1.0.0
 9  * Class description:
10  * <p>统计该类型出现的次数</p>
11  * <p/>
12  * *********************************************************
13  */
14 public class CountMap<T> {
15 
16     private T key;     //类型
17     private int value;   //出现的次数
18 
19     public CountMap() {
20         this(null, 0);
21     }
22 
23     public CountMap(T key, int value) {
24         this.key = key;
25         this.value = value;
26     }
27 
28     public T getKey() {
29         return key;
30     }
31 
32     public void setKey(T key) {
33         this.key = key;
34     }
35 
36     public int getValue() {
37         return value;
38     }
39 
40     public void setValue(int value) {
41         this.value = value;
42     }
43 }
  1 package com.qunar.data.tree;
  2 
  3 import com.google.common.collect.ArrayListMultimap;
  4 import com.google.common.collect.Maps;
  5 import com.google.common.collect.Multimap;
  6 import com.google.common.collect.Sets;
  7 
  8 import java.util.*;
  9 
 10 /**
 11  * *********************************************************
 12  * <p/>
 13  * Author:     XiJun.Gong
 14  * Date:       2016-09-02 14:24
 15  * Version:    default 1.0.0
 16  * Class description:
 17  * <p>决策树</p>
 18  * <p/>
 19  * *********************************************************
 20  */
 21 
 22 public class DecisionTree<T, K> {
 23 
 24     private static String positiveExampleType = "Yes";
 25     private static String counterExampleType = "No";
 26 
 27 
 28     public double pLog2(final double p) {
 29         if (0 == p) return 0;
 30         return p * (Math.log(p) / Math.log(2));
 31     }
 32 
 33     /**
 34      * 熵计算
 35      *
 36      * @param positiveExample 正例个数
 37      * @param counterExample  反例个数
 38      * @return 熵值
 39      */
 40     public double entropy(final double positiveExample, final double counterExample) {
 41 
 42         double total = positiveExample + counterExample;
 43         double positiveP = positiveExample / total;
 44         double counterP = counterExample / total;
 45         return -1d * (pLog2(positiveP) + pLog2(counterP));
 46     }
 47 
 48     /**
 49      * @param features 特征列表
 50      * @param results  对应结果
 51      * @return 将信息整合成新的格式
 52      */
 53     public Multimap<T, CountMap<K>> merge(final List<T> features, final List<T> results) {
 54         //数据转化
 55         Multimap<T, CountMap<K>> InfoMap = ArrayListMultimap.create();
 56         Iterator result = results.iterator();
 57         for (T feature : features) {
 58             K res = (K) result.next();
 59             boolean tag = false;
 60             Collection<CountMap<K>> countMaps = InfoMap.get(feature);
 61             for (CountMap countMap : countMaps) {
 62                 if (countMap.getKey().equals(res)) {
 63                     /*修改值*/
 64                     int num = countMap.getValue() + 1;
 65                     InfoMap.remove(feature, countMap);
 66                     InfoMap.put(feature, new CountMap<K>(res, num));
 67                     tag = true;
 68                     break;
 69                 }
 70             }
 71             if (!tag)
 72                 InfoMap.put(feature, new CountMap<K>(res, 1));
 73         }
 74 
 75         return InfoMap;
 76     }
 77 
 78     /**
 79      * 信息增益
 80      *
 81      * @param infoMap   因素(Outlook,Temperature,Humidity,Wind)对应的结果
 82      * @param dataTable 输入的数据表
 83      * @param type      因素中的类型(Outlook{Sunny,Overcast,Rain})
 84      * @param entropyS  总的熵值
 85      * @param totalSize 总的样本数
 86      * @return 信息增益
 87      */
 88     public double gain(Multimap<T, CountMap<K>> infoMap,
 89                        Map<K, List<T>> dataTable,
 90                        final String type,
 91                        double entropyS,
 92                        final int totalSize) {
 93         //去重
 94         Set<T> subTypes = Sets.newHashSet();
 95         subTypes.addAll(dataTable.get(type));
 96         /*计算*/
 97         for (T subType : subTypes) {
 98             Collection<CountMap<K>> countMaps = infoMap.get(subType);
 99             double subSize = 0;
100             double positiveExample = 0;
101             double counterExample = 0;
102             for (CountMap<K> countMap : countMaps) {
103                 subSize += countMap.getValue();
104                 if (positiveExampleType.equals(countMap.getKey()))
105                     positiveExample = countMap.getValue();
106                 else
107                     counterExample = countMap.getValue();
108             }
109             entropyS -= (subSize / totalSize) * entropy(positiveExample, counterExample);
110         }
111         return entropyS;
112     }
113 
114     /**
115      * 计算
116      *
117      * @param dataTable  数据表
118      * @param types      因素列表{Outlook,Temperature,Humidity,Wind}
119      * @param resultType 结果(PlayTennis)
120      * @return 返回信息增益集合
121      */
122     public Map<String, Double> calculate(Map<K, List<T>> dataTable, List<K> types, K resultType) {
123 
124         Map<String, Double> answer = Maps.newHashMap();
125         List<T> results = dataTable.get(resultType);
126         int totalSize = results.size();
127         int positiveExample = 0;
128         int counterExample = 0;
129         double entropyS = 0d;
130         for (T ExampleType : results) {
131             if (positiveExampleType.equals(ExampleType)) {
132                 ++positiveExample;
133                 continue;
134             }
135             ++counterExample;
136         }
137         /*计算总的熵*/
138         entropyS = entropy(positiveExample, counterExample);
139 
140         Multimap<T, CountMap<K>> infoMap;
141         for (K type : types) {
142             infoMap = merge(dataTable.get(type), results);
143             double _gain = gain(infoMap, dataTable, (String) type, entropyS, totalSize);
144             answer.put((String) type, _gain);
145         }
146         return answer;
147     }
148 
149 }   1package com.qunar.data.tree;
 2 
 3 import com.google.common.collect.Lists;
 4 import com.google.common.collect.Maps;
 5 
 6 import java.util.*;
 7 
 8 /**
 9  * *********************************************************
10  * <p/>
11  * Author:     XiJun.Gong
12  * Date:       2016-09-02 16:43
13  * Version:    default 1.0.0
14  * Class description:
15  * <p/>
16  * *********************************************************
17  */
18 public class Main {
19 
20     public static void main(String args[]) {
21 
22         Scanner scanner = new Scanner(System.in);
23         while (scanner.hasNext()) {
24             DecisionTree<String, String> dt = new DecisionTree();
25             Map<String, List<String>> dataTable = Maps.newHashMap();
26             /*Map<String, List<String>> dataTable = Maps.newHashMap();*/
27             List<String> types = Lists.newArrayList();
28             String resultType;
29             int factorSize = scanner.nextInt();
30             int demoSize = scanner.nextInt();
31             String type;
32 
33             for (int i = 0; i < factorSize; i++) {
34                 List<String> demos = Lists.newArrayList();
35                 type = scanner.next();
36                 for (int j = 0; j < demoSize; j++) {
37                     demos.add(scanner.next());
38                 }
39                 dataTable.put(type, demos);
40             }
41             for (int i = 1; i < factorSize; i++) {
42                 types.add(scanner.next());
43             }
44             resultType = scanner.next();
45             Map<String, Double> ans = dt.calculate(dataTable, types, resultType);
46             List<Map.Entry<String, Double>> list = new ArrayList<Map.Entry<String, Double>>(ans.entrySet());
47             Collections.sort(list, new Comparator<Map.Entry<String, Double>>() {
48 
49 
50                 @Override
51                 public int compare(Map.Entry<String, Double> o1, Map.Entry<String, Double> o2) {
52                     return (o2.getValue() > o1.getValue() ? 1 : -1);
53                 }
54             });
55 
56             for (Map.Entry<String, Double> iterator : list) {
57                 System.out.println(iterator.getKey() + "= " + iterator.getValue());
58             }
59         }
60     }
61 
62 }
63 /**
64  *使用举例:*
65  5  14
66  Outlook       Sunny  Sunny  Overcast  Rain  Rain    Rain    Overcast  Sunny  Sunny    Rain    Sunny   Overcast   Overcast    Rain
67  Temperature   Hot    Hot    Hot       Mild  Cool    Cool        Cool   Mild  Cool     Mild    Mild    Mild       Hot         Mild
68  Humidity      High   High   High      High  Normal  Normal  Normal     High  Normal   Normal  Normal  High       Normal      High
69  Wind          Weak   Strong Weak      Weak  Weak    Strong  Strong    Weak   Weak     Weak    Strong  Strong     Weak        Strong
70  PlayTennis    No     No     Yes       Yes   Yes     No      Yes       No     Yes      Yes     Yes     Yes        Yes         No
71  Outlook Temperature Humidity Wind PlayTennis
72  */

结果:

Outlook= 0.2467498197744391
Humidity= 0.15183550136234136
Wind= 0.04812703040826927
Temperature= 0.029222565658954647