感知机的股票预测算例及python代码实现 | 山人聊算法 | 5th
前言
本文是感知机入门系列的最后一篇算例及代码,前序文章列表:
- 模型构建:入门感知机:一种二分类模型
- 学习算法:感知机的两种典型学习算法
股价预测的栗子
我们有一只股票(市盈率,每股净收益)的4个样本点:x1(20,2),x2(50,1),x3(10,3),x4(60,0.5)。其所属类别(一年后上涨,一年后下跌)为:{1,-1,1,-1},中1代表上涨,-1代表下跌。如下图所示。我们用感知机模型来做分类,对不同股票的上涨和下跌趋势进行分类预测。
最优化问题的构建
按照原始形式算法(算法的过程见上一篇文章)求解w,b,学习率设为1
(1)取初值
w0=0,b0=0
(2)更新w,b
对x1(20,2),y1(w0x1+b0)=0,未能被正确分类,更新w,b
w1=w0+y1x1=(20,2),b1=b0+y1=1
得到线性模型
w1.x+b1=20x(1)+2x(2)+1
(3)迭代更新w,b
对x1,显然,yi(w1.xi+b)>0,被正确分类,不修改w,b
对x2=(50,1),y2(w1.x3+b1)<0,被误分类,更新w,b
w2=w1+y2.x2=(-30,1),b2=b1+y2=0
得到线性模型
w2.x+b2=-30x(1)+x(2)
对有所点正确分类的模型
如此继续下去,直到
对所有数据点yi(w.xi+b)>0,没有误分类点,损失函数达到极小。
分离超平面为 -10x(1)+91x(2)+34 = 0
感知机模型为 f(x) = sign(-10x(1)+91x(2)+34)
迭代过程
迭代过程如下[a,b]c,其中a为w,b为b,c为(w.xi+b):
[0, 0] 0
[20.0, 2.0] 1
[-30.0, 1.0] 0
[-10.0, 3.0] 1
[10.0, 5.0] 2
[-40.0, 4.0] 1
[-20.0, 6.0] 2
[0.0, 8.0] 3
[-50.0, 7.0] 2
[-30.0, 9.0] 3
[-10.0, 11.0] 4
[10.0, 13.0] 5
[-40.0, 12.0] 4
[-20.0, 14.0] 5
[0.0, 16.0] 6
[-50.0, 15.0] 5
[-30.0, 17.0] 6
[-10.0, 19.0] 7
[10.0, 21.0] 8
[-40.0, 20.0] 7
[-20.0, 22.0] 8
[0.0, 24.0] 9
[-50.0, 23.0] 8
[-30.0, 25.0] 9
[-10.0, 27.0] 10
[10.0, 29.0] 11
[-40.0, 28.0] 10
[-20.0, 30.0] 11
[0.0, 32.0] 12
[-50.0, 31.0] 11
[-30.0, 33.0] 12
[-10.0, 35.0] 13
[10.0, 37.0] 14
[-40.0, 36.0] 13
[-20.0, 38.0] 14
[0.0, 40.0] 15
[-50.0, 39.0] 14
[-30.0, 41.0] 15
[-10.0, 43.0] 16
[10.0, 45.0] 17
[-40.0, 44.0] 16
[-20.0, 46.0] 17
[0.0, 48.0] 18
[-50.0, 47.0] 17
[-30.0, 49.0] 18
[-10.0, 51.0] 19
[10.0, 53.0] 20
[-40.0, 52.0] 19
[-20.0, 54.0] 20
[0.0, 56.0] 21
[-50.0, 55.0] 20
[-30.0, 57.0] 21
[-10.0, 59.0] 22
[10.0, 61.0] 23
[-40.0, 60.0] 22
[-20.0, 62.0] 23
[0.0, 64.0] 24
[-50.0, 63.0] 23
[-30.0, 65.0] 24
[-10.0, 67.0] 25
[10.0, 69.0] 26
[-40.0, 68.0] 25
[-20.0, 70.0] 26
[0.0, 72.0] 27
[-50.0, 71.0] 26
[-30.0, 73.0] 27
[-10.0, 75.0] 28
[10.0, 77.0] 29
[-40.0, 76.0] 28
[-20.0, 78.0] 29
[0.0, 80.0] 30
[-50.0, 79.0] 29
[-30.0, 81.0] 30
[-10.0, 83.0] 31
[10.0, 85.0] 32
[-40.0, 84.0] 31
[-20.0, 86.0] 32
[0.0, 88.0] 33
[-50.0, 87.0] 32
[-30.0, 89.0] 33
[-10.0, 91.0] 34
讨论
聪明的你一定发现了,最终的结果与取点迭代过程非常有关系,即如果我们计算中取的第一个点是x3,而不是x1的话,最终的感知机模型会不同。
也就是说有不止一个模型可以很好的将这四个点分类,可以说是有无限多个解。感知机学习算法由于采用不同的初值或选取不同的误分类点,解可以不同。
这个问题在支持向量机中将有更好的解决方案,支持向量机会在这些无穷多个解之中寻求一个分类效果最为明显的一个超平面。
上代码
首先,定义感知机模型,每行代码的意义已做注释,详见注释说明。
# -*- coding: utf-8 -*-
"""
Created on Wed Jul 11 19:53:37 2018
@author: cz
"""
import numpy as np
import operator
import os
# create a dataset which contains 3 samples with 2 classes
def createDataSet():
# create a matrix: each row as a sample
group = np.array([[20,2], [50,1], [10,3],[60,0.5]])
labels = [1, -1, 1,-1] # four samples and two classes
return group, labels
#classify using perceptron
def perceptronClassify(trainGroup,trainLabels):
global w, b
isFind = False #the flag of find the best w and b
numSamples = trainGroup.shape[0] #计算矩阵的行数
mLenth = trainGroup.shape[1] #计算矩阵的列数
w = [0]*mLenth #初始化w
b = 0 #初始化b
while(not isFind): #定义迭代计算w和b的循环
for i in range(numSamples):
if cal(trainGroup[i],trainLabels[i]) <= 0: #计算损失函数,y(wx+b)<=0时更新参数
print (w,b)
update(trainGroup[i],trainLabels[i]) #更新计算w和b
break #end for loop
elif i == numSamples-1:
print (w,b)
isFind = True #end while loop
def cal(row,trainLabel): #定义损失函数
global w, b
res = 0
for i in range(len(row)):
res += row[i] * w[i]
res += b
res *= trainLabel
return res
def update(row,trainLabel): #学习率为1的更新计算
global w, b
for i in range(len(row)):
w[i] += trainLabel * row[i]
b += trainLabel
然后,导入定义好的感知机模型,并执行算法
# -*- coding: utf-8 -*-
"""
Created on Wed Jul 11 19:59:04 2018
@author: cz
"""
import perceptron1 #导入定义好的感知机模型
g,l = perceptron1.createDataSet() #生成数据集
perceptron1.perceptronClassify(g,l) #训练分类器
- 《深入理解C# 3.x的新特性》博文系列汇总
- 十一国庆节 之 “变量与函数同名时,会输出谁?”
- 挖坑无止境,来看看这个《this的指向》
- T-SQL Enhancement in SQL Server 2005[上篇]
- 初学js钻太深,不太好
- Linux shell 程序设计3——命令行程序
- Linux shell 程序设计2——bash的内置命令
- T-SQL Enhancement in SQL Server 2005[下篇]
- JS原型,a和b是不是失散多年的兄弟?
- Linux shell 程序设计1——安装及入门
- 偶遇--《坑新人--前端专用面试题》
- 简单的说下,(function(){...})() 与 (function(){...}()) 有什么区别?
- ASP.NET Process Model之二:ASP.NET Http Runtime Pipeline[上篇]
- Shell常用命令小结
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法