特征列feature_column
时间:2022-07-22
本文章向大家介绍特征列feature_column,主要内容包括其使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。
TensorFlow的中阶API主要包括:
- 数据管道(tf.data)
- 特征列(tf.feature_column)
- 激活函数(tf.nn)
- 模型层(tf.keras.layers)
- 损失函数(tf.keras.losses)
- 评估函数(tf.keras.metrics)
- 优化器(tf.keras.optimizers)
- 回调函数(tf.keras.callbacks)
如果把模型比作一个房子,那么中阶API就是【模型之墙】。
本篇我们介绍特征列。
特征列 通常用于对结构化数据实施特征工程时候使用,图像或者文本数据一般不会用到特征列。
一,特征列用法概述
使用特征列可以将类别特征转换为one-hot编码特征,将连续特征构建分桶特征,以及对多个特征生成交叉特征等等。
要创建特征列,请调用 tf.feature_column 模块的函数。该模块中常用的九个函数如下图所示,所有九个函数都会返回一个 Categorical Column 或一个 Dense Column 对象,但却不会返回 bucketized_column,后者继承自这两个类。
注意:所有的Catogorical Column类型最终都要通过indicator_column转换成Dense Column类型才能传入模型!
- numeric_column 数值列,最常用。
- bucketized_column 分桶列,由数值列生成,可以由一个数值列出多个特征,one-hot编码。
- categorical_column_with_identity 分类标识列,one-hot编码,相当于分桶列每个桶为1个整数的情况。
- categorical_column_with_vocabulary_list 分类词汇列,one-hot编码,由list指定词典。
- categorical_column_with_vocabulary_file 分类词汇列,由文件file指定词典。
- categorical_column_with_hash_bucket 哈希列,整数或词典较大时采用。
- indicator_column 指标列,由Categorical Column生成,one-hot编码
- embedding_column 嵌入列,由Categorical Column生成,嵌入矢量分布参数需要学习。嵌入矢量维数建议取类别数量的 4 次方根。
- crossed_column 交叉列,可以由除categorical_column_with_hash_bucket的任意分类列构成。
二,特征列使用范例
以下是一个使用特征列解决Titanic生存问题的完整范例。
import datetime
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import tensorflow as tf
from tensorflow.keras import layers,models
#打印日志
def printlog(info):
nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
print("n"+"=========="*8 + "%s"%nowtime)
print(info+'...nn')
tf.keras.backend.clear_session() #清空会话
#================================================================================
# 一,构建数据管道
#================================================================================
printlog("step1: prepare dataset...")
dftrain_raw = pd.read_csv("./data/titanic/train.csv")
dftest_raw = pd.read_csv("./data/titanic/test.csv")
dfraw = pd.concat([dftrain_raw,dftest_raw])
def prepare_dfdata(dfraw):
dfdata = dfraw.copy()
dfdata.columns = [x.lower() for x in dfdata.columns]
dfdata = dfdata.rename(columns={'survived':'label'})
dfdata = dfdata.drop(['passengerid','name'],axis = 1)
for col,dtype in dict(dfdata.dtypes).items():
# 判断是否包含缺失值
if dfdata[col].hasnans:
# 添加标识是否缺失列
dfdata[col + '_nan'] = pd.isna(dfdata[col]).astype('int32')
# 填充
if dtype not in [np.object,np.str,np.unicode]:
dfdata[col].fillna(dfdata[col].mean(),inplace = True)
else:
dfdata[col].fillna('',inplace = True)
return(dfdata)
dfdata = prepare_dfdata(dfraw)
dftrain = dfdata.iloc[0:len(dftrain_raw),:]
dftest = dfdata.iloc[len(dftrain_raw):,:]
# 从 dataframe 导入数据
def df_to_dataset(df, shuffle=True, batch_size=32):
dfdata = df.copy()
if 'label' not in dfdata.columns:
ds = tf.data.Dataset.from_tensor_slices(dfdata.to_dict(orient = 'list'))
else:
labels = dfdata.pop('label')
ds = tf.data.Dataset.from_tensor_slices((dfdata.to_dict(orient = 'list'), labels))
if shuffle:
ds = ds.shuffle(buffer_size=len(dfdata))
ds = ds.batch(batch_size)
return ds
ds_train = df_to_dataset(dftrain)
ds_test = df_to_dataset(dftest)
#================================================================================
# 二,定义特征列
#================================================================================
printlog("step2: make feature columns...")
feature_columns = []
# 数值列
for col in ['age','fare','parch','sibsp'] + [
c for c in dfdata.columns if c.endswith('_nan')]:
feature_columns.append(tf.feature_column.numeric_column(col))
# 分桶列
age = tf.feature_column.numeric_column('age')
age_buckets = tf.feature_column.bucketized_column(age,
boundaries=[18, 25, 30, 35, 40, 45, 50, 55, 60, 65])
feature_columns.append(age_buckets)
# 类别列
# 注意:所有的Catogorical Column类型最终都要通过indicator_column转换成Dense Column类型才能传入模型!
sex = tf.feature_column.indicator_column(
tf.feature_column.categorical_column_with_vocabulary_list(
key='sex',vocabulary_list=["male", "female"]))
feature_columns.append(sex)
pclass = tf.feature_column.indicator_column(
tf.feature_column.categorical_column_with_vocabulary_list(
key='pclass',vocabulary_list=[1,2,3]))
feature_columns.append(pclass)
ticket = tf.feature_column.indicator_column(
tf.feature_column.categorical_column_with_hash_bucket('ticket',3))
feature_columns.append(ticket)
embarked = tf.feature_column.indicator_column(
tf.feature_column.categorical_column_with_vocabulary_list(
key='embarked',vocabulary_list=['S','C','B']))
feature_columns.append(embarked)
# 嵌入列
cabin = tf.feature_column.embedding_column(
tf.feature_column.categorical_column_with_hash_bucket('cabin',32),2)
feature_columns.append(cabin)
# 交叉列
pclass_cate = tf.feature_column.categorical_column_with_vocabulary_list(
key='pclass',vocabulary_list=[1,2,3])
crossed_feature = tf.feature_column.indicator_column(
tf.feature_column.crossed_column([age_buckets, pclass_cate],hash_bucket_size=15))
feature_columns.append(crossed_feature)
#================================================================================
# 三,定义模型
#================================================================================
printlog("step3: define model...")
model = tf.keras.Sequential([
layers.DenseFeatures(feature_columns), #将特征列放入到tf.keras.layers.DenseFeatures中!!!
layers.Dense(64, activation='relu'),
layers.Dense(64, activation='relu'),
layers.Dense(1, activation='sigmoid')
])
model.compile(optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy'])
#================================================================================
# 四,训练模型
#================================================================================
printlog("step4: train model...")
history = model.fit(ds_train,
validation_data=ds_test,
epochs=20)
#================================================================================
# 五,评估模型
#================================================================================
printlog("step5: eval model...")
model.summary()
%matplotlib inline
%config InlineBackend.figure_format = 'svg'
def plot_metric(history, metric):
train_metrics = history.history[metric]
val_metrics = history.history['val_'+metric]
epochs = range(1, len(train_metrics) + 1)
plt.plot(epochs, train_metrics, 'bo--')
plt.plot(epochs, val_metrics, 'ro-')
plt.title('Training and validation '+ metric)
plt.xlabel("Epochs")
plt.ylabel(metric)
plt.legend(["train_"+metric, 'val_'+metric])
plt.show()
plot_metric(history,"accuracy")
- NYOJ-------表达式求值
- HDUOJ----1181 变形课
- 正确的Win主机网站伪静态设置方法
- HDUOJ----(1084)What Is Your Grade?
- HDUOJ------(1272)小希的迷宫
- HDUOJ ---1269迷宫城堡
- HDUOJ---1213How Many Tables
- hduoj----(1033)Edge
- HDUOJ----(1031)Design T-Shirt
- HDUOJ----(1030)Delta-wave
- 身份切换脚本,免登入切换权限的利器
- HDUOJ---What Are You Talking About
- HDUOJ-----(1251)统计难题
- HDUOJ-----1541 Stars
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- ubuntu16.04下qt5.14报错:/home/zhangfakai/Qt5.14.1/5.14.1/gcc_64/include/QtGui/qopengl.h:141: error: GL/
- 每天手撕一道算法-64. 最小路径和
- 每日手撕一道算法题-322.零钱兑换
- 每天手撕一道算法题-130. 被围绕的区域
- TKE上部署metrics-server
- Docker-Compose搭建mysql、redis、zookeeper、rabbitmq、consul、elasticsearch环境
- MDK更改配色方案
- Apache通过多端口配置多站点
- FatFs-目录下文件扫描
- Python之Bilibili自动更新邮件提醒并任务栏图标「完整代码」
- STC15频率产生器(粗调+微调+数码管显示)完整代码
- PID算法原理、调整规律及代码
- GIT——分布式版本控制系统
- 如何在 PHP 中使用和管理 Cookie
- 玩转 PhpStorm 系列(九):代码调试篇(上)