Python中基于网格搜索算法优化的深度学习模型分析糖尿病数据
原文链接:http://tecdat.cn/?p=12693
介绍
在本教程中,我们将讨论一种非常强大的优化(或自动化)算法,即网格搜索算法。它最常用于机器学习模型中的超参数调整。我们将学习如何使用Python来实现它,以及如何将其应用到实际应用程序中,以了解它如何帮助我们为模型选择最佳参数并提高其准确性。
先决条件
要遵循本教程,您应该对Python或其他某种编程语言有基本的了解。您最好也具有机器学习的基本知识,但这不是必需的。除此之外,本文是初学者友好的,任何人都可以关注。
安装
要完成本教程,您需要在系统中安装以下库/框架:
它们的安装都非常简单-您可以单击它们各自的网站,以获取各自的详细安装说明。通常,可以使用pip安装软件包:
$ pip install numpy pandas tensorflow keras scikit-learn
如果遇到任何问题,请参考每个软件包的官方文档。
什么是网格搜索?
网格搜索本质上是一种优化算法,可让你从提供的参数选项列表中选择最适合优化问题的参数,从而使“试验和错误”方法自动化。尽管它可以应用于许多优化问题,但是由于其在机器学习中的使用而获得最广为人知的参数,该参数可以使模型获得最佳精度。
假设您的模型采用以下三个参数作为输入:
- 隐藏层数[2,4]
- 每层中的神经元数量[5,10]
- 神经元数[10,50]
如果对于每个参数输入,我们希望尝试两个选项(如上面的方括号中所述),则总计总共2 ^3 = 8个不同的组合(例如,一个可能的组合为[2,5,10])。手动执行此操作会很麻烦。
现在,假设我们有10个不同的输入参数,并且想为每个参数尝试5个可能的值。每当我们希望更改参数值,重新运行代码并跟踪所有参数组合的结果时,都需要从我们这边进行手动输入。网格搜索可自动执行该过程,因为它仅获取每个参数的可能值并运行代码以尝试所有可能的组合,输出每个组合的结果,并输出可提供最佳准确性的组合。
网格搜索实施
让我们将网格搜索应用于实际应用程序。讨论机器学习和数据预处理这一部分不在本教程的讨论范围之内,因此我们只需要运行其代码并深入讨论Grid Search的引入部分即可。
我们将使用Pima印度糖尿病数据集,该数据集包含有关患者是否基于不同属性(例如血糖,葡萄糖浓度,血压等)的糖尿病信息。使用Pandas read_csv()
方法,您可以直接从在线资源中导入数据集。
以下脚本导入所需的库:
from sklearn.model_selection import GridSearchCV, KFoldfrom keras.models import Sequentialfrom keras.layers import Dense, Dropoutfrom keras.wrappers.scikit_learn import KerasClassifierfrom keras.optimizers import Adamimport sysimport pandas as pdimport numpy as np
以下脚本导入数据集并设置数据集的列标题。
df = pd.read_csv(data_path, names=columns)
让我们看一下数据集的前5行:
df.head()
输出:
如你所见,这5行都是用来描述每一列的标签,因此它们对我们没有用。我们将从删除这些非数据行开始,然后将所有NaN
值替换为0:
for col in columns: df[col].replace(0, np.NaN, inplace=True)df.dropna(inplace=True) # Drop all rows with missing values
以下脚本将数据分为变量和标签集,并将标准化应用于数据集:
# Transform and display the training dataX_standardized = scaler.transform(X)
以下方法创建了我们简单的深度学习模型:
def create_model(learn_rate, dropout_rate): # Create model model = Sequential() model.add(Dense(8, input_dim=8, kernel_initializer='normal', activation='relu')) model.add(Dropout(dropout_rate)) model.add(Dense(4, input_dim=8, kernel_initializer='normal', activation='relu')) model.add(Dropout(dropout_rate)) model.add(Dense(1, activation='sigmoid')) # Compile the model adam = Adam(lr=learn_rate) model.compile(loss='binary_crossentropy', optimizer=adam, metrics=['accuracy']) return model
这是加载数据集,对其进行预处理并创建您的机器学习模型所需的所有代码。因为我们只对看到Grid Search的功能感兴趣,所以我没有进行训练/测试拆分,我们将模型拟合到整个数据集。
在下一节中,我们将开始了解Grid Search如何通过优化参数使生活变得更轻松。
在没有网格搜索的情况下训练模型
在下面的代码中,我们将随机决定或根据直觉决定的参数值创建模型,并查看模型的性能:
model = create_model(learn_rate, dropout_rate)
输出:
Epoch 1/1130/130 [==============================] - 0s 2ms/step - loss: 0.6934 - accuracy: 0.6000
正如看到的,我们得到的精度是60.00%
。这是相当低的。
使用网格搜索优化超参数
如果不使用Grid Search,则可以直接fit()
在上面创建的模型上调用方法。但是,要使用网格搜索,我们需要将一些参数传递给create_model()
函数。此外,我们需要使用不同的选项声明我们的网格,我们希望为每个参数尝试这些选项。让我们分部分进行。
首先,我们修改create_model()
函数以接受调用函数的参数:
# Create the modelmodel = KerasClassifier(build_fn=create_model, verbose=1)
现在,我们准备实现网格搜索算法并在其上拟合数据集:
# Build and fit the GridSearchCVgrid = GridSearchCV(estimator=model, param_grid=param_grid, cv=KFold(random_state=seed), verbose=10)
输出:
Best: 0.7959183612648322, using {'batch_size': 10, 'dropout_rate': 0.2, 'epochs': 10, 'learn_rate': 0.02}
在输出中,我们可以看到它为我们提供了最佳精度的参数组合。
可以肯定地说,网格搜索在Python中非常容易实现,并且在人工方面节省了很多时间。您可以列出所有您想要调整的参数,声明要测试的值,运行您的代码,而不必理会。您无需再输入任何信息。找到最佳参数组合后,您只需将其用于最终模型即可。
结论
总结起来,我们了解了什么是Grid Search,它如何帮助我们优化模型以及它带来的诸如自动化的好处。此外,我们学习了如何使用Python语言在几行代码中实现它。为了了解其有效性,我们还训练了带有和不带有Grid Search的机器学习模型,使用Grid Search的准确性提高了19%。
- 3555: [Ctsc2014]企鹅QQ
- 3381: [Usaco2004 Open]Cave Cows 2 洞穴里的牛之二
- 3097: Hash Killer I
- 3390: [Usaco2004 Dec]Bad Cowtractors牛的报复
- 1684: [Usaco2005 Oct]Close Encounter
- 算法模板——Dinic最小费用最大流
- 算法模板——Dinic网络最大流 1
- SQL Server 使用全文索引进行页面搜索
- 2764: [JLOI2011]基因补全
- 1000: A+B Problem(NetWork Flow)
- 博弈论进阶之Multi-SG
- 2929: [Poi1999]洞穴攀行
- SQL Server 执行计划缓存
- 1081: [SCOI2005]超级格雷码
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- 1.2 ribbon-客户端负载均衡
- 1. feign的使用及原理
- hadoop案例实现之WordCount (计算单词出现的频数)
- ribbon源码
- feign源码
- DAO层配置绑定weblogic应用服务器的JNDI导致单元测试失败
- 1.操作系统底层工作的基本原理
- ibatis 日常问题总结
- 2.1 并发编程之java内存模型JMM & synchronize & volatile详解
- 使用intellij idea 查看Java字节码
- 2.2 指令重排&happens-before 原则 & 内存屏障
- 设计模式之代理模式(由浅入深)
- jquery param 数据 数组参数序列化
- 3 CPU缓存一致性协议MESi
- 4. synchronized详解