tensorflow2实现线性回归例子
时间:2020-04-07
本文章向大家介绍tensorflow2实现线性回归例子,主要包括tensorflow2实现线性回归例子使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。
%tensorflow_version 2.x import tensorflow as tf from tensorflow import keras from tensorflow.keras import layers from tensorflow import initializers as init from tensorflow import losses from tensorflow.keras import optimizers from tensorflow import data as tfdata #1.生成数据 num_inputs = 2#数据有两个特征 num_examples = 1000#共有1000条数据 true_w = [2, -3.4]#两个特征的权重 true_b = 4.2#偏置 features = tf.random.normal(shape=(num_examples, num_inputs), stddev=1)#随机生成一个1000*2的矩阵,每行代表一条数据 labels = true_w[0] * features[:, 0] + true_w[1] * features[:, 1] + true_b#计算y值 labels += tf.random.normal(labels.shape, stddev=0.01)#加上一个偏差 #2.组合数据 batch_size = 10 # 将训练数据的特征和标签组合 dataset = tfdata.Dataset.from_tensor_slices((features, labels))#按第0维进行切分,和标签组合 # 随机读取小批量 dataset = dataset.shuffle(buffer_size=num_examples)#随机打乱1000 dataset = dataset.batch(batch_size) data_iter = iter(dataset)#生成一个迭代器
输出一个batch看一下:
这里是其中一个batch,它包含10条原数据。
model = keras.Sequential()#定义模型 model.add(layers.Dense(1, kernel_initializer=init.RandomNormal(stddev=0.01)))#定义网络层 loss = losses.MeanSquaredError()#定义损失 trainer = optimizers.SGD(learning_rate=0.03)#定义优化器为随机梯度下降 loss_history = [] num_epochs = 3 for epoch in range(1, num_epochs + 1):#全体数据循环三次 for (batch, (X, y)) in enumerate(dataset):#对每一个batch循环 with tf.GradientTape() as tape:#定义梯度 l = loss(model(X, training=True), y) loss_history.append(l.numpy().mean())#记录该batch的损失 grads = tape.gradient(l, model.trainable_variables)#tape.gradient找到变量的梯度 trainer.apply_gradients(zip(grads, model.trainable_variables))#更新权重 l = loss(model(features), labels)#遍历完一次全体数据后的损失 print('epoch %d, loss: %f' % (epoch, l))
因为我们要求循环所有数据3次,而每一次循环都是小批量循环,每个小批量里都有10条数据,所以首先写出两个for循环,最里层的循环是每次循环10条数据。
我们通过调用tensorflow.GradientTape
记录动态图梯度,之前定义的损失函数是均方误差,需要真实值和模型值,于是把model(X)和y输入loss里。
我们可以记录每个batch的损失,添加到loss_history中。
通过 model.trainable_variables
找到需要更新的变量,并用 trainer.apply_gradients
更新权重,完成一步训练。
查看训练出来的参数和原参数的对比:
原文地址:https://www.cnblogs.com/liuxiangyan/p/12655875.html
- c#:Reflector+Reflexil 修改编译后的dll/exe文件
- testNG java.net.SocketException: Software caused connection abort: socket write error
- MyBatis.Net 学习手记
- 基于JavaScript 声明全局变量的三种方式详解
- 网页基础篇之如何制作简单的静态网页
- Mybatis.Net 整合 ODP.NET Managed
- 通过maven test 报org.apache.ibatis.binding.BindingException: Invalid bound statement
- 知道这几点,用微信小程序留住海量客户不是问题
- C#:DataTable映射成Model
- jenkins 多选框
- Oracle:ODP.NET Managed 小试牛刀
- C#:Func的同步、异步调用
- Python之路-day6
- hadoop1.2.1伪分布模式配置
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- JavaScript同步、异步及事件循环
- Node.js开发人员都应该知道的12个有用的包
- 欧拉函数的几条重要性质
- 迷宫问题的简单栈实现
- xmuC语言程序实践week 3 大作业
- xmuC语言程序实践week 4 大作业
- R语言预测人口死亡率:用李·卡特(Lee-Carter)模型、非线性模型进行平滑估计
- 前端的发展历程
- R语言蒙特卡洛计算和快速傅立叶变换计算矩生成函数
- Visual Studio 中万能头文件编译不了的解决方案
- Difference in two ways of using lower_bound [C++]std::set::lower_bound与std::lower_bound
- 迷你版Vue--学习如何造一个Vue轮子
- 如何用R语言绘制生成正态分布图表
- hdu 5143 NPY and arithmetic progression(暴力+思维)
- 正则表达式之简易markdown文件解析器