【python实现卷积神经网络】开始训练
代码来源:https://github.com/eriklindernoren/ML-From-Scratch
卷积神经网络中卷积层Conv2D(带stride、padding)的具体实现:https://www.cnblogs.com/xiximayou/p/12706576.html
激活函数的实现(sigmoid、softmax、tanh、relu、leakyrelu、elu、selu、softplus):https://www.cnblogs.com/xiximayou/p/12713081.html
损失函数定义(均方误差、交叉熵损失):https://www.cnblogs.com/xiximayou/p/12713198.html
优化器的实现(SGD、Nesterov、Adagrad、Adadelta、RMSprop、Adam):https://www.cnblogs.com/xiximayou/p/12713594.html
卷积层反向传播过程:https://www.cnblogs.com/xiximayou/p/12713930.html
全连接层实现:https://www.cnblogs.com/xiximayou/p/12720017.html
批量归一化层实现:https://www.cnblogs.com/xiximayou/p/12720211.html
池化层实现:https://www.cnblogs.com/xiximayou/p/12720324.html
padding2D实现:https://www.cnblogs.com/xiximayou/p/12720454.html
Flatten层实现:https://www.cnblogs.com/xiximayou/p/12720518.html
上采样层UpSampling2D实现:https://www.cnblogs.com/xiximayou/p/12720558.html
Dropout层实现:https://www.cnblogs.com/xiximayou/p/12720589.html
激活层实现:https://www.cnblogs.com/xiximayou/p/12720622.html
定义训练和测试过程:https://www.cnblogs.com/xiximayou/p/12725873.html
代码在mlfromscratch/examples/convolutional_neural_network.py 中:
from __future__ import print_function
from sklearn import datasets
import matplotlib.pyplot as plt
import math
import numpy as np
# Import helper functions
from mlfromscratch.deep_learning import NeuralNetwork
from mlfromscratch.utils import train_test_split, to_categorical, normalize
from mlfromscratch.utils import get_random_subsets, shuffle_data, Plot
from mlfromscratch.utils.data_operation import accuracy_score
from mlfromscratch.deep_learning.optimizers import StochasticGradientDescent, Adam, RMSprop, Adagrad, Adadelta
from mlfromscratch.deep_learning.loss_functions import CrossEntropy
from mlfromscratch.utils.misc import bar_widgets
from mlfromscratch.deep_learning.layers import Dense, Dropout, Conv2D, Flatten, Activation, MaxPooling2D
from mlfromscratch.deep_learning.layers import AveragePooling2D, ZeroPadding2D, BatchNormalization, RNN
def main():
#----------
# Conv Net
#----------
optimizer = Adam()
data = datasets.load_digits()
X = data.data
y = data.target
# Convert to one-hot encoding
y = to_categorical(y.astype("int"))
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4, seed=1)
# Reshape X to (n_samples, channels, height, width)
X_train = X_train.reshape((-1,1,8,8))
X_test = X_test.reshape((-1,1,8,8))
clf = NeuralNetwork(optimizer=optimizer,
loss=CrossEntropy,
validation_data=(X_test, y_test))
clf.add(Conv2D(n_filters=16, filter_shape=(3,3), stride=1, input_shape=(1,8,8), padding='same'))
clf.add(Activation('relu'))
clf.add(Dropout(0.25))
clf.add(BatchNormalization())
clf.add(Conv2D(n_filters=32, filter_shape=(3,3), stride=1, padding='same'))
clf.add(Activation('relu'))
clf.add(Dropout(0.25))
clf.add(BatchNormalization())
clf.add(Flatten())
clf.add(Dense(256))
clf.add(Activation('relu'))
clf.add(Dropout(0.4))
clf.add(BatchNormalization())
clf.add(Dense(10))
clf.add(Activation('softmax'))
print ()
clf.summary(name="ConvNet")
train_err, val_err = clf.fit(X_train, y_train, n_epochs=50, batch_size=256)
# Training and validation error plot
n = len(train_err)
training, = plt.plot(range(n), train_err, label="Training Error")
validation, = plt.plot(range(n), val_err, label="Validation Error")
plt.legend(handles=[training, validation])
plt.title("Error Plot")
plt.ylabel('Error')
plt.xlabel('Iterations')
plt.show()
_, accuracy = clf.test_on_batch(X_test, y_test)
print ("Accuracy:", accuracy)
y_pred = np.argmax(clf.predict(X_test), axis=1)
X_test = X_test.reshape(-1, 8*8)
# Reduce dimension to 2D using PCA and plot the results
Plot().plot_in_2d(X_test, y_pred, title="Convolutional Neural Network", accuracy=accuracy, legend_labels=range(10))
if __name__ == "__main__":
main()
我们还是一步步进行分析:
1、优化器使用Adam()
2、数据集使用的是sklearn.datasets中的手写数字,其部分数据如下:
(1797, 64)
(1797,)
[[ 0. 0. 5. 13. 9. 1. 0. 0. 0. 0. 13. 15. 10. 15. 5. 0. 0. 3.
15. 2. 0. 11. 8. 0. 0. 4. 12. 0. 0. 8. 8. 0. 0. 5. 8. 0.
0. 9. 8. 0. 0. 4. 11. 0. 1. 12. 7. 0. 0. 2. 14. 5. 10. 12.
0. 0. 0. 0. 6. 13. 10. 0. 0. 0.]
[ 0. 0. 0. 12. 13. 5. 0. 0. 0. 0. 0. 11. 16. 9. 0. 0. 0. 0.
3. 15. 16. 6. 0. 0. 0. 7. 15. 16. 16. 2. 0. 0. 0. 0. 1. 16.
16. 3. 0. 0. 0. 0. 1. 16. 16. 6. 0. 0. 0. 0. 1. 16. 16. 6.
0. 0. 0. 0. 0. 11. 16. 10. 0. 0.]
[ 0. 0. 0. 4. 15. 12. 0. 0. 0. 0. 3. 16. 15. 14. 0. 0. 0. 0.
8. 13. 8. 16. 0. 0. 0. 0. 1. 6. 15. 11. 0. 0. 0. 1. 8. 13.
15. 1. 0. 0. 0. 9. 16. 16. 5. 0. 0. 0. 0. 3. 13. 16. 16. 11.
5. 0. 0. 0. 0. 3. 11. 16. 9. 0.]
[ 0. 0. 7. 15. 13. 1. 0. 0. 0. 8. 13. 6. 15. 4. 0. 0. 0. 2.
1. 13. 13. 0. 0. 0. 0. 0. 2. 15. 11. 1. 0. 0. 0. 0. 0. 1.
12. 12. 1. 0. 0. 0. 0. 0. 1. 10. 8. 0. 0. 0. 8. 4. 5. 14.
9. 0. 0. 0. 7. 13. 13. 9. 0. 0.]
[ 0. 0. 0. 1. 11. 0. 0. 0. 0. 0. 0. 7. 8. 0. 0. 0. 0. 0.
1. 13. 6. 2. 2. 0. 0. 0. 7. 15. 0. 9. 8. 0. 0. 5. 16. 10.
0. 16. 6. 0. 0. 4. 15. 16. 13. 16. 1. 0. 0. 0. 0. 3. 15. 10.
0. 0. 0. 0. 0. 2. 16. 4. 0. 0.]
[ 0. 0. 12. 10. 0. 0. 0. 0. 0. 0. 14. 16. 16. 14. 0. 0. 0. 0.
13. 16. 15. 10. 1. 0. 0. 0. 11. 16. 16. 7. 0. 0. 0. 0. 0. 4.
7. 16. 7. 0. 0. 0. 0. 0. 4. 16. 9. 0. 0. 0. 5. 4. 12. 16.
4. 0. 0. 0. 9. 16. 16. 10. 0. 0.]
[ 0. 0. 0. 12. 13. 0. 0. 0. 0. 0. 5. 16. 8. 0. 0. 0. 0. 0.
13. 16. 3. 0. 0. 0. 0. 0. 14. 13. 0. 0. 0. 0. 0. 0. 15. 12.
7. 2. 0. 0. 0. 0. 13. 16. 13. 16. 3. 0. 0. 0. 7. 16. 11. 15.
8. 0. 0. 0. 1. 9. 15. 11. 3. 0.]
[ 0. 0. 7. 8. 13. 16. 15. 1. 0. 0. 7. 7. 4. 11. 12. 0. 0. 0.
0. 0. 8. 13. 1. 0. 0. 4. 8. 8. 15. 15. 6. 0. 0. 2. 11. 15.
15. 4. 0. 0. 0. 0. 0. 16. 5. 0. 0. 0. 0. 0. 9. 15. 1. 0.
0. 0. 0. 0. 13. 5. 0. 0. 0. 0.]
[ 0. 0. 9. 14. 8. 1. 0. 0. 0. 0. 12. 14. 14. 12. 0. 0. 0. 0.
9. 10. 0. 15. 4. 0. 0. 0. 3. 16. 12. 14. 2. 0. 0. 0. 4. 16.
16. 2. 0. 0. 0. 3. 16. 8. 10. 13. 2. 0. 0. 1. 15. 1. 3. 16.
8. 0. 0. 0. 11. 16. 15. 11. 1. 0.]
[ 0. 0. 11. 12. 0. 0. 0. 0. 0. 2. 16. 16. 16. 13. 0. 0. 0. 3.
16. 12. 10. 14. 0. 0. 0. 1. 16. 1. 12. 15. 0. 0. 0. 0. 13. 16.
9. 15. 2. 0. 0. 0. 0. 3. 0. 9. 11. 0. 0. 0. 0. 0. 9. 15.
4. 0. 0. 0. 9. 12. 13. 3. 0. 0.]]
[0 1 2 3 4 5 6 7 8 9]
3、接着有一个to_categorical()函数,在mlfromscratch.utils下的data_manipulation.py中:
def to_categorical(x, n_col=None):
""" One-hot encoding of nominal values """
if not n_col:
n_col = np.amax(x) + 1
one_hot = np.zeros((x.shape[0], n_col))
one_hot[np.arange(x.shape[0]), x] = 1
return one_hot
用于将标签转换为one-hot编码。
4、划分训练集和测试集:train_test_split(),在mlfromscratch.utils下的data_manipulation.py中:
def train_test_split(X, y, test_size=0.5, shuffle=True, seed=None):
""" Split the data into train and test sets """
if shuffle:
X, y = shuffle_data(X, y, seed)
# Split the training data from test data in the ratio specified in
# test_size
split_i = len(y) - int(len(y) // (1 / test_size))
X_train, X_test = X[:split_i], X[split_i:]
y_train, y_test = y[:split_i], y[split_i:]
return X_train, X_test, y_train, y_test
5、由于卷积神经网络的输入是[batchsize,channel,wheight,width]的维度,因此要将原始数据进行转换,即将(1797,64)转换为(1797,1,8,8)格式的数据。这里batchsize就是样本的数量。
6、定义卷积神经网络的训练和测试过程:包括优化器、损失函数、测试数据
7、定义模型结构
8、输出模型每层的类型、参数数量以及输出大小
9、将数据输入到模型中,设置epochs的大小以及batch_size的大小
10、计算训练和测试的错误,并绘制成图
11、计算准确率
12、绘制测试集中每一类预测的结果,这里有一个plot_in_2d()函数,位于mlfromscratch.utils下的misc.py中
# Plot the dataset X and the corresponding labels y in 2D using PCA.
def plot_in_2d(self, X, y=None, title=None, accuracy=None, legend_labels=None):
X_transformed = self._transform(X, dim=2)
x1 = X_transformed[:, 0]
x2 = X_transformed[:, 1]
class_distr = []
y = np.array(y).astype(int)
colors = [self.cmap(i) for i in np.linspace(0, 1, len(np.unique(y)))]
# Plot the different class distributions
for i, l in enumerate(np.unique(y)):
_x1 = x1[y == l]
_x2 = x2[y == l]
_y = y[y == l]
class_distr.append(plt.scatter(_x1, _x2, color=colors[i]))
# Plot legend
if not legend_labels is None:
plt.legend(class_distr, legend_labels, loc=1)
# Plot title
if title:
if accuracy:
perc = 100 * accuracy
plt.suptitle(title)
plt.title("Accuracy: %.1f%%" % perc, fontsize=10)
else:
plt.title(title)
# Axis labels
plt.xlabel('Principal Component 1')
plt.ylabel('Principal Component 2')
plt.show()
接下来就可以实际进行操作了,我是在谷歌colab中,首先使用:
!git clone https://github.com/eriklindernoren/ML-From-Scratch.git
将相关代码复制下来。
然后进行安装:在ML-From-Scratch目录下输入:
!python setup.py install
最后输入:
!python mlfromscratch/examples/convolutional_neural_network.py
最终结果:
+---------+
| ConvNet |
+---------+
Input Shape: (1, 8, 8)
+----------------------+------------+--------------+
| Layer Type | Parameters | Output Shape |
+----------------------+------------+--------------+
| Conv2D | 160 | (16, 8, 8) |
| Activation (ReLU) | 0 | (16, 8, 8) |
| Dropout | 0 | (16, 8, 8) |
| BatchNormalization | 2048 | (16, 8, 8) |
| Conv2D | 4640 | (32, 8, 8) |
| Activation (ReLU) | 0 | (32, 8, 8) |
| Dropout | 0 | (32, 8, 8) |
| BatchNormalization | 4096 | (32, 8, 8) |
| Flatten | 0 | (2048,) |
| Dense | 524544 | (256,) |
| Activation (ReLU) | 0 | (256,) |
| Dropout | 0 | (256,) |
| BatchNormalization | 512 | (256,) |
| Dense | 2570 | (10,) |
| Activation (Softmax) | 0 | (10,) |
+----------------------+------------+--------------+
Total Parameters: 538570
Training: 100% [------------------------------------------------] Time: 0:01:32
<Figure size 640x480 with 1 Axes>
Accuracy: 0.9846796657381616
<Figure size 640x480 with 1 Axes>
至此,结合代码一步一步看卷积神经网络的整个实现过程就完成了。通过结合代码的形式,可以加深对深度学习中卷积神经网络相关知识的理解。
- [认证授权] 3.基于OAuth2的认证(译)
- [Asp.Net Core] 1. IIS中的 Asp.Net Core 和 dotnet watch
- kafka数据迁移实践
- HDFS 2.x 磁盘间数据均衡的一种可行办法
- Batik渲染png图片异常的bug修复全程记录
- Web应用服务器安全:攻击、防护与检测
- 基于Go Packet实现网络数据包的捕获与分析
- 动态追踪技术(四):基于 Linux bcc/BPF 实现 Go 程序动态追踪
- Hive 时间转换函数使用心得
- Flume-Hbase-Sink针对不同版本flume与HBase的适配研究与经验总结
- 利用Flume 汇入数据到HBase:Flume-hbase-sink 使用方法详解
- 浅谈保证软件工程质量的一些心得体会
- 基于ELK的nginx-qps监控解决方案
- 2017年年度最烂密码排名
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- MySQL 案例:大表改列的新技巧(Generated Column)
- Spark 2.2 on K8S Dynamic Resource Allocation
- Java中异常处理的9个最佳实践
- Mongodb执行计划
- Spark 2.2/2.3/2.4 的 Dynamic Resource Allocation
- 04-操作文件与目录
- 05-命令的使用
- 缓存穿透、缓存击穿、缓存雪崩看这篇就够了,文末还送福利哦!
- 百万并发「零拷贝」技术系列之经典案例Netty
- 什么是大O表示法
- 06-1重定向
- sbt 项目导入问题
- Spark 程序优化建议
- Spark persist MEMORY_AND_DISK & DISK_ONLY
- jvm类加载机制,双亲委派机制,看这一篇就够了