梯度下降法
时间:2019-11-20
本文章向大家介绍梯度下降法,主要包括梯度下降法使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。
简介
梯度下降法是迭代法的一种,可以用于求解最小二乘问题(线性和非线性都可以),在求解机器学习算法的模型参数,梯度下降是最常采用的方法之一,在求解损失函数的最小值时,可以通过梯度下降法来一步步的迭代求解
- 不是一个机器学习算法
- 是一种基于搜索的最优化方法
- 最小化损失函数
- 最大化一个效用函数(梯度上升法)
模型
$J=\theta ^{2}+b$
定义了一个损失函数以后,参数 $\theta $ 对应的损失函数 $J$ 的值对应的示例图,需要找到使得损失函数值 $J$ 取得最小值对应的 $\theta $
- 首先随机取一个 $\theta $,对 $\theta $求导乘 $\eta $,得到一个导数gradient
- 将之前的 $\theta $ 存为last_theta
- 将 $\theta $减去$\eta $*gradient得到的值存入 $\theta $
- 将 $\theta $与last_theta分别代入公式后得到两个函数值相减,如果小于指定的一个极小值,则说明已找到了最小的 $\theta $,否则重复第1个步聚,对 $\theta $ 求导,依次完成,直到差值小于极小值
$\eta $ 超参数的作用
- $\eta $ 称为学习率也称为步长
- $\eta $ 的取值影响获得最优解的速度
- $\eta $ 取值不合适,可能得不到最优解
- $\eta $ 是梯度下降的一个超参数
η太小,会减慢收敛学习速度
η太大,会导致不收敛
局部最优与全局最优
最优化问题
就是在复杂环境中遇到的许多可能的决策中,挑选“最好”的决策的科学机器学习中选择最小的参数满足分类与预测的要求
全局最优:若一项决策和所有解决该问题的决策相比是最优的,对目标函数去最大值还是最小值,损失函数只有一个最优解
局部最优:不要求在所有决策中是最好的
解决方案
- 多次运行,随机化初始点
- 梯度下降法的初始点也是一个超参数
目标
使 $\sum_{i=1}^{m}(y^{(i)}-\widehat{y}^{(i)})^{2}$ 尽可能小
代码实现
python代码实现
mport numpy as np import matplotlib.pyplot as plt plot_x = np.linspace(-1.0,6.0,141) plot_y = (plot_x-2.5)**2+1 plt.plot(plot_x,plot_y,c='b') plt.show() #定一个极小值 epsilon = 1e-8 eta = 0.1 def J(thata): return (thata-2.5)**2+1.0 def DJ(thata): return 2*(thata-2.5) thata = 0.0 while True: g = DJ(thata) last_thata = thata thata = thata-g if(abs(J(thata)-J(last_thata))<epsilon): #注意这里不能小于0,如果两个损失函数的值相减小于一个极值,说明已找到 break; print(thata) print(J(thata))#最后一次的值
查看学习率
theta = 0.0 theta_history = [theta] while True: gradient = dJ(theta) last_theta = theta theta = theta - eta * gradient theta_history.append(theta) if(abs(J(theta) - J(last_theta)) < epsilon): break plt.plot(plot_x, J(plot_x)) plt.plot(np.array(theta_history), J(np.array(theta_history)), color="r", marker='+') plt.show()
函数封装
def gradient_descent(inital_theta,eta,epslion=1e-8): theta = inital_theta theta_history.append(theta) while True: g = DJ(theta) last_thata = theta theta = theta - eta * g theta_history.append(theta) if (abs(J(theta) - J(last_thata)) < epsilon): # 注意这里不能小于0,如果两个损失函数的值相减小于一个极值,说明已找到 break; def plot_theta_history(): plt.plot(plot_x, J(plot_x), c='b') # 将x数据传入J函数取得y的值 plt.plot(np.array(theta_history), J(np.array(theta_history)), c='r', marker='+') plt.show() gradient_descent(0.0,eta) plot_theta_history()
调整学习参数
eta = 0.001 theta_history = [] gradient_descent(0, eta) plot_theta_history() eta = 0.8 theta_history = [] gradient_descent(0, eta) plot_theta_history() eta = 1.1 theta_history = [] gradient_descent(0, eta)
迭代次数的调整
def gradient_descent(initial_theta, eta, n_iters = 1e4, epsilon=1e-8): theta = initial_theta i_iter = 0 theta_history.append(initial_theta) while i_iter < n_iters: gradient = dJ(theta) last_theta = theta theta = theta - eta * gradient theta_history.append(theta) if(abs(J(theta) - J(last_theta)) < epsilon): break i_iter += 1 return
多元线性回归中的梯度下降法
原文地址:https://www.cnblogs.com/zry-yt/p/11900542.html
- 轻量级交互数据json格式初探
- Golang语言社区--【基础知识】语言数组
- HDUOJ1086You can Solve a Geometry Problem too
- Golang语言社区--【基础知识】常量
- HDUOJ------1058 Humble Numbers
- MySQL偏移量的一点分析
- HDUOJ------------1051Wooden Sticks
- HDUOJ-----2068RPG的错排
- MySQL创建表失败的问题
- HDUOJ-----1066Last non-zero Digit in N!
- Golang语言社区-【基础知识】切片
- Oracle和MySQL的高可用方案对比(一)
- golang取两个数字之间的随机数
- MySQL 5.5复制升级到5.7的一点简单尝试
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- 链接脚本linker script的妙用
- 【TBase开源版测评】轻松愉快去O选项:TBase
- Tungsten Fabric知识库丨更多组件内部探秘
- TRTC Android端开发接入学习之环境快速搭建(三)
- 为了满足UI小姐姐要求,自己动手实现了Android面包屑效果(支持Fragment联动)
- Jenkins持续集成「编译打包、代码检查、单元测试、环境部署、软件测试」
- 斗鱼直播带你实现:你主播最爱的Android音视频开发
- 深度解析Redis线程模型设计原理
- 聊聊claudb的hash command
- Exceptionless 5.x 无法正常发送邮件的问题解决
- 详解 Linux 中的硬链接与软链接
- 程序员进阶之算法练习(四十七)
- Git 合并多个 commit,保持历史简洁
- Vue开源项目使用探索
- 自定义View | 仿QQ运动步数进度效果