[MXNet逐梦之旅]练习二·使用MXNet拟合直线简洁实现
时间:2022-06-24
本文章向大家介绍[MXNet逐梦之旅]练习二·使用MXNet拟合直线简洁实现,主要内容包括其使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。
[MXNet逐梦之旅]练习二·使用MXNet拟合直线简洁实现
- code
#%%
#%matplotlib inline
from matplotlib import pyplot as plt
from mxnet import autograd, nd
import random
#%%
num_inputs = 1
num_examples = 100
true_w = 1.56
true_b = 1.24
features = nd.arange(0,10,0.1).reshape((-1, 1))
labels = true_w * features + true_b
labels += nd.random.normal(scale=0.2, shape=labels.shape)
features[0], labels[0]
#%%
# 本函数已保存在d2lzh包中方便以后使用
from mxnet import gluon as gl
from mxnet.gluon import data as gdata
batch_size = 10
# 将训练数据的特征和标签组合
dataset = gl.data.ArrayDataset(features, labels)
# 随机读取小批量
data_iter = gl.data.DataLoader(dataset, batch_size, shuffle=True)
for X, y in data_iter:
print(X, y)
break
#%%
model = gl.nn.Sequential()
#%%
model.add(gl.nn.Dense(1))
model
#%%
import mxnet as mx
model.initialize(mx.init.Normal(sigma=0.01))
#%%
loss = gl.loss.L2Loss() # 平方损失又称L2范数损失
#%%
trainer = gl.Trainer(model.collect_params(), 'adam', {'learning_rate': 0.5})
#%%
num_epochs = 10
for epoch in range(1, num_epochs + 1):
for X, y in data_iter:
with autograd.record():
l = loss(model(X), y)
l.backward()
trainer.step(batch_size)
l = loss(model(features), labels)
print('epoch %d, loss: %f' % (epoch, l.mean().asnumpy()))
#%%
pre = model(features)
pre
plt.scatter(features.asnumpy(), labels.asnumpy(), 1)
plt.scatter(features.asnumpy(), pre.asnumpy(), 1)
plt.show()
#%%
print(model)
print("w:",model.collect_params()["dense0_weight"].data())
print("b:",model.collect_params()["dense0_bias"].data())
- out
<NDArray 10x1 @cpu(0)>
epoch 1, loss: 5.570210
epoch 2, loss: 2.831637
epoch 3, loss: 0.995476
epoch 4, loss: 0.332262
epoch 5, loss: 0.060224
epoch 6, loss: 0.027413
epoch 7, loss: 0.031316
epoch 8, loss: 0.030222
epoch 9, loss: 0.027907
epoch 10, loss: 0.032840
Sequential(
(0): Dense(1 -> 1, linear)
)
w:
[[1.5745053]]
<NDArray 1x1 @cpu(0)>
b:
[1.2476798]
<NDArray 1 @cpu(0)>
蓝色是原始数据
黄色为拟合数据
- [Go 语言社区]Golang架构--服务器与客户端自定义传输规则--原创
- Go语言 -浮点数
- android开发列表界面
- Java中Queue和BlockingQueue的区别
- android使用Activity
- Golang入门-- 2D的图形库学习
- Go语言--简单聊天室程序
- Go语言编程中判断文件是否存在是创建目录的方法
- jquery clone()表格之后查找里边的元素
- 必读:再讲Spark与kafka 0.8.2.1+整合
- windows下如何下载android源码
- Go语言的os包中常用函数初步归纳
- socket inet_pton
- 1,StructuredStreaming简介
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- 8成以上的java线程状态图都画错了,看看这个-图解java并发第二篇
- 特殊数据格式处理-JSON框架Jackson精解第2篇
- 序列化与反序列化核心用法-JSON框架Jackson精解第一篇
- 属性序列化自定义排序与字母表排序-JSON框架Jackson精解第3篇
- 【我在拉勾训练营学技术】mysql 索引面试再也不怕啦
- 智能合约中常见的漏洞总结复现#技术创作101训练营#
- JS根据列表排列对象数组
- git提取两次提交或者版本的差异文件并打包成zip压缩包
- 博客通用版Live2d伊斯特瓦尔发布
- 一个小需求,自动重启k8s集群中日志不刷新的POD
- 多图,一文了解 8 种常见的数据结构
- Jenkins--pipline 流水线部署Java后端项目
- 微信小程序修炼之路LV1—工具介绍篇
- CentOS 7 部署OpenLDAP+FreeRadius
- 手把手教你使用yolo进行对象检测