统计学习方法之感知机1.感知机模型2.学习策略3.学习算法4.源代码
时间:2022-05-08
本文章向大家介绍统计学习方法之感知机1.感知机模型2.学习策略3.学习算法4.源代码,主要内容包括1.感知机模型、2.学习策略、3.学习算法、4.源代码、基本概念、基础应用、原理机制和需要注意的事项等,并结合实例形式分析了其使用技巧,希望通过本文能帮助到大家理解应用这部分内容。
1.感知机模型
- 在机器学习中,感知机(perceptron)是二分类的线性分类模型,属于监督学习算法。输入为实例的特征向量,输出为实例的类别(取+1和-1)。感知机对应于输入空间中将实例划分为两类的分离超平面。感知机旨在求出该超平面,为求得超平面导入了基于误分类的损失函数,利用梯度下降法 对损失函数进行最优化(最优化)。感知机的学习算法具有简单而易于实现的优点,分为原始形式和对偶形式。感知机预测是用学习得到的感知机模型对新的实例进行预测的,因此属于判别模型。感知机由Rosenblatt于1957年提出的,是神经网络和支持向量机的基础。
- 假设输入空间(特征向量)为X⊆Rn,输出空间为Y={-1, +1}。输入x∈X表示实例的特征向量,对应于输入空间的点;输出y∈Y表示示例的类别。由输入空间到输出空间的函数为
称为感知机。其中,参数w叫做权值向量weight,b称为偏置bias。w⋅x表示w和x的点积
sign为符号函数,即
- 在二分类问题中,f(x)的值(+1或-1)用于分类x为正样本(+1)还是负样本(-1)。感知机是一种线性分类模型,属于判别模型。我们需要做的就是找到一个最佳的满足w⋅x+b=0的w和b值,即分离超平面(separating hyperplane)。如下图,一个线性可分的感知机模型
中间的直线即w⋅x+b=0这条直线。
2.学习策略
- 感知机学习算法本身是误分类驱动的,因此我们采用随机梯度下降法。首先,任选一个超平面w0和b0,然后使用梯度下降法不断地极小化目标函数
3.学习算法
感知机学习算法={原始形式和对偶形式}
3.1原始形式
- 输入:T={(x1,y1),(x2,y2)...(xN,yN)}(其中xi∈X=Rn,yi∈Y={-1, +1},i=1,2...N,学习速率为η) 输出:w, b;感知机模型f(x)=sign(w·x+b) (1) 初始化w0,b0,权值可以初始化为0或一个很小的随机数 (2) 在训练数据集中选取(x_i, y_i) (3) 如果yi(w xi+b)≤0 w = w + ηy_ix_i b = b + ηy_i (4) 转至(2),直至训练集中没有误分类点
4.源代码
- 问题描述
- data.csv
1,1,-1
0,1,-1
3,3,1
4,3,1
2,0.5,-1
3,2,1
4,4,1
1,2,-1
- percetron.py
"""
# -*- coding: utf-8 -*-
"""
Created on Fri Dec 8 14:24:10 2017
@author: jasonhaven
"""
import numpy as np
import matplotlib.pyplot as plt
def sign(x,w,b):
'''
x:特征向量
y:标记
w:权值向量
b:偏置
功能:计算 y=w*x+b
'''
y=np.dot(x,w)+b#返回只有一个元素的ndarray
return int(y)
def train(X,Y,w,b,yita):
'''
X:n维特征向量
y:n维标记向量
w:权值向量初始
b:偏置初始
yita:学习率
功能:训练模型,计算优化模型参数(感知机算法的原始形式)
'''
flag = False #损失函数是否最小
while not flag:
count_error=0
for i,xi in enumerate(X):
yi=Y[i]
signy=sign(xi,w,b)
if signy*yi<=0:
w+=(yita*yi*xi).reshape(w.shape)
b+=yita*yi
count_error+=1
if count_error==0:
flag=True
return w,b
def draw(train_datas,sign_of_train_datas,w,b):
plt.figure('percetron')
#设置横坐标
x=np.linspace(0,6,100)
#w[0]*x[0]+w[1]*x[1]+b=0
#计算函数值
y=-(w[0]*x+b)/w[1]
#绘制函数
plt.plot(x,y,color='r')
#绘制数据集
for i in range(len(train_datas)):
if(sign_of_train_datas[i]==1):
plt.scatter(train_datas[i][0],train_datas[i][1],s=100)
else:
plt.scatter(train_datas[i][0],train_datas[i][1],marker='x',s=100)
plt.title('percetron')
plt.xlabel('x')
plt.ylabel('y')
plt.show()
def read_from_csv():
file='./data.csv'
content=np.loadtxt(file,delimiter=',')
train_datas=[]
sign_of_train_datas=[]
for line in content:
train_datas.append(line[0:2])
sign_of_train_datas.append(line[-1:])
return train_datas,sign_of_train_datas
if __name__=='__main__':
#数据集
#train_datas=[[3,3],[4,3],[1,1],[2,3],[4,5],[1,2]]
#ign_of_train_datas=[1,1,-1,1,1,-1]
train_datas,sign_of_train_datas=read_from_csv()
#构造输入实例
X=np.array(train_datas)
Y=np.array(sign_of_train_datas)
#设置初始w0,b0
w0=np.zeros((X.shape[1],1))
b0=0
#设置学习率
yita=1
#训练模型
w,b=train(X,Y,w0,b0,yita)
print(w,b)
draw(train_datas,sign_of_train_datas,w,b)
- 运行结果
- 【重磅】微软Facebook联手发布AI生态系统,CNTK+Caffe2+PyTorch挑战TensorFlow
- hduoj-----(1068)Girls and Boys(二分匹配)
- 使用Django suit或Bootstrap美化admin模板
- hdu---------(1026)Ignatius and the Princess I(bfs+dfs)
- hdu-----(1113)Word Amalgamation(字符串排序)
- HDUoj-------(1128)Self Numbers
- cf------(round 2)A. Winner
- cf------(round)#1 C. Ancient Berland Circus(几何)
- MySQL配置TokuDB的简单总结
- cf------(round)#1 B. Spreadsheets(模拟)
- sysbench压测MyCAT的shell脚本
- qemu-kvm中vcpu虚拟化到底是咋整的?
- 【给 iOS 开发者】人工智能在 iOS 开发上的应用和机会
- 【Python】Selenium辅助海量基金数据获取
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- 详解android 人脸检测你一定会遇到的坑
- Android实战RecyclerView头部尾部添加方法示例
- android实现多线程断点续传功能
- Android 8.0 中如何实现视频通话的画中画模式的示例
- Android7.0开发实现Launcher3去掉应用抽屉的方法详解
- Android利用Paint自定义View实现进度条控件方法示例
- 前端科普系列(5):ESLint - 守住优雅的护城河
- 图的储存方式,链式前向星最简单实现方式 (边集数组)
- 技术前刊:PostgreSQL12 COPY和bulkloading提升
- 疯子的算法总结(八) 最短路算法+模板
- POJ - 2387 Til the Cows Come Home (最短路入门)
- POJ - 3074 Sudoku (搜索)剪枝+位运算优化
- C语言rand随机函数问题
- HDU - 1253 胜利大逃亡(搜索)
- Android7.0版本影响开发的改进分析