keras导入weights方式
keras源码engine中toplogy.py定义了加载权重的函数:
load_weights(self, filepath, by_name=False)
其中默认by_name为False,这时候加载权重按照网络拓扑结构加载,适合直接使用keras中自带的网络模型,如VGG16
VGG19/resnet50等,源码描述如下:
If `by_name` is False (default) weights are loaded based on the network’s topology, meaning the architecture should be the same as when the weights were saved. Note that layers that don’t have weights are not taken into account in the topological ordering, so adding or removing layers is fine as long as they don’t have weights.
若将by_name改为True则加载权重按照layer的name进行,layer的name相同时加载权重,适合用于改变了
模型的相关结构或增加了节点但利用了原网络的主体结构情况下使用,源码描述如下:
If `by_name` is True, weights are loaded into layers only if they share the same name. This is useful for fine-tuning or transfer-learning models where some of the layers have changed.
在进行边缘检测时,利用VGG网络的主体结构,网络中增加反卷积层,这时加载权重应该使用
model.load_weights(filepath,by_name=True)
补充知识:Keras下实现mnist手写数字
之前一直在用tensorflow,被同学推荐来用keras了,把之前文档中的mnist手写数字数据集拿来练手,
代码如下。
import struct
import numpy as np
import os
import keras
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import SGD
def load_mnist(path, kind):
labels_path = os.path.join(path, '%s-labels.idx1-ubyte' % kind)
images_path = os.path.join(path, '%s-images.idx3-ubyte' % kind)
with open(labels_path, 'rb') as lbpath:
magic, n = struct.unpack(' II', lbpath.read(8))
labels = np.fromfile(lbpath, dtype=np.uint8)
with open(images_path, 'rb') as imgpath:
magic, num, rows, cols = struct.unpack(" IIII", imgpath.read(16))
images = np.fromfile(imgpath, dtype=np.uint8).reshape(len(labels), 784) #28*28=784
return images, labels
#loading train and test data
X_train, Y_train = load_mnist('.data', kind='train')
X_test, Y_test = load_mnist('.data', kind='t10k')
#turn labels to one_hot code
Y_train_ohe = keras.utils.to_categorical(Y_train, num_classes=10)
#define models
model = Sequential()
model.add(Dense(input_dim=X_train.shape[1],output_dim=50,init='uniform',activation='tanh'))
model.add(Dense(input_dim=50,output_dim=50,init='uniform',activation='tanh'))
model.add(Dense(input_dim=50,output_dim=Y_train_ohe.shape[1],init='uniform',activation='softmax'))
sgd = SGD(lr=0.001, decay=1e-7, momentum=0.9, nesterov=True)
model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=["accuracy"])
#start training
model.fit(X_train,Y_train_ohe,epochs=50,batch_size=300,shuffle=True,verbose=1,validation_split=0.3)
#count accuracy
y_train_pred = model.predict_classes(X_train, verbose=0)
train_acc = np.sum(Y_train == y_train_pred, axis=0) / X_train.shape[0]
print('Training accuracy: %.2f%%' % (train_acc * 100))
y_test_pred = model.predict_classes(X_test, verbose=0)
test_acc = np.sum(Y_test == y_test_pred, axis=0) / X_test.shape[0]
print('Test accuracy: %.2f%%' % (test_acc * 100))
训练结果如下:
Epoch 45/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.2174 - acc: 0.9380 - val_loss: 0.2341 - val_acc: 0.9323
Epoch 46/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.2061 - acc: 0.9404 - val_loss: 0.2244 - val_acc: 0.9358
Epoch 47/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.1994 - acc: 0.9413 - val_loss: 0.2295 - val_acc: 0.9347
Epoch 48/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.2003 - acc: 0.9413 - val_loss: 0.2224 - val_acc: 0.9350
Epoch 49/50
42000/42000 [==============================] - 1s 18us/step - loss: 0.2013 - acc: 0.9417 - val_loss: 0.2248 - val_acc: 0.9359
Epoch 50/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.1960 - acc: 0.9433 - val_loss: 0.2300 - val_acc: 0.9346
Training accuracy: 94.11%
Test accuracy: 93.61%
以上这篇keras导入weights方式就是小编分享给大家的全部内容了,希望能给大家一个参考。
- 为什么网站需要用CDN来加速?
- Jmeter常用获取数据的几种方式
- [Silverlight 4 RC]RichTextBox概览
- WCF后续之旅(4):WCF Extension Point 概览
- Asp.Net无刷新上传并裁剪头像
- 用泛型的IEqualityComparer<T>接口去重复项
- python与office(一)
- Asp.net 后台添加CSS、JS、Meta标签(帮助类)
- 分享一下cookies操作(增、删、改、查)小经验
- [Silverlight 4 RC]WebBrowserBrush概览
- 一个例子理解C#位移
- WCF后续之旅(3): WCF Service Mode Layer 的中枢—Dispatcher
- silverlight 2 Random 随机数解决方案
- 开发中巧用Enum枚举类型
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- kotlin项目加入Glide图片加载库并使用GlideApp的方法
- Android实现百分比下载进度条效果
- 实验2 OpenGL交互
- 深入了解OkHttp3之Interceptors
- 实验3.1 直线光栅化(键盘交互版)
- 150行Python代码实现带界面的数独游戏
- 实验4 编码裁剪算法
- 浅谈Python中os模块及shutil模块的常规操作
- Python decorator拦截器代码实例解析
- 实验4.1 编码裁剪算法(鼠标交互版)
- python实现对变位词的判断方法
- python实现一个猜拳游戏
- 实验5 OpenGL二维几何变换
- Python关键字及可变参数*args,**kw原理解析
- 实验6 OpenGL模型视图变换