详解自动识别验证码,LSTM大显身手
这是去年博主心血来潮实现的一个小模型,现在把它总结一下。由于楼主比较懒,网上许多方法都需要切割图片,但是楼主思索了一下感觉让模型有多个输出就可以了呀,没必要一定要切割的吧?切不好还需要损失信息啊!本文比较简单,只基于传统的验证码。
Part 0 模型概览
从图片到序列实际上就是 Image2text 也就是 seq2seq 的一种。encoder 是 Image, decoder 是验证码序列。由于 keras 不支持传统的在 decoder 部分每个 cell 输出需要作为下一个 rnn 的 cell 的输入 (见下图),所以我们这里把 decoder 部分的输入用 encoder(image)的最后一层复制 N 份作为 decoder 部分的每个 cell 的输入。
典型的 seq2seq
keras 可以直接实现的 image2text
当然利用 recurrentshop 和 seq2seq,我们也可以实现标准的 seq2seq 的网络结构 (后文会写)。
Part I 收集数据
网上还是有一些数据集可以用的,包括 dataCastle 也举办过验证码识别的比赛,都有现成的标注好了的数据集。(然而难点是各种花式验证码啊,填字的,滑动的,还有那个基于语义的 reCaptcha~)。
因为我想弄出各种长度的验证码,所以我还是在 github 上下载了一个生成验证码的 python 包。
下载后,按照例子生成验证码 (包含 26 个小写英文字母):
#!/usr/bin/env python# -*- coding: utf-8from captcha.image import ImageCaptchafrom random import sample
image = ImageCaptcha() #fonts=[ "font/Xenotron.ttf"]characters = list("abcdefghijklmnopqrstuvwxyz")def generate_data(digits_num, output, total):
num = 0
while(num<total):
cur_cap = sample(characters, digits_num)
cur_cap =''.join(cur_cap)
_ = image.generate(cur_cap)
image.write(cur_cap, output+cur_cap+".png")
num += 1generate_data(4, "images/four_digit/", 10000) # 产生四个字符长度的验证码generate_data(5, "images/five_digit/", 10000) #产生五个字符长度的验证码generate_data(6, "images/six_digit/", 10000) #产生六个字符长度的验证码generate_data(7, "images/seven_digit/",10000) # 产生七个字符长度的验证码产生的验证码
(目测了一下生成验证码的包的代码,发现主要是在 x,y 轴上做一些变换,加入一些噪音)
Part II 预处理
由于生成的图片不是相同尺寸的,为了方便训练我们需要转换成相同尺寸的。另外由于验证码长度不同,我们需要在 label 上多加一个符号来表示这个序列的结束。
处理之后的结果就是图像 size 全部为 Height=60, Width=250, Channel=3。label 全部用字符 id 表示,并且末尾加上表示 <EOF> 的 id。比如假设 a-z 的 id 为 0-25,<EOF > 的 id 为 26,那么对于验证码 "abdf" 的 label 也就是 [0,1,3,5,26,26,26,26],"abcdefg" 的 label 为 [0,1,2,3,4,5,6,26]。
由于我们用的是 categorical_crossentropy 来判断每个输出的结果,所以对 label 我们还需要把其变成 one-hot 的形式,那么用 Keras 现成的工具 to_categorical 函数对上面的 label 做一下处理就可以了。比如 abdf 的 label 进一步转换成:
[[1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],
[0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],
[0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],
[0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],
[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1],
[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1],
[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1],
[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1]]
Part III 构建模型
不借助外部包可以实现的模型
def create_simpleCnnRnn(image_shape, max_caption_len,vocab_size):
image_model = Sequential() # image_shape : C,W,H
# input: 100x100 images with 3 channels -> (3, 100, 100) tensors.
# this applies 32 convolution filters of size 3x3 each.
image_model.add(Convolution2D(32, 3, 3, border_mode='valid', input_shape=image_shape))
image_model.add(BatchNormalization())
image_model.add(Activation('relu'))
image_model.add(Convolution2D(32, 3, 3))
image_model.add(BatchNormalization())
image_model.add(Activation('relu'))
image_model.add(MaxPooling2D(pool_size=(2, 2)))
image_model.add(Dropout(0.25))
image_model.add(Convolution2D(64, 3, 3, border_mode='valid'))
image_model.add(BatchNormalization())
image_model.add(Activation('relu'))
image_model.add(Convolution2D(64, 3, 3))
image_model.add(BatchNormalization())
image_model.add(Activation('relu'))
image_model.add(MaxPooling2D(pool_size=(2, 2)))
image_model.add(Dropout(0.25))
image_model.add(Flatten()) # Note: Keras does automatic shape inference.
image_model.add(Dense(128))
image_model.add(RepeatVector(max_caption_len)) # 复制 8 份
image_model.add(Bidirectional(GRU(output_dim=128, return_sequences=True)))
image_model.add(TimeDistributed(Dense(vocab_size)))
image_model.add(Activation('softmax'))
sgd = SGD(lr=0.002, decay=1e-6, momentum=0.9, nesterov=True)
image_model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy']) return image_model
借助 recurrentshop 和 seq2seq 可以实现的结构
def create_imgText(image_shape, max_caption_len,vocab_size):
image_model = Sequential() # image_shape : C,W,H
# input: 100x100 images with 3 channels -> (3, 100, 100) tensors.
# this applies 32 convolution filters of size 3x3 each.
image_model.add(Convolution2D(32, 3, 3, border_mode='valid', input_shape=image_shape))
image_model.add(BatchNormalization())
image_model.add(Activation('relu'))
image_model.add(Convolution2D(32, 3, 3))
image_model.add(BatchNormalization())
image_model.add(Activation('relu'))
image_model.add(MaxPooling2D(pool_size=(2, 2)))
image_model.add(Dropout(0.25))
image_model.add(Convolution2D(64, 3, 3, border_mode='valid'))
image_model.add(BatchNormalization())
image_model.add(Activation('relu'))
image_model.add(Convolution2D(64, 3, 3))
image_model.add(BatchNormalization())
image_model.add(Activation('relu'))
image_model.add(MaxPooling2D(pool_size=(2, 2)))
image_model.add(Dropout(0.25))
image_model.add(Flatten()) # Note: Keras does automatic shape inference.
image_model.add(Dense(128))
image_model.add(RepeatVector(1)) # 为了兼容 seq2seq,要多包一个 []
#model = AttentionSeq2Seq(input_dim=128, input_length=1, hidden_dim=128, output_length=max_caption_len, output_dim=128, depth=2)
model = Seq2Seq(input_dim=128, input_length=1, hidden_dim=128, output_length=max_caption_len,output_dim=128, peek=True)
image_model.add(model)
image_model.add(TimeDistributed(Dense(vocab_size)))
image_model.add(Activation('softmax'))
image_model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy']) return image_model
Part IV 模型训练
之前写过固定长度的验证码的序列准确率可以达到 99%,项目可以参考这里。
另外,我们在用 Keras 训练的时候会有一个 acc,这个 acc 是指的一个字符的准确率,并不是这一串序列的准确率。也就是说在可以预期的情况下,如果你的一个字符的准确率达到了 99%,那么如果你的序列长度是 5 的时候,理论上你的序列准确率是 0.99^5 = 0.95, 如果像我们一样序列长度是 7,则为 0.99^8=0.923。
所以当你要看到实际的验证集上的准确率的时候,应该自己写一个 callback 的类来评测,只有当序列中所有的字符都和 label 一样才可以算正确。
class ValidateAcc(Callback):
def __init__(self, image_model, val_data, val_label, model_output):
self.image_model = image_model
self.val = val_data
self.val_label = val_label
self.model_output = model_output def on_epoch_end(self, epoch, logs={}): # 每个 epoch 结束后会调用该方法
print 'n———————————--------'
self.image_model.load_weights(self.model_output+'weights.%02d.hdf5' % epoch)
r = self.image_model.predict(val, verbose=0)
y_predict = np.asarray([np.argmax(i, axis=1) for i in r])
val_true = np.asarray([np.argmax(i, axis = 1) for i in self.val_label])
length = len(y_predict) * 1.0
correct = 0
for (true,predict) in zip(val_true,y_predict): print true,predict if list(true) == list(predict):
correct += 1
print "Validation set acc is:", correct/length print 'n———————————--------'val_acc_check_pointer = ValidateAcc(image_model,val,val_label,model_output)
记录每个 epoch 的模型结果
check_pointer = ModelCheckpoint(filepath=model_output + "weights.{epoch:02d}.hdf5")
训练
image_model.fit(train, train_label,shuffle=True, batch_size=16, nb_epoch=20, validation_split=0.2, callbacks=[check_pointer, val_acc_check_pointer])
Part V 训练结果
在 39866 张生成的验证码上,27906 张作为训练,11960 张作为验证集。
第一种模型:
序列训练了大约 80 轮,在验证集上最高的准确率为 0.9264, 但是很容易变化比如多跑一轮就可能变成 0.7,主要原因还是因为预测的时候考虑的是整个序列而不是单个字符,只要有一个字符没有预测准确整个序列就是错误的。
第二种模型:
第二个模型也就是上面的 create_imgText,验证集上的最高准确率差不多是 0.9655(当然我没有很仔细的去调参,感觉调的好的话两个模型应该是差不多的,验证集达到 0.96 之后相对稳定)。
Part VI 其它
看起来还是觉得 keras 实现简单的模型会比较容易,稍微变形一点的模型就很纠结了,比较好的是基础的模型用上其他包都可以实现。keras 2.0.x 开始的版本跟 1.0.x 还是有些差异的,而且 recurrentshop 现在也是支持 2.0 版本的。如果在建模型的时候想更 flexible 一点的话,还是用 tensorflow 会比较好,可以调整的东西也比较多,那下一篇可以写一下 img2txt 的 tensorflow 版本。
Part VII 代码
完整源代码:https://github.com/Slyne/CaptchaVariLength
Part VIII 后续
现在的这两个模型还是需要指定最大的长度,后面有时间会在训练集最多只有 8 个字符的情况下,利用 rnn 的最后一层进一步对于有 9 个以及以上字符的验证码效果,看看是不是可以再进一步的扩展到任意长度。(又立了一个 flag~)
- 深入理解Android Build系统
- Mac Jenkins搭建 Android/IOS自动打包环境
- javascript 红皮高程(11)
- javascript 红皮高程(8)
- javascript 红皮高程(7)
- javascript 红皮高程(17)-- 左移(<<)
- javascript 红皮高程(17)-- 按位异或(XOR)
- javascript 红皮高程(17)
- javascript 红皮高程(16)
- javascript 红皮高程(15)
- javascript 红皮高程(21)-- 乘性操作符
- javascript 红皮高程(20)-- 逻辑或
- javascript 红皮高程(19)-- 逻辑与
- 技术分享 | 浅谈 RAS
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- kubernetes(八) kubernetes的使用
- kubernetes(九) kubernetes控制器
- kubernetes(十) kubernetes service,ingress&cm,secret
- kubernetes(十一) 存储& statefulset控制器
- kubernetes(十二) 准入控制和helm v3包管理
- JS Flowchart Diagrams
- kubernetes(十三) k8s 业务上线流程(手动版)
- java+appium+安卓模拟器实现app自动化Demo
- webdriver使用已打开过的chrome
- Shortcodes
- Unexpected EOF 远程主机强迫关闭了一个现有的连接 如何处理
- npm 使用问题
- 接口自动化测试框架-AIM
- hexo 图片显示问题及使用typora设置图片路径
- 接口自动化项目实践