[机智的机器在学习] 利用TensorFlow实现多元线性回归分类器
时间:2022-05-08
本文章向大家介绍[机智的机器在学习] 利用TensorFlow实现多元线性回归分类器,主要内容包括其使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。
从今天的推文开始,我打算把经典的机器学习算法,都用tf实现一遍。这样一来可以熟悉一下机器学习算法,二来可以对tf有比较好的掌握,如果你是新手,那就跟着我的节奏,一起学习吧。讲的不好,大神轻拍~。
为了节省时间,有兴趣的童鞋可以直接去Github上clone,使用~,欢迎来点star~。
Github 地址:
https://github.com/Alvin2580du/machine_learning_with_tensorflow.git
# 导入需要的模块
# - * - coding: utf-8 - * -
import tensorflow as tf
import numpy as np
import pandas as pd
from sklearn import datasets
import os
# 这个函数是为了利用sklearn获取iris数据,然后保存到本地后面用。
def make_iris():
iris = datasets.load_iris()
x = pd.DataFrame(iris.data)
y = pd.DataFrame(iris.target).values
y_onehot = tf.one_hot(y, 3)
sess = tf.InteractiveSession()
y_onehot_value = sess.run(y_onehot).reshape((150, 3))
y_onehot_value = pd.DataFrame(y_onehot_value)
x.to_csv("iris_x.csv", sep=',', header=None, index=None)
y_onehot_value.to_csv("iris_y.csv", sep=',', header=None, index=None)
# 定义模型,这里要分清楚,in_size,out_size分别代表什么的大小,比如对于iris数据集,有4个自变量,1个因变量,但是我们把label经过one_hot编码之后,label就变成了3维。所以这里In_size就是训练数据的维度,也就是变量的个数。而out_size是输出的维度,就是因变量的维度,所以是3.
一般对于多元线性回归模型,可以写成矩阵的形式就是,Y=WX+b,这里W是4x3的,x是150x4的,b是150x3的,所以Y的维度就是(150x4)x(4x3)+(150x3)=150x3(属于某个类别的概率),模型最后输出是softmax多分类函数,所以最后每个样本都会有一个属于不同类别的概率值。
def model(inputs, in_size, out_size):
Weights = tf.Variable(tf.random_normal([in_size, out_size]))
biases = tf.Variable(tf.zeros([1, out_size]))
outputs = tf.nn.softmax(tf.matmul(inputs, Weights) + biases)
return outputs
# 定义模型训练函数
def train():
# 首先是读取数据,用上面那个函数保存的数据,
# 把训练数据读进来,因为pandas读取的是
# DataFrame对象,通过values属性转换为numpy.ndarry类型。
x_data = pd.read_csv("iris_x.csv", header=None).values
y_data = pd.read_csv("iris_y.csv", header=None).values
# 接下来是把数据分为训练集和测试集。
train_x = x_data[0:120, :]
train_y = y_data[0:120, :]
test_x = x_data[120:151, :]
test_y = y_data[120:151, :]
print train_x.shape
print test_x.shape
print train_y.shape
print test_y.shape
#定义placeholder,这也可以不定义,后面就不
# 用显示的feed了,直接run优化目标就行。这
# 里还是要注意holder的维度代表的含义,别稀里糊涂的。
x_data_holder = tf.placeholder(tf.float32, [None, 4])
y_data_holder = tf.placeholder(tf.float32, [None, 3])
# 调用模型,输出预测结果
y_prediction = model(x_data_holder, 4, 3)
#定义交叉熵损失函数
cross_entropy = tf.reduce_mean(
-tf.reduce_sum(y_data_holder *
tf.log(y_prediction), reduction_indices=[1]))
# 用梯度下降法求解,使得损失函数最小。
train_step = tf.train.GradientDescentOptimizer(0.1)
.minimize(cross_entropy)
# 启动session。
with tf.Session() as sess:
init = tf.global_variables_initializer()
sess.run(init)
epoch = 2000
for e in range(epoch):
sess.run(train_step,
feed_dict={x_data_holder: train_x,
y_data_holder: train_y})
# 每隔50次,计算一下损失,注意这里的损失是训
# 练数据的损失,而且这个损失是单步的损失,
# 不是全部数据的损失。
if e % 50 == 0:
train_loss = sess.run(cross_entropy,
feed_dict={x_data_holder: train_x,
y_data_holder: train_y})
y_pre = sess.run(y_prediction,
feed_dict={x_data_holder: test_x})
correct_prediction = tf.equal(tf.argmax(y_pre, 1),
tf.argmax(test_y, 1))
# eval函数可以将tensor类型转换为具体的值,也可以不运行。
# print correct_prediction.eval(session=sess)
accuracy = tf.reduce_mean(tf.cast(correct_prediction,
tf.float32))
# 最后用测试数据,计算一下测试数据的预测精度。
test_acc = sess.run(accuracy,
feed_dict={x_data_holder: test_x,
y_data_holder: test_y})
print "acc: {}; loss: {}".format(test_acc, train_loss)
# 要计算全部数据的损失,需要在最后再run一下损失。
training_cost = sess.run(cross_entropy,
feed_dict={x_data_holder: train_x,
y_data_holder: train_y})
print "Training cost={}".format(training_cost)
if __name__ == "__main__":
train()
============End============
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- Laravel中9个不经常用的小技巧汇总
- PHP simplexml_load_string()函数实例讲解
- php文件操作之文件写入字符串、数组的方法分析
- php xhprof使用实例详解
- PHP获取远程http或ftp文件的md5值的方法
- PHP addslashes()函数讲解
- PHP+swoole+linux实现系统监控和性能优化操作示例
- PHP中PCRE正则解析代码详解
- tensorflow 2.1.0 安装与实战教程(CASIA FACE v5)
- 使用Tensorflow-GPU禁用GPU设置(CPU与GPU速度对比)
- python 抓取知乎指定回答下视频的方法
- 基于python实现计算两组数据P值
- PHP getNamespaces()函数讲解
- OpenCV 使用imread()函数读取图片的六种正确姿势
- PHP simplexml_import_dom()函数讲解