python SVM 案例,sklearn.svm.SVC 参数说明
sklearn.svm.SVC 参数说明
经常用到sklearn中的SVC函数,这里把文档中的参数翻译了一些,以备不时之需。
本身这个函数也是基于libsvm实现的,所以在参数设置上有很多相似的地方。(PS: libsvm中的二次规划问题的解决算法是SMO)。
sklearn.svm.SVC(C=1.0,kernel='rbf', degree=3, gamma='auto',coef0=0.0,shrinking=True,probability=False,tol=0.001,cache_size=200, class_weight=None,verbose=False,max_iter=-1,decision_function_shape=None,random_state=None)
参数:
l C:C-SVC的惩罚参数C?默认值是1.0
C越大,相当于惩罚松弛变量,希望松弛变量接近0,即对误分类的惩罚增大,趋向于对训练集全分对的情况,这样对训练集测试时准确率很高,但泛化能力弱。C值小,对误分类的惩罚减小,允许容错,将他们当成噪声点,泛化能力较强。
l kernel :核函数,默认是rbf,可以是‘linear’, ‘poly’, ‘rbf’, ‘sigmoid’, ‘precomputed’
0 – 线性:u’v
1 – 多项式:(gamma*u’*v + coef0)^degree
2 – RBF函数:exp(-gamma|u-v|^2)
3 –sigmoid:tanh(gamma*u’*v + coef0)
l degree :多项式poly函数的维度,默认是3,选择其他核函数时会被忽略。
l gamma : ‘rbf’,‘poly’ 和‘sigmoid’的核函数参数。默认是’auto’,则会选择1/n_features
l coef0 :核函数的常数项。对于‘poly’和 ‘sigmoid’有用。
l probability :是否采用概率估计?.默认为False
l shrinking :是否采用shrinking heuristic方法,默认为true
l tol :停止训练的误差值大小,默认为1e-3
l cache_size :核函数cache缓存大小,默认为200
l class_weight :类别的权重,字典形式传递。设置第几类的参数C为weight*C(C-SVC中的C)
l verbose :允许冗余输出?
l max_iter :最大迭代次数。-1为无限制。
l decision_function_shape :‘ovo’, ‘ovr’ or None, default=None3
l random_state :数据洗牌时的种子值,int值
主要调节的参数有:C、kernel、degree、gamma、coef0。
案例代码:
#!/usr/bin/python
# -*- coding:utf-8 -*-
import numpy as np
from sklearn import svm
from scipy import stats
from sklearn.metrics import accuracy_score
import matplotlib as mpl
import matplotlib.pyplot as plt
def extend(a, b, r):
x = a - b
m = (a + b) / 2
return m-r*x/2, m+r*x/2
if __name__ == "__main__":
np.random.seed(0)
N = 20
x = np.empty((4*N, 2))
print("{}n{}".format(x.shape,x))
means = [(-1, 1), (1, 1), (1, -1), (-1, -1)]
print(means)
sigmas = [np.eye(2), 2*np.eye(2), np.diag((1,2)), np.array(((2,1),(1,2)))]
print(sigmas)
for i in range(4):
mn = stats.multivariate_normal(means[i], sigmas[i]*0.3)
# print(mn)
x[i*N:(i+1)*N, :] = mn.rvs(N)
# print(mn.rvs(N))
a = np.array((0,1,2,3)).reshape((-1, 1))
print(a)
y = np.tile(a, N).flatten()
print(np.tile(a, N) )
print(y)
clf = svm.SVC(C=1, kernel='rbf', gamma=1, decision_function_shape='ovo')
# clf = svm.SVC(C=1, kernel='linear', decision_function_shape='ovr')
clf.fit(x, y)
y_hat = clf.predict(x)
acc = accuracy_score(y, y_hat)
np.set_printoptions(suppress=True)
print (u'预测正确的样本个数:%d,正确率:%.2f%%' % (round(acc*4*N), 100*acc))
# decision_function
print (clf.decision_function(x))
print (y_hat)
x1_min, x2_min = np.min(x, axis=0)
x1_max, x2_max = np.max(x, axis=0)
x1_min, x1_max = extend(x1_min, x1_max, 1.05)
x2_min, x2_max = extend(x2_min, x2_max, 1.05)
x1, x2 = np.mgrid[x1_min:x1_max:500j, x2_min:x2_max:500j]
x_test = np.stack((x1.flat, x2.flat), axis=1)
y_test = clf.predict(x_test)
y_test = y_test.reshape(x1.shape)
cm_light = mpl.colors.ListedColormap(['#FF8080', '#A0FFA0', '#6060FF', '#F080F0'])
cm_dark = mpl.colors.ListedColormap(['r', 'g', 'b', 'm'])
mpl.rcParams['font.sans-serif'] = [u'SimHei']
mpl.rcParams['axes.unicode_minus'] = False
plt.figure(facecolor='w')
plt.pcolormesh(x1, x2, y_test, cmap=cm_light)
plt.scatter(x[:, 0], x[:, 1], s=40, c=y, cmap=cm_dark, alpha=0.7)
plt.xlim((x1_min, x1_max))
plt.ylim((x2_min, x2_max))
plt.grid(b=True)
plt.tight_layout(pad=2.5)
plt.title(u'SVM多分类方法:One/One or One/Other', fontsize=18)
plt.show()
分类结果:
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- Spring的学习与实战(续)
- mycat数据库集群系列之数据库多实例安装
- 自已动手作图搞清楚AVL树
- 自己动手作图深入理解二叉树、满二叉树及完全二叉树
- AsyncTask记录
- spring cloud gateway跨域冲突功能的开发
- Spring同时集成JPA与Mybatis
- Qt音视频开发10-ffmpeg控制播放
- 拿好了!Linux 运维必备的 13 款实用工具!
- 自制CA证书设置ssl证书
- MySQL数据迁移TcaplusDB实践
- TKE之初识容器探测器
- 2.3.2 JDK动态代理 -《SSM深入解析与项目实战》
- mac设备安装nginx注意事项
- 《研发运营安全白皮书(2020年)》深度解读:全生命周期安全体系将是未来趋势