如何用TensorFlow实现线性回归

时间:2019-08-16
本文章向大家介绍如何用TensorFlow实现线性回归,主要包括如何用TensorFlow实现线性回归使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。

环境Anaconda

废话不多说,关键看代码

import tensorflow as tf
import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'

tf.app.flags.DEFINE_integer("max_step", 300, "训练模型的步数")
FLAGS = tf.app.flags.FLAGS

def linear_regression():
    '''
    自实现线性回归
    :return: 
    '''
    #1.准备100个样本 特征值X,目标值y_true

    with tf.variable_scope("original_data"):
        #mean是平均值
        #stddev代表方差
        X = tf.random_normal(shape=(100,1),mean=0,stddev=1)

        y_true = tf.matmul(X,[[0.8]])+0.7

    #2.建立线性模型:
    with tf.variable_scope("linear_model"):
        weigh = tf.Variable(initial_value=tf.random_normal(shape=(1,1)))
        bias = tf.Variable(initial_value=tf.random_normal(shape=(1,1)))

        y_predict = tf.matmul(X,weigh)+bias

    # 3 确定损失函数
    #均方误差((y-y_repdict)^2)/m = 平均每一个样本的误差
    with tf.variable_scope("loss"):
        error = tf.reduce_mean(tf.square(y_predict-y_true))

    #4梯度下降优化损失:需要指定学习率
    with tf.variable_scope("gd_optimizer"):
        optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01).minimize(error)

    #收集变量
    tf.summary.scalar("error",error)
    tf.summary.histogram("weights",weigh)
    tf.summary.histogram("bias",bias)

    #合并变量
    merge = tf.summary.merge_all()

    #初始化变量
    init = tf.global_variables_initializer()

    #创建一个saver
    saver = tf.train.Saver()
    #开启会话进行训练
    with tf.Session() as sess:
        #初始化变量op
        sess.run(init)
        print("随机初始化的权重为{},偏执为{}".format(weigh.eval(),bias.eval()))

        # print(weigh.eval(), bias.eval())
        # saver.restore(sess,"./checkpoint/linearregression")
        # print(weigh.eval(),bias.eval())
        #创建文件事件
        file_writer = tf.summary.FileWriter(logdir="./",graph=sess.graph)
        #训练模型

        for i in range(FLAGS.max_step):
            sess.run(optimizer)
            summary = sess.run(merge)
            file_writer.add_summary(summary,i)
            print("第{}步的误差为{},权重为{},偏执为{}".format(i,error.eval(),weigh.eval(),bias.eval()))
            #checkpoint:检查点文件
            #tf.keras:h5
            # saver.save(sess,"./checkpoint/linearregression")

if __name__ == '__main__':
    linear_regression()

  部分结果输出:

第294步的误差为7.031372661003843e-06,权重为[[0.7978232]],偏执为[[0.69850117]]
第295步的误差为5.66376502320054e-06,权重为[[0.7978593]],偏执为[[0.6985256]]
第296步的误差为5.646746103593614e-06,权重为[[0.7978932]],偏执为[[0.698556]]
第297步的误差为5.33674938196782e-06,权重为[[0.7979515]],偏执为[[0.69858944]]
第298步的误差为5.233380761637818e-06,权重为[[0.79799336]],偏执为[[0.6986183]]
第299步的误差为5.024347956350539e-06,权重为[[0.7980382]],偏执为[[0.6986382]]

  

原文地址:https://www.cnblogs.com/LiuXinyu12378/p/11366803.html