预测随机机器学习算法实验的重复次数
许多随机机器学习算法的一个问题是同一数据上相同算法的不同运行会返回不同的结果。 这意味着,当进行实验来配置随机算法或比较算法时,必须收集多个结果,并使用平均表现来总结模型的技能。 这就提出了一个问题,即一个实验的重复次数是否足以充分描述一个给定问题的随机机器学习算法的技巧。 通常建议使用30个或更多个重复,甚至100个。一些从业者使用数千个重复,似乎超出了收益递减的想法。 在本教程中,您将探索统计方法,您可以使用它们来估计正确的重复次数,以有效地表征随机机器学习算法的性能。 本教程假定您有一个工作的Python 2或3 SciPy环境安装NumPy,熊猫和Matplotlib。
1.生成数据
第一步是生成数据。 我们将假设我们将一个神经网络或其他随机算法放入一个训练数据集1000次,并在数据集上收集了最终的RMSE分数。我们将进一步假设数据是正态分布的,这是我们将在本教程中使用的分析类型的要求。 检查您的结果分布; 结果往往是高斯分布。 我们将会分析产生的结果。这是有用的,因为我们将知道真正的人口平均数和标准误差,这是我们在真实的情况下不知道的。 我们将使用60为平均分,标准偏差是10。 以下代码生成1000个随机结果的样本,并将其保存到名为results.csv的CSV文件中。 我们使用seed()函数来生成随机数生成程序,以确保每次运行这个代码时总是得到相同的结果。然后我们使用normal()函数生成高斯随机数,并使用savetxt()函数保存ASCII格式的数组。
from numpy.random import seed
from numpy.random import normal
from numpy import savetxt
# define underlying distribution of results
mean = 60
stev = 10
# generate samples from ideal distribution
seed(1)
results = normal(mean, stev, 1000)
# save to ASCII file
savetxt('results.csv', results)
您现在应该有一个名为results.csv的文件,其中包含我们假装随机算法测试工具的1000个最终结果。 以下是文件的最后10行。
...
6.160564991742511864e+01
5.879850024371251038e+01
6.385602292344325548e+01
6.718290735754342791e+01
7.291188902850875309e+01
5.883555851728335995e+01
3.722702003339634302e+01
5.930375460544870947e+01
6.353870426882840405e+01
5.813044983467250404e+01
2.基础分析
当我们有一个成果的人口后第一步是做一些基本的统计分析,看看我们有什么。 三个基本分析的有用工具包括: 1.计算汇总统计,如平均值,标准偏差和百分位数。 2.使用框须图来查看数据的传播。 3.使用直方图查看数据的分布。 下面的代码执行这个基本的分析。首先加载results.csv,计算汇总统计量,并显示图表。
from pandas import DataFrame
from pandas import read_csv
from numpy import mean
from numpy import std
from matplotlib import pyplot
# load results file
results = read_csv('results.csv', header=None)
# descriptive stats
print(results.describe())
# box and whisker plot
results.boxplot()
pyplot.show()
# histogram
results.hist()
pyplot.show()
运行示例:首先打印统计信息。 我们可以看到,该算法的平均成绩约为60.3个单位,标准偏差约为9.8。 如果我们假设分数是最小化分数,如RMSE,我们可以看到最差的成绩是99.5,最好的成绩是大约29.4。
count 1000.000000
mean 60.388125
std 9.814950
min 29.462356
25% 53.998396
50% 60.412926
75% 67.039989
max 99.586027
创建框须图来总结数据的传播,显示中间的50%(框),离群值(点)和中位数(绿线)。 我们可以看到,即使在中位数附近,结果的散布也是合理的。
最后,创建结果的直方图。我们可以看到高斯分布的贝尔曲线形状,这是一个好兆头,因为它意味着我们可以使用标准的统计工具。 我们看不到任何明显的分配偏差; 它似乎以60左右为中心。
3.重复次数的影响
我们有很多的结果,准确的说有1000个。 这可能远远超过我们需要的结果,但是还是不够的。 我们怎么知道? 我们可以通过将实验的重复次数与这些重复的平均分数进行比较来获得一个初步的想法。 我们预计随着实验重复次数的增加,平均得分将迅速稳定。它应该经历一个最初混乱但最后趋于稳定的过程。 以下是代码。
from pandas import DataFrame
from pandas import read_csv
from numpy import mean
from matplotlib import pyplot
import numpy
# load results file
results = read_csv('results.csv', header=None)
values = results.values
# collect cumulative stats
means = list()
for i in range(1,len(values)+1):
data = values[0:i, 0]
mean_rmse = mean(data)
means.append(mean_rmse)
# line plot of cumulative values
pyplot.plot(means)
pyplot.show()
这个段期间会得到一些混乱没有规律的平均结果,经过前200次重复,它变得稳定。在600次重复之后,它似乎变得更加稳定。
我们可以放大图表中前500次重复,看看能否更好地了解发生了什么。 我们还可以叠加最终的平均分数(来自所有1000次运行的平均值),并尝试找到收益递减点。
from pandas import DataFrame
from pandas import read_csv
from numpy import mean
from matplotlib import pyplot
import numpy
# load results file
results = read_csv('results.csv', header=None)
values = results.values
final_mean = mean(values)
# collect cumulative stats
means = list()
for i in range(1,501):
data = values[0:i, 0]
mean_rmse = mean(data)
means.append(mean_rmse)
# line plot of cumulative values
pyplot.plot(means)
pyplot.plot([final_mean for x in range(len(means))])
pyplot.show()
橙色线显示所有1000次运行的平均值。 我们可以看到,100次运行可能是停止的一个好点,在400次可能会有一个更精致的结果,但只更精确一点点。
4.计算标准误差
标准误差是计算“样本平均值”与“总体均值”的差异。 这与描述样本中观察值的平均变化量的标准偏差不同。 标准误差可以计算如下:
standard_error = sample_standard_deviation / sqrt(number of repeats)
在这种情况下,模型得分的样本的标准偏差除以总重复次数的平方根。 我们期望标准误差随着实验的重复次数减少。 给出结果,我们可以从每个重复序列的总体平均值计算样本平均值的标准误差。以下提供完整的代码清单。
from pandas import read_csv
from numpy import std
from numpy import mean
from matplotlib import pyplot
from math import sqrt
# load results file
results = read_csv('results.csv', header=None)
values = results.values
# collect cumulative stats
std_errors = list()
for i in range(1,len(values)+1):
data = values[0:i, 0]
stderr = std(data) / sqrt(len(data))
std_errors.append(stderr)
# line plot of cumulative values
pyplot.plot(std_errors)
pyplot.show()
创建标准误差与重复次数的折线图。 我们可以看到,正如预期的那样,随着重复次数的增加,标准误差降低。我们也可以看到有一个可以接受的错误点,比如说一两个单位。 标准误差的单位与模型技能的单位相同。
我们可以重新创建上面的图表,并绘制0.5和1个单位作为指导,可以用来找到一个可以接受的错误级别。
from pandas import read_csv
from numpy import std
from numpy import mean
from matplotlib import pyplot
from math import sqrt
# load results file
results = read_csv('results.csv', header=None)
values = results.values
# collect cumulative stats
std_errors = list()
for i in range(1,len(values)+1):
data = values[0:i, 0]
stderr = std(data) / sqrt(len(data))
std_errors.append(stderr)
# line plot of cumulative values
pyplot.plot(std_errors)
pyplot.plot([0.5 for x in range(len(std_errors))], color='red')
pyplot.plot([1 for x in range(len(std_errors))], color='red')
pyplot.show()
我们可以看到,如果1的标准误差是可以接受的,那么大约100次重复就足够了。如果0.5的标准误差是可以接受的,则可能有300-350次重复就足够了。
我们也可以使用标准误差作为平均模型技能的置信区间。 例如,未知人口平均模型的性能有95%的可能性在上限和下限之间。 请注意,此方法仅适用于适度和大量的重复,例如20或更多。 置信区间可以定义为:
sample mean +/- (standard error * 1.96)
我们可以计算该置信区间,并将其添加到每个重复序列的样本平均值作为误差线。 以下提供完整的代码清单。
from pandas import read_csv
from numpy import std
from numpy import mean
from matplotlib import pyplot
from math import sqrt
# load results file
results = read_csv('results.csv', header=None)
values = results.values
# collect cumulative stats
means, confidence = list(), list()
n = len(values) + 1
for i in range(20,n):
data = values[0:i, 0]
mean_rmse = mean(data)
stderr = std(data) / sqrt(len(data))
conf = stderr * 1.96
means.append(mean_rmse)
confidence.append(conf)
# line plot of cumulative values
pyplot.errorbar(range(20, n), means, yerr=confidence)
pyplot.plot(range(20, n), [60 for x in range(len(means))], color='red')
pyplot.show()
被创建的线条图显示每个重复次数的平均样本值,并显示每个平均值的置信区间,以收集未知的底层人口平均值。 一条读线显示实际的人口平均值(仅因为我们在本教程开始时设计了模型技巧得分)。 作为总体均值的代理,你可以在1000次重复或更多的情况下添加最后一个样本均值。 误差条模糊了平均分数的线。我们可以看到平均值高估了总体均值,但95%置信区间掌握了总体均值。 请注意,95%置信区间意味着,在100个样本中,95%的时间间隔将会捕获总体均值,而5个样本均值和置信区间则不会。 我们可以看到,随着标准误差的减小,95%置信区间确实会随着重复的增加而增加,但可能会有超过500次重复的收益递减。
我们可以通过放大此图形来更清楚地了解发生了什么,突出显示从20到200的重复。
from pandas import read_csv
from numpy import std
from numpy import mean
from matplotlib import pyplot
from math import sqrt
# load results file
results = read_csv('results.csv', header=None)
values = results.values
# collect cumulative stats
means, confidence = list(), list()
n = 200 + 1
for i in range(20,n):
data = values[0:i, 0]
mean_rmse = mean(data)
stderr = std(data) / sqrt(len(data))
conf = stderr * 1.96
means.append(mean_rmse)
confidence.append(conf)
# line plot of cumulative values
pyplot.errorbar(range(20, n), means, yerr=confidence)
pyplot.plot(range(20, n), [60 for x in range(len(means))], color='red')
pyplot.show()
在创建的线条图中,我们可以清楚地看到样本平均值和周围的对称误差线。该图确实能够更好地显示样本平均值的偏差。
进一步阅读
没有多少资源将所需的统计数据与使用随机算法的计算实验方法联系起来。 关于我发现的主题的最好的书是: Empirical Methods for Artificial Intelligence,Cohen,1995, 如果对这篇文章感兴趣,我强烈推荐这本书。
下面是一些额外的文章,您可能会发现有用的: Standard Error Confidence Interval 68–95–99.7 rule 此文为编译作品,作者 Jason Brownlee,原网站http://machinelearningmastery.com
- 二分查找
- 译文 | Android 开发中利用异步来优化运行速度和性能
- 算法基础6:二叉树查找
- 通过UDP广播实现Android局域网Peer Discovering
- tensorflow读取数据-tfrecord格式
- 用Python使用C语言程序(Windows平台)
- 译文 | 在使用过采样或欠采样处理类别不均衡数据后,如何正确做交叉验证?
- 花式解释AutoEncoder与VAE
- 用CNN做句子分类:CNN Sentence Classification (with Theano code)
- MySQL与Python的交互
- 实时Android语音对讲系统架构
- ElasticSearch优化系列二:机器设置(内存)
- Tensorflow之 CNN卷积神经网络的MNIST手写数字识别
- 你听过算法也是可以贪心的吗?
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- 利用python读取WORD文档中的创建者信息
- LeetCode-2.两数相加 使用链表加法实现
- Spring学习(2):Spring Bean管理(上)
- 聊聊dubbo-go的TpsLimitFilter
- spring,springBoot事件
- LeetCode-3.无重复字符的最长子串 利用一个整形数组+ASCII码实现滑动窗口
- 算法不想学(二): 堆排序和top k
- 利用python读取EXCEL文档中的创建者信息
- R语言工具变量与两阶段最小二乘法
- 贼好用的Java工具类库,GitHub星标13k+,很是厉害!
- CPU密集型任务会阻塞 Node.js 吗
- Let's Encrypt 配置 HTTPS 免费泛域名证书
- 如何删除重复数据(二)
- 如何删除重复数据
- SQL 生成斐波那契数列