手写三层神经网络完成手写体识别任务
时间:2021-10-11
本文章向大家介绍手写三层神经网络完成手写体识别任务,主要包括手写三层神经网络完成手写体识别任务使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。
import numpy as np import numpy.random import scipy.special class NeuralNetwork: # initialise the neural network def __init__(self, input_nodes_num, hidden_nodes_num, output_nodes_num, lr): # 初始化神经元个数,可以直接修改 self.input_nodes = input_nodes_num self.hidden_nodes = hidden_nodes_num self.output_nodes = output_nodes_num self.learning_rate = lr # 初始化权重值,利用正态分布函数进行随机初始化,均值为0,方差为神经元个数开方 self.w_input_hidden = numpy.random.normal(0.0, pow(self.hidden_nodes, -0.5), (self.hidden_nodes, self.input_nodes)) self.w_hidden_output = numpy.random.normal(0.0, pow(self.output_nodes, -0.5), (self.output_nodes, self.hidden_nodes)) # 初始化激活函数,激活函数选用Sigmoid函数,更加平滑,接近自然界的神经元行为模式 # lambda定义了一个匿名函数 self.activation_function = lambda x: scipy.special.expit(x) pass # train the neural network def train(self, inputs_list, targets_list): # 将训练集和测试集中的数据转化为列向量 inputs = np.array(inputs_list, ndmin=2).T targets = np.array(targets_list, ndmin=2).T # 隐藏层的输入为训练集与权重值的点乘,输出为激活函数的输出 hidden_inputs = np.dot(self.w_input_hidden, inputs) hidden_outputs = self.activation_function(hidden_inputs) # 输出层的输入为隐藏层的输出,输出为最终结果 final_inputs = np.dot(self.w_hidden_output, hidden_outputs) final_outputs = self.activation_function(final_inputs) # 损失函数 output_errors = targets - final_outputs # 隐藏层的误差为权值矩阵的转置与输出误差的点乘 hidden_errors = np.dot(self.w_hidden_output.T, output_errors) # 对权值进行更新 self.w_hidden_output += self.learning_rate * np.dot((output_errors * final_outputs * (1.0 - final_outputs)), np.transpose(hidden_outputs)) self.w_input_hidden += self.learning_rate * np.dot((hidden_errors * hidden_outputs * (1.0 - hidden_outputs)), np.transpose(inputs)) pass # query the neural network def query(self, inputs_list): # 转置将行向量转成列向量,将每组数据更好的分隔开来,方便后续矩阵点乘操作 inputs = np.array(inputs_list, ndmin=2).T # 加权求和后经过sigmoid函数得到隐藏层输出 hidden_inputs = np.dot(self.w_input_hidden, inputs) hidden_outputs = self.activation_function(hidden_inputs) # 加权求和后经过sigmoid函数得到最终输出 final_inputs = np.dot(self.w_hidden_output, hidden_outputs) final_outputs = self.activation_function(final_inputs) # 得到输出数据列 return final_outputs # 初始化各层神经元个数,期中输入神经元个数取决于读入的因变量,而输出神经元个数取决于分类的可能性个数 input_nodes = 784 hidden_nodes = 500 output_nodes = 10 # 学习率,每次调整步幅大小 learning_rate = 0.2 n = NeuralNetwork(input_nodes, hidden_nodes, output_nodes, learning_rate) # 获取训练集信息 training_data_file = open('data/mnist_train.csv', 'r') training_data_list = training_data_file.readlines() training_data_file.close() for record in training_data_list: all_values = record.split(',') inputs = (numpy.asfarray(all_values[1:]) / 255.0 * 0.99) + 0.01 targets = numpy.zeros(output_nodes) + 0.01 targets[int(all_values[0])] = 0.99 n.train(inputs, targets) pass print('train successful!') test_file = open('data/mnist_test.csv', 'r') test_list = test_file.readlines() test_file.close() m = np.size(test_list) j = 0.0 for record in test_list: test_values = record.split(',') np.asfarray(test_values) results = n.query(np.asfarray(test_values[1:])) if results[int(test_values[0])] == max(results): j += 1 pass print("正确率为;" + str(j/m))
原文地址:https://www.cnblogs.com/sevent/p/15394345.html
- 一条细小的报警短信的处理(r6笔记第96天)
- 1.react的基础知识
- 防火墙设置的小问题(r6笔记第94天)
- 有没有必要把机器学习算法自己实现一遍?
- python中从str中提取元素到list以及将list转换为str
- 简单易学的机器学习算法——线性回归(2)
- Java基础-26(01)总结网络编程
- undo retention的思考(一)
- 优化算法——人工蜂群算法(ABC)
- 用GPU加速深度学习: Windows安装CUDA+TensorFlow教程
- 由报警邮件分析发现的备库oracle bug(r7笔记第12天)
- Python中的__init__()方法整理中(两种解释)
- 如何找到最优学习率?
- 简单易学的机器学习算法——Rosenblatt感知机
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- Linux中把用户添加到组的4个方法总结
- Linux下配置jdk环境的方法
- Ubuntu 16.04/18.04 安装Pycharm及Ipython的教程
- linux系统对外开放3306、8080等端口,防火墙设置详解
- Linux中selinux基础配置教程详解
- Linux中如何查看已挂载的文件系统类型详解
- 在 Linux 命令行中使用 tcpdump 抓包的一些功能
- CentOS平台快速搭建LAMP环境的方法
- Linux系统中时间的获取和使用
- 基于Linux搭建Apache网站服务配置详解
- CentOs下手动升级node版本的方法
- 详述Linux中Firewalld高级配置的使用
- CentOS7安装PHP7 Redis扩展的方法步骤
- centos7下rsync+crontab定期同步备份
- 你可能不知道的一些linux文件权限管理方法