机器学习笔记(6):多类逻辑回归-使用gluon
时间:2022-04-22
本文章向大家介绍机器学习笔记(6):多类逻辑回归-使用gluon,主要内容包括其使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。
上一篇演示了纯手动添加隐藏层,这次使用gluon让代码更精减,代码来自:https://zh.gluon.ai/chapter_supervised-learning/mlp-gluon.html
from mxnet import gluon
from mxnet import ndarray as nd
import matplotlib.pyplot as plt
import mxnet as mx
from mxnet import autograd
def transform(data, label):
return data.astype('float32')/255, label.astype('float32')
mnist_train = gluon.data.vision.FashionMNIST(train=True, transform=transform)
mnist_test = gluon.data.vision.FashionMNIST(train=False, transform=transform)
def show_images(images):
n = images.shape[0]
_, figs = plt.subplots(1, n, figsize=(15, 15))
for i in range(n):
figs[i].imshow(images[i].reshape((28, 28)).asnumpy())
figs[i].axes.get_xaxis().set_visible(False)
figs[i].axes.get_yaxis().set_visible(False)
plt.show()
def get_text_labels(label):
text_labels = [
'T 恤', '长 裤', '套头衫', '裙 子', '外 套',
'凉 鞋', '衬 衣', '运动鞋', '包 包', '短 靴'
]
return [text_labels[int(i)] for i in label]
data, label = mnist_train[0:10]
print('example shape: ', data.shape, 'label:', label)
show_images(data)
print(get_text_labels(label))
batch_size = 256
train_data = gluon.data.DataLoader(mnist_train, batch_size, shuffle=True)
test_data = gluon.data.DataLoader(mnist_test, batch_size, shuffle=False)
#计算模型
net = gluon.nn.Sequential()
with net.name_scope():
net.add(gluon.nn.Flatten())
net.add(gluon.nn.Dense(256, activation="relu"))
net.add(gluon.nn.Dense(10))
net.initialize()
softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss()
#定义训练器
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.5})
def accuracy(output, label):
return nd.mean(output.argmax(axis=1) == label).asscalar()
def _get_batch(batch):
if isinstance(batch, mx.io.DataBatch):
data = batch.data[0]
label = batch.label[0]
else:
data, label = batch
return data, label
def evaluate_accuracy(data_iterator, net):
acc = 0.
if isinstance(data_iterator, mx.io.MXDataIter):
data_iterator.reset()
for i, batch in enumerate(data_iterator):
data, label = _get_batch(batch)
output = net(data)
acc += accuracy(output, label)
return acc / (i+1)
for epoch in range(5):
train_loss = 0.
train_acc = 0.
for data, label in train_data:
with autograd.record():
output = net(data)
loss = softmax_cross_entropy(output, label)
loss.backward()
trainer.step(batch_size) #使用训练器,向"前"走一步
train_loss += nd.mean(loss).asscalar()
train_acc += accuracy(output, label)
test_acc = evaluate_accuracy(test_data, net)
print("Epoch %d. Loss: %f, Train acc %f, Test acc %f" % (
epoch, train_loss/len(train_data), train_acc/len(train_data), test_acc))
data, label = mnist_test[0:10]
show_images(data)
print('true labels')
print(get_text_labels(label))
predicted_labels = net(data).argmax(axis=1)
print('predicted labels')
print(get_text_labels(predicted_labels.asnumpy()))
有变化的地方,已经加上了注释。运行效果,跟一篇完全相同,就不重复贴图了
- 联系生活来简化sql(r3笔记第43天)
- [笔记]使用Python一步一步地来进行数据分析
- 使用 R 语言从拉勾网看数据挖掘岗位现状
- 使用strace分析exp的奇怪问题(r3笔记第41天)
- Python文本挖掘:知乎网友如何评价《人民的名义》
- 怎样做中文文本的情感分析?
- 由一条日志警告所做的调优分析(r3笔记第40天)
- 生产环境sql语句调优实战第十篇(r3笔记第39天)
- memory_target设置不当导致数据库无法启动的问题(r3笔记第38天)
- python利用结巴分词做新闻地图
- 数据库静默安装总结(r3笔记第58天)
- 用TensorFlow实现文本分析模型,做个聊天机器人
- 深度学习:用tensorflow建立线性回归模型
- 用python基于2015-2016年的NBA常规赛及季后赛的统计数据分析
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- 【LeetCode每日一题】(8.9)复原IP地址(回溯)
- 【回溯算法】N叉树相关技巧
- 【回溯算法】回溯,从入门到入土,七道试题精选、精讲、精练
- 数据结构练手小项目(AVL树、哈希表、循环链表、MySQL数据库)
- 【LeetCode】每日一题(8.2)二叉树展开为链表
- 【小技巧】argc和argv的用法
- 全面分析redis持久化机制
- 【奇技淫巧】-- 接雨水
- 【奇技淫巧】-- 最长连续序列
- 【redis】跟我一起动手玩玩redis主从复制和哨兵模式
- 【C++】八大排序算法 :GIF + 亲测代码 +专项练习平台
- 【C++】勉强能看的线程池详解
- 国密SSL协议之Java编程
- 7. Jackson用树模型处理JSON是必备技能,不信你看
- epoll,求知者离我近点