从零开始学人工智能-Python·决策树(三)·节点
作者:射命丸咲Python 与 机器学习 爱好者
知乎专栏:https://zhuanlan.zhihu.com/carefree0910-pyml
个人网站:http://www.carefree0910.com
本章用到的 GitHub 地址:
https://github.com/carefree0910/MachineLearning/blob/master/Zhihu/CvDTree/one/CvDTree.py
本章用到的数学相关知识:
https://zhuanlan.zhihu.com/p/24501172
上一章我们把 node 的结构搭好了,这一章要做的就是塞东西进去。为此,我们不妨先看看我们需要实现什么:
一个 fit 函数,它能够根据输入的数据递归生成一颗决策树
一个 handle_terminate 函数,它在 node 成为 leaf 时调用,用于更新它爸爸和它爸爸的爸爸和……等等的信息
一个 prune 函数,用于剪枝
一个 view 函数,用于可视化
大提上来说就是这两点,剩下的就是一些细节。我们分开来说说怎么去实现它们
fit 函数
先说一下大概的流程:
接受数据和相应的标签
判断该 node 是否应该被当做 leaf;若是,则 return,否则继续往下走算出各维度的条件熵,记录下最好的条件熵、信息增益和此时关注的数据维度
比如说,如果输入的数据和标签为:
那么通过计算各个维度的条件熵和信息增益可知,此时该 node 关注的数据维度应该是第一维、也就是 A 对应的那一维。直观来说,这意味着 A 提供的信息量最大(事实上在这个栗子中,A 和 Label 是一样的)
根据信息增益判断是否终止(比如在 ID3 中,如果信息增益小于阈值的话就直接终止。这种判断方法会有比较严重的缺陷,观众老爷们可以想一想为什么 ( σ'ω')σ 【提示:异或数据集】)
根据所选的数据维度的各个特征把数据集切分成几份,分别喂给新的 node、递归,同时把这些 node 记录在自己的 children 里面
由于利用了递归,感觉还是一个比较干净利落的实现。下面就贴一些核心的代码,完整的实现可以参见这里
计算各个维度的条件熵和信息增益,这里就要用到准则章节的东西了
递归
handle_terminate 函数
如果童鞋们还记得我们 node 的结构的话,大概就会知道当一个 node 成为 leaf 后、需要做的事情有两个:
只要分别实现它们就好了:
其中
计算该 leaf 属于哪一类
更新它列祖列宗的 leafs 变量
prune 函数
需要指出的是,node 的 prune 函数不是决策树的剪枝算法、而是会在决策树的剪枝算法中被调用。它仅仅是为了该 node 的所有子孙都切了而已(喂
先说说流程,核心思想其实就是把该 node 变成一个 leaf:
判断该 node 应该属于哪一类
把该 node 的 leafs 中的 leaf 从该 node 的列祖列宗中的 leafs 中删除
把该 node 存进列祖列宗的 leafs 中
把该 node 自身及其所有子孙打上“已被剪枝”的标签
接下来是实现:
其中打标签用的 self.mark_pruned 函数的定义如下:
view 函数
基本思路很简单:如果自己是 leaf、就直接输出相关信息,否则在输出自己相关信息的同时、还要调用自己所有 children 的 view 函数。以下是实现:
这一章有点长,稍微总结一下:
决策树的生长关键是靠递归。当 node 接收一个数据和标签时,它会选出数据的某个维度、记录下来,然后会根据该维度的各个特征将数据、标签进行划分,分别喂给新的 node、从而能够递归下去
在 node 被判定应该是 leaf 时,要判断它属于哪一类并更新它列祖列宗的 leafs 变量
node 的 prune 函数是用来把 node 变成 leaf 并更新结构的,它本身不是决策树的剪枝算法、但它会在决策树的剪枝算法中被调用
下一章我们就要说说怎么建立一个框架以利用这些 node 来搭建一颗真正的决策树了。可能有童鞋已经敏锐地发现:不就只剩一个剪枝算法没有实现了吗?
事实上正是如此。下一章的框架确实只额外地实现了剪枝算法,剩下的都是封装的活儿
希望观众老爷们能够喜欢~
- SQLite 带你入门
- Windows下Nginx+Mysql+Php(wnmp)环境搭建
- LNMP源码编译安装(centos7+nginx1.9+mysql5.6+php7)
- MySQL SHOW PROFILE(剖析报告)的查看
- PHP连接MySQL数据库的三种方式(mysql、mysqli、pdo)
- 如何查看已经安装的nginx、apache、mysql和php的编译参数
- 连仕彤博客Centos7安装Mysql数据库
- sql server 2008 操作数据表
- sql server 使用函数辅助查询
- sql server存储过程编程
- sql server 2008 数据库的完整性约束
- sql server T-SQL 基础
- sql server 触发器
- T-SQL 查询、修改数据表
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- pytest 测试框架学习(11):pytest.raises
- Hibernate第二天:Hibernate的一级缓存、其他的API
- pytest 测试框架学习(12):pytest.deprecated_call
- Pinstaller(Python打包为exe文件
- pytest 测试框架学习(14):pytest.warns
- ImportError: /lib64/libm.so.6: version `CXXAB_1.3.8.' not found (required by /usr/local/python37/lib
- pytest 测试框架学习(15):pytest.freeze_includes
- Linux: scp文件,目录上传下载标准版
- Hibernate第三天:Hibernate的一对多配置、Hibernate的多对多的配置
- Git: 掉坑记 -- git reset 杀手
- ModuleNotFoundError: No module named 'phkit.pinyin'
- Hibernate第四天:Hibernate的查询方式、抓取策略
- 爬虫抓取博客园前10页标题带有Python关键字(不区分大小写)的文章
- Python爬虫抓取唐诗宋词
- ImportError: /lib64/libm.so.6: version `GLIBC_2.23' not found (required by /usr/local/python37/lib/p