EM算法的python实现的方法步骤
前言:前一篇文章大概说了EM算法的整个理解以及一些相关的公式神马的,那些数学公式啥的看完真的是忘完了,那就来用代码记忆记忆吧!接下来将会对python版本的EM算法进行一些分析。
EM的python实现和解析
引入问题(双硬币问题)
假设有两枚硬币A、B,以相同的概率随机选择一个硬币,进行如下的抛硬币实验:共做5次实验,每次实验独立的抛十次,结果如图中a所示,例如某次实验产生了H、T、T、T、H、H、T、H、T、H,H代表正面朝上。
假设试验数据记录员可能是实习生,业务不一定熟悉,造成a和b两种情况
a表示实习生记录了详细的试验数据,我们可以观测到试验数据中每次选择的是A还是B
b表示实习生忘了记录每次试验选择的是A还是B,我们无法观测实验数据中选择的硬币是哪个
问在两种情况下分别如何估计两个硬币正面出现的概率?
以上的针对于b实习生的问题其实和三硬币问题类似,只是这里把三硬币中第一个抛硬币的选择换成了实习生的选择。
对于已知是A硬币还是B硬币抛出的结果的时候,可以直接采用概率的求法来进行求解。对于含有隐变量的情况,也就是不知道到底是A硬币抛出的结果还是B硬币抛出的结果的时候,就需要采用EM算法进行求解了。如下图:
其中的EM算法的第一步就是初始化的过程,然后根据这个参数得出应该产生的结果。
构建观测数据集
针对这个问题,首先采集数据,用1表示H(正面),0表示T(反面):
#硬币投掷结果 observations = numpy.array([[1,0,0,0,1,1,0,1,0,1], [1,1,1,1,0,1,1,1,0,1], [1,0,1,1,1,1,1,0,1,1], [1,0,1,0,0,0,1,1,0,0], [0,1,1,1,0,1,1,1,0,1]])
第一步:参数的初始化
参数赋初值
第一个迭代的E步
抛硬币是一个二项分布,可以用scipy中的binom来计算。对于第一行数据,正反面各有5次,所以:
#二项分布求解公式 contribution_A = scipy.stats.binom.pmf(num_heads,len_observation,theta_A) contribution_B = scipy.stats.binom.pmf(num_heads,len_observation,theta_B)
将两个概率正规化,得到数据来自硬币A,B的概率:
weight_A = contribution_A / (contribution_A + contribution_B) weight_B = contribution_B / (contribution_A + contribution_B)
这个值类似于三硬币模型中的μ,只不过多了一个下标,代表是第几行数据(数据集由5行构成)。同理,可以算出剩下的4行数据的μ。
有了μ,就可以估计数据中AB分别产生正反面的次数了。μ代表数据来自硬币A的概率的估计,将它乘上正面的总数,得到正面来自硬币A的总数,同理有反面,同理有B的正反面。
#更新在当前参数下A,B硬币产生的正反面次数 counts['A']['H'] += weight_A * num_heads counts['A']['T'] += weight_A * num_tails counts['B']['H'] += weight_B * num_heads counts['B']['T'] += weight_B * num_tails
第一个迭代的M步
当前模型参数下,AB分别产生正反面的次数估计出来了,就可以计算新的模型参数了:
new_theta_A = counts['A']['H']/(counts['A']['H'] + counts['A']['T']) new_theta_B = counts['B']['H']/(counts['B']['H'] + counts['B']['T'])
于是就可以整理一下,给出EM算法单个迭代的代码:
def em_single(priors,observations): """ EM算法的单次迭代 Arguments ------------ priors:[theta_A,theta_B] observation:[m X n matrix] Returns --------------- new_priors:[new_theta_A,new_theta_B] :param priors: :param observations: :return: """ counts = {'A': {'H': 0, 'T': 0}, 'B': {'H': 0, 'T': 0}} theta_A = priors[0] theta_B = priors[1] #E step for observation in observations: len_observation = len(observation) num_heads = observation.sum() num_tails = len_observation-num_heads #二项分布求解公式 contribution_A = scipy.stats.binom.pmf(num_heads,len_observation,theta_A) contribution_B = scipy.stats.binom.pmf(num_heads,len_observation,theta_B) weight_A = contribution_A / (contribution_A + contribution_B) weight_B = contribution_B / (contribution_A + contribution_B) #更新在当前参数下A,B硬币产生的正反面次数 counts['A']['H'] += weight_A * num_heads counts['A']['T'] += weight_A * num_tails counts['B']['H'] += weight_B * num_heads counts['B']['T'] += weight_B * num_tails # M step new_theta_A = counts['A']['H'] / (counts['A']['H'] + counts['A']['T']) new_theta_B = counts['B']['H'] / (counts['B']['H'] + counts['B']['T']) return [new_theta_A,new_theta_B]
EM算法主循环
给定循环的两个终止条件:模型参数变化小于阈值;循环达到最大次数,就可以写出EM算法的主循环了
def em(observations,prior,tol = 1e-6,iterations=10000): """ EM算法 :param observations :观测数据 :param prior:模型初值 :param tol:迭代结束阈值 :param iterations:最大迭代次数 :return:局部最优的模型参数 """ iteration = 0; while iteration < iterations: new_prior = em_single(prior,observations) delta_change = numpy.abs(prior[0]-new_prior[0]) if delta_change < tol: break else: prior = new_prior iteration +=1 return [new_prior,iteration]
调用
给定数据集和初值,就可以调用EM算法了:
print em(observations,[0.6,0.5])
得到
[[0.72225028549925996, 0.55543808993848298], 36]
我们可以改变初值,试验初值对EM算法的影响。
print em(observations,[0.5,0.6])
结果:
[[0.55543727869042425, 0.72225099139214621], 37]
看来EM算法还是很健壮的。如果把初值设为相等会怎样?
print em(observations,[0.3,0.3])
输出:[[0.64000000000000001, 0.64000000000000001], 1]
显然,两个值相加不为1的时候就会破坏这个EM函数。
换一下初值:
print em(observations,[0.99999,0.00001])
输出:[[0.72225606292866507, 0.55543145006184214], 33]
EM算法对于参数的改变还是有一定的健壮性的。
以上是根据前人写的博客进行学习的~可以自己动手实现以下,对于python练习还是有作用的。希望对大家的学习有所帮助,也希望大家多多支持脚本之家。
- 声音分类的迁移学习
- 【死磕Java并发】—– J.U.C之AQS:CLH同步队列
- 使用Python完成你的第一个学习项目
- CA,给了数据库,给了机器,为啥也扩不了容?
- 如何使用Anaconda设置机器学习和深度学习的Python环境
- MQ,互联网架构解耦神器
- 预测随机机器学习算法实验的重复次数
- 服务化了,没想到耦合更加严重?
- 如何在Python中扩展LSTM网络的数据
- 使用Keras的Python深度学习模型的学习率方案
- 全球电脑手机无一幸免,英特尔CPU“漏洞事件”到底多严重?
- 评估Keras深度学习模型的性能
- Python机器学习的练习二:多元线性回归
- 熔断器 Hystrix 源码解析 —— 命令合并执行
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- [剑指Offer]面试题25: 合并两个排序的链表
- 环形链表
- 二十年前做科研你只需要检测一些基因在一些癌症细胞系表达量情况即可
- [剑指Offfer]面试题22: 链表中倒数第k个节点
- Django开发快速入门
- 三阴性乳腺癌表达数据分析笔记之PAM50
- Celery入门
- TO-do api
- 炫酷,Spring Boot + ECharts 实现用户访问地图可视化(附源码)
- 互联网大厂常考算法及套路深度解析
- 2020--Python语法常考知识点
- 为什么你画的Seurat包PCA图与别人的方向不一致?
- 用Python程序模拟300位观众,为5位嘉宾随机投票,最后按照降序排列结果
- Python知识点
- 上盘硬菜,@Transaction源码深度解析 | Spring系列第48篇