一看就懂的Tensorflow实战(K-Means模型)
时间:2022-07-22
本文章向大家介绍一看就懂的Tensorflow实战(K-Means模型),主要内容包括其使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。
K-Means算法简介
K-MEANS
算法是输入聚类个数k
,以及包含 n
个数据对象的数据库,输出满足方差最小标准k
个聚类的一种算法。属于一种经典的无监督学习算法。
示意图如下所示:
K-Means算法示意图
k-means
算法接受输入量 k
;然后将n
个数据对象划分为 k
个聚类以便使得所获得的聚类满足:同一聚类中的对象相似度较高;而不同聚类中的对象相似度较小。聚类相似度是利用各聚类中对象的均值所获得一个“中心对象”(引力中心)来进行计算的。
基本步骤:
(1) 从 n个数据对象任意选择 k 个对象作为初始聚类中心;
(2) 根据每个聚类对象的均值(中心对象),计算每个对象与这些中心对象的距离;并根据最小距离重新对相应对象进行划分;
(3) 重新计算每个(有变化)聚类的均值(中心对象);
(4) 计算标准测度函数,当满足一定条件,如函数收敛时,则算法终止;如果条件不满足则回到步骤(2)。
TensorFlow的K-Means实现
from __future__ import print_function
import numpy as np
import tensorflow as tf
from tensorflow.contrib.factorization import KMeans#导入KMeans函数
# Ignore all GPUs, tf random forest does not benefit from it.
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""
导入数据集
# Import MNIST data
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("./data/", one_hot=True)
full_data_x = mnist.train.images
Extracting ./data/train-images-idx3-ubyte.gz
Extracting ./data/train-labels-idx1-ubyte.gz
Extracting ./data/t10k-images-idx3-ubyte.gz
Extracting ./data/t10k-labels-idx1-ubyte.gz
设置参数
# Parameters
num_steps = 50 # Total steps to train
batch_size = 1024 # The number of samples per batch
k = 25 # The number of clusters
num_classes = 10 # The 10 digits
num_features = 784 # Each image is 28x28 pixels
# Input images
X = tf.placeholder(tf.float32, shape=[None, num_features])
# Labels (for assigning a label to a centroid and testing)
Y = tf.placeholder(tf.float32, shape=[None, num_classes])
# K-Means Parameters
# 距离度量的方式采用余弦距离(余弦相似度)
kmeans = KMeans(inputs=X, num_clusters=k, distance_metric='cosine',
use_mini_batch=True)
构建K-means图模型
# Build KMeans graph
(all_scores, cluster_idx, scores, cluster_centers_initialized,init_op,train_op) = kmeans.training_graph()
cluster_idx = cluster_idx[0] # fix for cluster_idx being a tuple
avg_distance = tf.reduce_mean(scores)
# Initialize the variables (i.e. assign their default value)
init_vars = tf.global_variables_initializer()
训练模型
# Start TensorFlow session
sess = tf.Session()
# Run the initializer
sess.run(init_vars, feed_dict={X: full_data_x})
sess.run(init_op, feed_dict={X: full_data_x})
# Training
for i in range(1, num_steps + 1):
_, d, idx = sess.run([train_op, avg_distance, cluster_idx],
feed_dict={X: full_data_x})
if i % 10 == 0 or i == 1:
print("Step %i, Avg Distance: %f" % (i, d))
Step 1, Avg Distance: 0.341471
Step 10, Avg Distance: 0.221609
Step 20, Avg Distance: 0.220328
Step 30, Avg Distance: 0.219776
Step 40, Avg Distance: 0.219419
Step 50, Avg Distance: 0.219154
测试
# Assign a label to each centroid
# Count total number of labels per centroid, using the label of each training
# sample to their closest centroid (given by 'idx')
counts = np.zeros(shape=(k, num_classes))
for i in range(len(idx)):
counts[idx[i]] += mnist.train.labels[i]
# Assign the most frequent label to the centroid
labels_map = [np.argmax(c) for c in counts]
labels_map = tf.convert_to_tensor(labels_map)
# Evaluation ops
# Lookup: centroid_id -> label
cluster_label = tf.nn.embedding_lookup(labels_map, cluster_idx)
# Compute accuracy
correct_prediction = tf.equal(cluster_label, tf.cast(tf.argmax(Y, 1), tf.int32))
accuracy_op = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
# Test Model
test_x, test_y = mnist.test.images, mnist.test.labels
print("Test Accuracy:", sess.run(accuracy_op, feed_dict={X: test_x, Y: test_y}))
Test Accuracy: 0.7127
参考
[百度百科——K-MEANS算法]https://baike.baidu.com/item/K-MEANS算法/594631?fr=aladdin
[TensorFlow-Examples]https://github.com/aymericdamien/TensorFlow-Examples
- 人脸智慧时尚店落地广深,微信支付赋能智慧零售
- 基层医疗破局关键:从医疗SaaS三大未来趋势说起
- 无数据库权限下载文献攻略大全
- 学 Python 就是为了当程序员?不止一种可能性
- 2018程序员必备碎片化学习工具
- 深入理解php底层:php生命周期
- 网站性能测试指标详解
- 在.NET Core类库中使用EF Core迁移数据库到SQL Server
- 人类的未来:儿童都能驾驶的汽车
- 论循证新闻的方法与意义——一种媒体融合背景下新闻生产方式创新
- 域名资讯:域名jiuhuang.com已搭建成“韭黄答题助手”网站
- Servlet开篇
- 浅谈中国域名的名与利
- 加密货币的火爆,tokens.com域名已50万美元成交
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- Django JSONField SQL注入漏洞(CVE-2019-14234)分析与影响
- 扩展之Tensorflow2.0 | 21 Keras的API详解(上)卷积、激活、初始化、正则
- 持续代码质量管理-SonarQube-7.3部署
- 攻击Scrapyd爬虫
- 【webpack】从vue-cli 2x 到 3x 迁移与实践
- 前端单元测试那些事
- 前端Nginx那些事
- 前端运维部署那些事
- 《前端那些事》从0到1开发简单脚手架
- CDH7.1.1启用Kerberos
- 持续代码质量管理-SonarQube Scanner部署 2.1. 软件安装2.2. 配置修改
- 《前端那些事》聊聊前端的按需加载
- 直播带货系统,滚动视图,上滑隐藏,下滑显示
- 持续代码质量管理-SonarQube-7.3简单使用 2.1. 查看配置2.2. 质量检测2.3. 浏览器查看
- 安装指定版本的docker服务