[编程经验] TensorFlow实现线性支持向量机SVM
[点击蓝字,一键关注~]
今天要说的是线性可分情况下的支持向量机的实现,如果对于平面内的点,支持向量机的目的是找到一条直线,把训练样本分开,使得直线到两个样本的距离相等,如果是高维空间,就是一个超平面。
然后我们简单看下对于线性可分的svm原理是啥,对于线性模型:
训练样本为
标签为:
如果
那么样本就归为正类, 否则归为负类。
这样svm的目标是找到W(向量)和b,然后假设我们找到了这样的一条直线,可以把数据分开,那么这些数据到这条直线的距离为:
然后我们把超平面两边到超平面的距离叫做间隔(margin),优化目标是使得这个margin最大,使得这样得到的超平面具有良好的泛化能力(用别的数据也能正确分类),
SVM的优化目标是:
条件是:
注意这里,因为tn可以取+1和-1,当取-1的时候,不等式两边都会乘以-1,所以不等号的方向会变。求解这个优化问题(二次规划),可以用拉格朗日乘子法,其中alpha是拉格朗日乘子。
对w和b求导,可以得到:
然后把这个求解的结果代到上面的L里面,可以得到L的对偶形式,得L~:
对偶形式的条件是:
然后将开始的那个线性模型中的参数W用核函数代替得到:
上面L的对偶形式,就是一个简单的二次规划问题,可以利用KKT条件求解:
然后把上面的y(xn)带入到这个等式里面,就得到下面这个式子:
求解上式,得b为:
其中Ns表示的就是支持向量,K(Xn,Xm)表示核函数。
下面举个核函数的栗子,对于二维平面内的点,
花了两个多小时,终于算是把代码调通了,虽然不难,但是还是觉得自己水平有限,实现起来还是会有很多问题
import numpy as np
import tensorflow as tf
from sklearn import datasets
x_data = tf.placeholder(shape=[None, 2], dtype=tf.float32)
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)
# 获得batch大小的数据
def gen_data(batch_size):
iris = datasets.load_iris()
iris_X = np.array([[x[0], x[3]] for x in iris.data])
iris_y = np.array([1 if y == 0 else -1 for y in iris.target])
train_indices = np.random.choice(len(iris_X),
int(round(len(iris_X) * 0.8)), replace=False)
train_x = iris_X[train_indices]
train_y = iris_y[train_indices]
rand_index = np.random.choice(len(train_x), size=batch_size)
batch_train_x = train_x[rand_index]
batch_train_y = np.transpose([train_y[rand_index]])
test_indices = np.array(
list(set(range(len(iris_X))) - set(train_indices)))
test_x = iris_X[test_indices]
test_y = iris_y[test_indices]
return batch_train_x, batch_train_y, test_x, test_y
# 定义模型
def svm():
A = tf.Variable(tf.random_normal(shape=[2, 1]))
b = tf.Variable(tf.random_normal(shape=[1, 1]))
model_output = tf.subtract(tf.matmul(x_data, A), b)
l2_norm = tf.reduce_sum(tf.square(A))
alpha = tf.constant([0.01])
classification_term = tf.reduce_mean(tf.maximum(0.,
tf.subtract(1., tf.multiply(model_output, y_target))))
loss = tf.add(classification_term, tf.multiply(alpha, l2_norm))
my_opt = tf.train.GradientDescentOptimizer(0.01)
train_step = my_opt.minimize(loss)
return model_output, loss, train_step
def train(sess, batch_size):
print("# Training loop")
for i in range(100):
x_vals_train, y_vals_train,
x_vals_test, y_vals_test = gen_data(batch_size)
model_output, loss, train_step = svm()
init = tf.global_variables_initializer()
sess.run(init)
prediction = tf.sign(model_output)
accuracy = tf.reduce_mean(tf.cast(
tf.equal(prediction, y_target), tf.float32))
sess.run(train_step, feed_dict=
{
x_data: x_vals_train,
y_target: y_vals_train
})
train_loss = sess.run(loss, feed_dict=
{
x_data: x_vals_train,
y_target: y_vals_train
})
train_acc = sess.run(accuracy, feed_dict=
{
x_data: x_vals_train,
y_target: y_vals_train
})
test_acc = sess.run(accuracy, feed_dict=
{
x_data: x_vals_test,
y_target: np.transpose([y_vals_test])
})
if i % 10 == 1:
print("train loss: {:.6f}, train accuracy : {:.6f}".
format(train_loss[0], train_acc))
print
print("test accuracy : {:.6f}".format(test_acc))
print("- * - "*15)
def main(_):
with tf.Session() as sess:
train(sess, batch_size=16)
if __name__ == "__main__":
tf.app.run()
总结一下,SVM里面的坑,首先要知道SVM的目的找到一条线或者超平面,然后会计算点到超平面的距离,然后把这个距离转化为一个二次规划问题,然后就是使用拉格朗日方法求解这个优化问题,最后会涉及核函数方法。对于线性可分的SVM主要就是这些,是不是很简单呢?
今天就这样了,妈呀都12点了,好久不看SVM了,虽然之前花了很多时间看这个,但是还是好多都忘记了。好了,下次我们说非线性可分的情况是什么样子的,还有什么是松弛变量这些东东。。。
- SpringMVC注解@RequestMapping之produces属性导致的406错误
- SpringBoot集成MyBatis的分页插件PageHelper(回头草)
- SpringBoot整合Mybatis之进门篇
- Tomcat和Java Virtual Machine的性能调优总结
- 一次浴火重生的MySQL优化(EXPLAIN命令详解)
- 简单聊聊不可或缺的Nginx反向代理服务器--实现负载均衡【上篇】
- Java设计模式之适配器设计模式(项目升级案例)
- Java设计模式之模板方法设计模式(银行计息案例)
- 多线程之策略模式
- 文件上传的动作不能太俗,必须页面无刷新上传
- 这次真的忽略了一些ActiveMQ内心的娇艳
- 多线程编程:阻塞、并发队列的使用总结
- 多线程编程:多线程并发制单的开发记录【一】
- 如何使用线程锁来提高多线程并发效率
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法