朴素贝叶斯分类器(离散型)算法实现(一)

时间:2022-05-06
本文章向大家介绍朴素贝叶斯分类器(离散型)算法实现(一),主要内容包括其使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。

1. 贝叶斯定理:    

   (1)   P(A^B) = P(A|B)P(B) = P(B|A)P(A) 

 由(1)得

   P(A|B) = P(B|A)*P(A)/[p(B)]

贝叶斯在最基本题型:

假定一个场景,在一所高中男女比例为4:6, 留长头发的有男学生有女学生, 我们设定女生都留长发 , 而男生中有10%的留长发,90%留短发.那么如果我们看到远处一个长发背影?请问是一只男学生的概率?

  分析:

    P(男|长发) = P(长发|男)*P(男)/[p(长发)] 

        = (1/10)*(4/10)/[(6+4*(1/10))/10]

        =1/16 =0.0625

   P(女|长发) =P(长发|女)*P(女)/[p(长发)]

                  =1*(6/10)/[(6+4*(1/10))/10]

                 =30/32 =15/16

再举一个列子:

某个医院早上收了六个门诊病人,如下表。

  症状  职业   疾病   打喷嚏 护士   感冒    打喷嚏 农夫   过敏    头痛  建筑工人 脑震荡    头痛  建筑工人 感冒    打喷嚏 教师   感冒    头痛  教师   脑震荡

现在又来了第七个病人,是一个打喷嚏的建筑工人。请问他患上感冒的概率有多大?(来源: http://www.ruanyifeng.com/blog/2013/12/naive_bayes_classifier.html)

Java代码实现:

 1 /**
 2  * *********************************************************
 3  * <p/>
 4  * Author:     XiJun.Gong
 5  * Date:       2016-08-31 20:36
 6  * Version:    default 1.0.0
 7  * Class description:
 8  * <p>特征库</p>
 9  * <p/>
10  * *********************************************************
11  */
12 
13 public class FeaturePoint {
14 
15     private String key;
16     private double p;
17 
18     public FeaturePoint(String key) {
19         this(key, 1);
20     }
21 
22     public FeaturePoint(String key, double p) {
23         this.key = key;
24         this.p = p;
25     }
26 
27     public String getKey() {
28         return key;
29     }
30 
31     public void setKey(String key) {
32         this.key = key;
33     }
34 
35     public double getP() {
36         return p;
37     }
38 
39     public void setP(double p) {
40         this.p = p;
41     }
42 }
 1 import com.google.common.collect.ArrayListMultimap;
 2 import com.google.common.collect.Multimap;
 3 
 4 import java.util.Collection;
 5 import java.util.List;
 6 
 7 /**
 8  * *********************************************************
 9  * <p/>
10  * Author:     XiJun.Gong
11  * Date:       2016-08-31 15:48
12  * Version:    default 1.0.0
13  * Class description:
14  * <p/>
15  * *********************************************************
16  */
17 
18 public class Bayes {
19     private static Multimap<String, FeaturePoint> map = ArrayListMultimap.create();
20 
21     /*喂数据*/
22     public void input(List<String> labels) {
23 
24         for (String key : labels) {
25             Collection<FeaturePoint> features = map.get(key);
26             for (String value : labels) {
27                 if (features == null || features.size() < 1) {
28                     map.put(key, new FeaturePoint(value));
29                     continue;
30                 }
31                 boolean tag = false;
32                 for (FeaturePoint feature : features) {
33                     if (feature.getKey().equals(value)) {
34                         Double num = feature.getP() + 1;
35                         map.remove(key, feature);
36                         map.put(key, new FeaturePoint(value, num));
37                         tag = true;
38                         break;
39                     }
40                 }
41                 if (!tag)
42                     map.put(key, new FeaturePoint(value));
43             }
44         }
45     }
46 
47     /*构造模型*/
48     public void excute(List<String> labels) {
49         //   excute(labels, null);
50     }
51 
52     /*构造模型*/
53     public Double excute(final List<String> labels, final String judge, Integer dataSize) {
54 
55         Double denominator = 1d;    //分母
56         Double numerator = 1d;      //分子
57         Double coughNum = 0d;
58        /*选择相关性分子*/
59         Collection<FeaturePoint> featurePoints = map.get(judge);
60         for (FeaturePoint featurePoint : featurePoints) {
61             if (judge.equals(featurePoint.getKey())) {
62                 coughNum = featurePoint.getP();
63                 denominator *= (featurePoint.getP() / dataSize);
64                 break;
65             }
66         }
67 
68         Integer size = featurePoints.size() - 1; //容量
69         for (String label : labels) {
70             for (FeaturePoint featurePoint : featurePoints) {
71                 if (label.equals(featurePoint.getKey())) {
72                     denominator *= (featurePoint.getP() / coughNum);
73                     for (FeaturePoint feature : map.get(label)) {
74                         if (label.equals(feature.getKey())) {
75                             numerator *= (feature.getP() / dataSize);
76                         }
77                     }
78                 }
79             }
80         }
81 
82         return denominator / numerator;
83     }
84 
85 }
 1 import com.google.common.collect.Lists;
 2 
 3 import java.util.List;
 4 import java.util.Scanner;
 5 
 6 /**
 7  * *********************************************************
 8  * <p/>
 9  * Author:     XiJun.Gong
10  * Date:       2016-09-01 14:58
11  * Version:    default 1.0.0
12  * Class description:
13  * <p/>
14  * *********************************************************
15  */
16 public class Main {
17 
18     public static void main(String args[]) {
19 
20         Scanner scanner = new Scanner(System.in);
21         Integer size = scanner.nextInt();
22         Integer row = scanner.nextInt();
23         Bayes bayes = new Bayes();
24         while (scanner.hasNext()) {
25 
26             for (int ro = 0; ro < row; ro++) {
27                 List<String> list = Lists.newArrayList();
28                 for (int i = 0; i < size; i++) {
29                     list.add(scanner.next());
30                 }
31                 bayes.input(list);
32             }
33             List<String> list = Lists.newArrayList();
34             for (int i = 0; i < size - 1; i++) {
35                 list.add(scanner.next());
36             }
37             String judge = scanner.next();
38             System.out.println(bayes.excute(list, judge,row));
39             ;
40         }
41 
42     }
43 }

pom.xml包

    <dependency>
            <groupId>junit</groupId>
            <artifactId>junit</artifactId>
            <version>3.8.1</version>
            <scope>test</scope>
        </dependency>
        <dependency>
            <groupId>com.google.guava</groupId>
            <artifactId>guava</artifactId>
            <version>18.0</version>
        </dependency>

结果:

1 3 6
2 打喷嚏 护士   感冒 
3   打喷嚏 农夫   过敏 
4   头痛  建筑工人 脑震荡 
5   头痛  建筑工人 感冒 
6   打喷嚏 教师   感冒 
7   头痛  教师   脑震荡
8 打喷嚏  建筑工人 感冒
9 0.6666666666666666 
1 3 6
2   打喷嚏 护士   感冒 
3   打喷嚏 农夫   过敏 
4   头痛  建筑工人 脑震荡 
5   头痛  建筑工人 感冒 
6   打喷嚏 教师   感冒 
7   头痛  教师   脑震荡
8 打喷嚏 护士   感冒 
9 1.3333333333333333
 1 2 50
 2 男  长发
 3 男  短发
 4 男  短发
 5 男  短发
 6 男  短发
 7 男  短发
 8 男  短发
 9 男  短发
10 男  短发
11 男  短发
12 男  短发
13 男  短发
14 男  短发
15 男  短发
16 男  短发
17 男  短发
18 男  短发
19 男  短发
20 男  短发
21 男  长发
22 女  长发
23 女  长发
24 女  长发
25 女  长发
26 女  长发
27 女  长发
28 女  长发
29 女  长发
30 女  长发
31 女  长发
32 女  长发
33 女  长发
34 女  长发
35 女  长发
36 女  长发
37 女  长发
38 女  长发
39 女  长发
40 女  长发
41 女  长发
42 女  长发
43 女  长发
44 女  长发
45 女  长发
46 女  长发
47 女  长发
48 女  长发
49 女  长发
50 女  长发
51 女  长发
52             
53 长发 男
54 0.06250000000000001
 1 2 50
 2 男  长发
 3 男  短发
 4 男  短发
 5 男  短发
 6 男  短发
 7 男  短发
 8 男  短发
 9 男  短发
10 男  短发
11 男  短发
12 男  短发
13 男  短发
14 男  短发
15 男  短发
16 男  短发
17 男  短发
18 男  短发
19 男  短发
20 男  短发
21 男  长发
22 女  长发
23 女  长发
24 女  长发
25 女  长发
26 女  长发
27 女  长发
28 女  长发
29 女  长发
30 女  长发
31 女  长发
32 女  长发
33 女  长发
34 女  长发
35 女  长发
36 女  长发
37 女  长发
38 女  长发
39 女  长发
40 女  长发
41 女  长发
42 女  长发
43 女  长发
44 女  长发
45 女  长发
46 女  长发
47 女  长发
48 女  长发
49 女  长发
50 女  长发
51 女  长发
52 长发 女
53 0.9375

 利用贝叶斯进行分类?

  1 import com.google.common.collect.ArrayListMultimap;
  2 import com.google.common.collect.Lists;
  3 import com.google.common.collect.Multimap;
  4 
  5 import java.util.Collection;
  6 import java.util.List;
  7 
  8 /**
  9  * *********************************************************
 10  * <p/>
 11  * Author:     XiJun.Gong
 12  * Date:       2016-08-31 15:48
 13  * Version:    default 1.0.0
 14  * Class description:
 15  * <p/>
 16  * *********************************************************
 17  */
 18 
 19 public class Bayes {
 20     private Multimap<String, FeaturePoint> map = null;
 21     private List<String> featurePool = null;
 22 
 23     public Bayes() {
 24         map = ArrayListMultimap.create();
 25         featurePool = Lists.newArrayList();
 26     }
 27 
 28     public void add(String label) {
 29         featurePool.add(label);
 30     }
 31 
 32     /*喂数据*/
 33     public void input(List<String> labels) {
 34 
 35         for (String key : labels) {
 36             Collection<FeaturePoint> features = map.get(key);
 37             for (String value : labels) {
 38                 if (features == null || features.size() < 1) {
 39                     map.put(key, new FeaturePoint(value));
 40                     continue;
 41                 }
 42                 boolean tag = false;
 43                 for (FeaturePoint feature : features) {
 44                     if (feature.getKey().equals(value)) {
 45                         Double num = feature.getP() + 1;
 46                         map.remove(key, feature);
 47                         map.put(key, new FeaturePoint(value, num));
 48                         tag = true;
 49                         break;
 50                     }
 51                 }
 52                 if (!tag)
 53                     map.put(key, new FeaturePoint(value));
 54             }
 55         }
 56     }
 57 
 58     /*最符合那个分类*/
 59     public String excute(List<String> labels, Integer dataSize) {
 60 
 61         Double max = -999999999d;
 62         String max_obj = null;
 63         List<Double> ans = Lists.newArrayList();
 64         for (String label : featurePool) {
 65             Double p = excute(labels, label, dataSize);
 66             ans.add(p);
 67             if (max < p) {
 68                 max_obj = label;
 69                 max = p;
 70             }
 71         }
 72         return max_obj;
 73     }
 74 
 75     /*构造模型*/
 76     public Double excute(final List<String> labels, final String judge, Integer dataSize) {
 77 
 78         Double denominator = 1d;    //分母
 79         Double numerator = 1d;      //分子
 80         Double coughNum = 0d;
 81        /*选择相关性分子*/
 82         Collection<FeaturePoint> featurePoints = map.get(judge);
 83         for (FeaturePoint featurePoint : featurePoints) {
 84             if (judge.equals(featurePoint.getKey())) {
 85                 coughNum = featurePoint.getP();
 86                 denominator *= (featurePoint.getP() / dataSize);
 87                 break;
 88             }
 89         }
 90        /*O(n^3)*/
 91         Integer size = featurePoints.size() - 1; //容量
 92         for (String label : labels) {
 93             for (FeaturePoint featurePoint : featurePoints) {
 94                 if (label.equals(featurePoint.getKey())) {
 95                     denominator *= (featurePoint.getP() / coughNum);
 96                     for (FeaturePoint feature : map.get(label)) {
 97                         if (label.equals(feature.getKey())) {
 98                             numerator *= (feature.getP() / dataSize);
 99                         }
100                     }
101                 }
102             }
103         }
104 
105         return denominator / numerator;
106     }
107 
108 }
 1 import com.google.common.collect.Lists;
 2 
 3 import java.util.List;
 4 import java.util.Scanner;
 5 
 6 /**
 7  * *********************************************************
 8  * <p/>
 9  * Author:     XiJun.Gong
10  * Date:       2016-09-01 14:58
11  * Version:    default 1.0.0
12  * Class description:
13  * <p/>
14  * *********************************************************
15  */
16 public class Main {
17 
18     public static void main(String args[]) {
19 
20         Scanner scanner = new Scanner(System.in);
21         Integer size = scanner.nextInt();
22         Integer row = scanner.nextInt();
23         Integer category = scanner.nextInt();
24         while (scanner.hasNext()) {
25             Bayes bayes = new Bayes();
26             for (int ro = 0; ro < row; ro++) {
27                 List<String> list = Lists.newArrayList();
28                 for (int i = 0; i < size; i++) {
29                     list.add(scanner.next());
30                 }
31                 bayes.input(list);
32             }
33             List<String> list = Lists.newArrayList();
34             for (int i = 0; i < size - 1; i++) {
35                 list.add(scanner.next());
36             }
37             for (int i = 0; i < category; i++) {
38                 bayes.add(scanner.next());
39             }
40             System.out.println(bayes.excute(list, row));
41         }
42 
43     }
44 }

结果:

 1 2 50 2
 2 男  长发
 3 男  短发
 4 男  短发
 5 男  短发
 6 男  短发
 7 男  短发
 8 男  短发
 9 男  短发
10 男  短发
11 男  短发
12 男  短发
13 男  短发
14 男  短发
15 男  短发
16 男  短发
17 男  短发
18 男  短发
19 男  短发
20 男  短发
21 男  长发
22 女  长发
23 女  长发
24 女  长发
25 女  长发
26 女  长发
27 女  长发
28 女  长发
29 女  长发
30 女  长发
31 女  长发
32 女  长发
33 女  长发
34 女  长发
35 女  长发
36 女  长发
37 女  长发
38 女  长发
39 女  长发
40 女  长发
41 女  长发
42 女  长发
43 女  长发
44 女  长发
45 女  长发
46 女  长发
47 女  长发
48 女  长发
49 女  长发
50 女  长发
51 女  长发
52 长发
53 男 女
54 女
 1 2 50 2
 2 男  长发
 3 男  短发
 4 男  短发
 5 男  短发
 6 男  短发
 7 男  短发
 8 男  短发
 9 男  短发
10 男  短发
11 男  短发
12 男  短发
13 男  短发
14 男  短发
15 男  短发
16 男  短发
17 男  短发
18 男  短发
19 男  短发
20 男  短发
21 男  长发
22 女  长发
23 女  长发
24 女  长发
25 女  长发
26 女  长发
27 女  长发
28 女  长发
29 女  长发
30 女  长发
31 女  长发
32 女  长发
33 女  长发
34 女  长发
35 女  长发
36 女  长发
37 女  长发
38 女  长发
39 女  长发
40 女  长发
41 女  长发
42 女  长发
43 女  长发
44 女  长发
45 女  长发
46 女  长发
47 女  长发
48 女  长发
49 女  长发
50 女  长发
51 女  长发
52 短发
53 男 女
54 男