Keras-多输入多输出实例(多任务)
时间:2022-07-27
本文章向大家介绍Keras-多输入多输出实例(多任务),主要内容包括其使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。
1、模型结果设计
2、代码
from keras import Input, Model
from keras.layers import Dense, Concatenate
import numpy as np
from keras.utils import plot_model
from numpy import random as rd
samples_n = 3000
samples_dim_01 = 2
samples_dim_02 = 2
# 样本数据
x1 = rd.rand(samples_n, samples_dim_01)
x2 = rd.rand(samples_n, samples_dim_02)
y_1 = []
y_2 = []
y_3 = []
for x11, x22 in zip(x1, x2):
y_1.append(np.sum(x11) + np.sum(x22))
y_2.append(np.max([np.max(x11), np.max(x22)]))
y_3.append(np.min([np.min(x11), np.min(x22)]))
y_1 = np.array(y_1)
y_1 = np.expand_dims(y_1, axis=1)
y_2 = np.array(y_2)
y_2 = np.expand_dims(y_2, axis=1)
y_3 = np.array(y_3)
y_3 = np.expand_dims(y_3, axis=1)
# 输入层
inputs_01 = Input((samples_dim_01,), name='input_1')
inputs_02 = Input((samples_dim_02,), name='input_2')
# 全连接层
dense_01 = Dense(units=3, name="dense_01", activation='softmax')(inputs_01)
dense_011 = Dense(units=3, name="dense_011", activation='softmax')(dense_01)
dense_02 = Dense(units=6, name="dense_02", activation='softmax')(inputs_02)
# 加入合并层
merge = Concatenate()([dense_011, dense_02])
# 分成两类输出 --- 输出01
output_01 = Dense(units=6, activation="relu", name='output01')(merge)
output_011 = Dense(units=1, activation=None, name='output011')(output_01)
# 分成两类输出 --- 输出02
output_02 = Dense(units=1, activation=None, name='output02')(merge)
# 分成两类输出 --- 输出03
output_03 = Dense(units=1, activation=None, name='output03')(merge)
# 构造一个新模型
model = Model(inputs=[inputs_01, inputs_02], outputs=[output_011,
output_02,
output_03
])
# 显示模型情况
plot_model(model, show_shapes=True)
print(model.summary())
# # 编译
# model.compile(optimizer="adam", loss='mean_squared_error', loss_weights=[1,
# 0.8,
# 0.8
# ])
# # 训练
# model.fit([x1, x2], [y_1,
# y_2,
# y_3
# ], epochs=50, batch_size=32, validation_split=0.1)
# 以下的方法可灵活设置
model.compile(optimizer='adam',
loss={'output011': 'mean_squared_error',
'output02': 'mean_squared_error',
'output03': 'mean_squared_error'},
loss_weights={'output011': 1,
'output02': 0.8,
'output03': 0.8})
model.fit({'input_1': x1,
'input_2': x2},
{'output011': y_1,
'output02': y_2,
'output03': y_3},
epochs=50, batch_size=32, validation_split=0.1)
# 预测
test_x1 = rd.rand(1, 2)
test_x2 = rd.rand(1, 2)
test_y = model.predict(x=[test_x1, test_x2])
# 测试
print("测试结果:")
print("test_x1:", test_x1, "test_x2:", test_x2, "y:", test_y, np.sum(test_x1) + np.sum(test_x2))
补充知识:Keras多输出(多任务)如何设置fit_generator
在使用Keras的时候,因为需要考虑到效率问题,需要修改fit_generator来适应多输出
# create model
model = Model(inputs=x_inp, outputs=[main_pred, aux_pred])
# complie model
model.compile(
optimizer=optimizers.Adam(lr=learning_rate),
loss={"main": weighted_binary_crossentropy(weights), "auxiliary":weighted_binary_crossentropy(weights)},
loss_weights={"main": 0.5, "auxiliary": 0.5},
metrics=[metrics.binary_accuracy],
)
# Train model
model.fit_generator(
train_gen, epochs=num_epochs, verbose=0, shuffle=True
)
看Keras官方文档:
generator: A generator or an instance of Sequence (keras.utils.Sequence) object in order to avoid duplicate data when using multiprocessing. The output of the generator must be either a tuple (inputs, targets) a tuple (inputs, targets, sample_weights).
Keras设计多输出(多任务)使用fit_generator的步骤如下:
根据官方文档,定义一个generator或者一个class继承Sequence
class Batch_generator(Sequence):
"""
用于产生batch_1, batch_2(记住是numpy.array格式转换)
"""
y_batch = {'main':batch_1,'auxiliary':batch_2}
return X_batch, y_batch
# or in another way
def batch_generator():
"""
用于产生batch_1, batch_2(记住是numpy.array格式转换)
"""
yield X_batch, {'main': batch_1,'auxiliary':batch_2}
重要的事情说三遍(亲自采坑,搜了一大圈才发现滴):
如果是多输出(多任务)的时候,这里的target是字典类型
如果是多输出(多任务)的时候,这里的target是字典类型
如果是多输出(多任务)的时候,这里的target是字典类型
以上这篇Keras-多输入多输出实例(多任务)就是小编分享给大家的全部内容了,希望能给大家一个参考。
- mongoDB报错Cannot find module '../build/Release/bson'
- 计算机视觉处理三大任务:分类、定位和检测
- Windows下RabbitMQ安装及入门
- 计算机视觉任务:图像梯度和图像完成
- Yarn【label-based scheduling】实战总结(一)
- 配置sonarqube+maven
- Yarn【label-based scheduling】实战总结(二)
- HDFS学习:HDFS机架感知与副本放置策略
- spring cloud 报错Error creating bean with name 'hystrixCommandAspect' ,解决方案
- Spring Security OAuth2 Demo
- SpringBoot学习:整合shiro(身份认证和权限认证),使用EhCache缓存
- 线性回归与最小二乘法 | 机器学习笔记
- 添加sqljdbc4的maven依赖
- MyBatis 实现关联表查询
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- PHP递归统计系统中代码行数
- PHP切割整数工具类似微信红包金额分配的思路详解
- php写入文件不覆盖的实例讲解
- php解决crontab定时任务不能写入文件问题的方法分析
- Laravel项目中timeAgo字段语言转换的改善方法示例
- php生成微信红包数组的方法
- 解决php写入数据库乱码的问题
- php写入txt乱码的解决方法
- PHP实现的AES 128位加密算法示例
- php写入mysql中文乱码的实例解决方法
- php实现的支付宝网页支付功能示例【基于TP5框架】
- php校验公钥是否可用的实例方法
- PHP实现的微信APP支付功能示例【基于TP5框架】
- php创建多级目录与级联删除文件的方法示例
- Linux VPS定时备份服务器/网站数据到Github私人仓库