不定長圖片驗證碼訓練

基于LSTM和CTCLoss訓練不定長圖片驗證碼
Github項目地址:https://github.com/JansonJo/captcha_ocr.git

# coding=utf-8
"""
將三通道的圖片轉為灰度圖進行訓練
"""
import itertools
import os
import re
import random
import string
from collections import Counter
from os.path import join
import yaml
import cv2
import numpy as np
import tensorflow as tf
from keras import backend as K
from keras.callbacks import ModelCheckpoint, EarlyStopping, Callback
from keras.layers import Input, Dense, Activation, Dropout, BatchNormalization, Reshape, Lambda
from keras.layers.convolutional import Conv2D, MaxPooling2D
from keras.layers.merge import add, concatenate
from keras.layers.recurrent import GRU
from keras.models import Model, load_modelf = open('./config/config_demo.yaml', 'r', encoding='utf-8')
cfg = f.read()
cfg_dict = yaml.load(cfg)config = tf.ConfigProto()
config.gpu_options.allow_growth = True
# config.gpu_options.per_process_gpu_memory_fraction = cfg_dict['System']['GpuMemoryFraction']
session = tf.Session(config=config)
K.set_session(session)# System config
TRAIN_SET_PTAH = cfg_dict['System']['TrainSetPath']
VALID_SET_PATH = cfg_dict['System']['TestSetPath']
TEST_SET_PATH = cfg_dict['System']['TestSetPath']
MAX_TEXT_LEN = cfg_dict['System']['MaxTextLenth']
IMG_W = cfg_dict['System']['IMG_W']
IMG_H = cfg_dict['System']['IMG_H']
MODEL_NAME = cfg_dict['System']['ModelName']
LABEL_REGEX = cfg_dict['System']['LabelRegex']
ALPHABET = cfg_dict['System']['Alphabet']# NeuralNet config
RNN_SIZE = cfg_dict['NeuralNet']['RNNSize']
DROPOUT = cfg_dict['NeuralNet']['Dropout']# TrainParam config
MONITOR = cfg_dict['TrainParam']['EarlyStoping']['monitor']
PATIENCE = cfg_dict['TrainParam']['EarlyStoping']['patience']
MODE = cfg_dict['TrainParam']['EarlyStoping']['mode']
BASELINE = cfg_dict['TrainParam']['EarlyStoping']['baseline']
EPOCHS = cfg_dict['TrainParam']['Epochs']
BATCH_SIZE = cfg_dict['TrainParam']['BatchSize']
TEST_BATCH_SIZE = cfg_dict['TrainParam']['TestBatchSize']
TEST_SET_NUM = cfg_dict['TrainParam']['TestSetNum']def get_counter(dirpath):letters = ''lens = []for root, dirs, files in os.walk(dirpath):for filename in files:m = re.search(LABEL_REGEX, filename, re.M | re.I)description = m.group(1)lens.append(len(description))letters += descriptionprint('Max plate length in "%s":' % dirpath, max(Counter(lens).keys()))return Counter(letters)c_val = get_counter(VALID_SET_PATH)
c_train = get_counter(TRAIN_SET_PTAH)
letters_train = set(c_train.keys())
letters_val = set(c_val.keys())
print('letters_train: %s' % ''.join(sorted(letters_train)))
print('letters_val: %s' % ''.join(sorted(letters_val)))
if letters_train == letters_val:print('Letters in train and val do match')
else:raise Exception('Letters in train and val don\'t match')
# print(len(letters_train), len(letters_val), len(letters_val | letters_train))
# letters = sorted(list(letters_train))
# letters = '0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
letters = ALPHABET
if len(letters) == 0:letters = string.digits + string.ascii_uppercase + string.ascii_lowercase
class_num = len(letters) + 1   # plus 1 for blank
print('Alphabet Letters:', ''.join(letters))# Input data generatordef labels_to_text(labels):# return ''.join(list(map(lambda x: letters[int(x)], labels)))return ''.join([letters[int(x)] if int(x) != len(letters) else '' for x in labels])def text_to_labels(text):# return list(map(lambda x: letters.index(x), text))return [letters.find(x) if letters.find(x) > -1 else len(letters) for x in text]def is_valid_str(s):for ch in s:if not ch in letters:return Falsereturn Trueclass TextImageGenerator:def __init__(self,dirpath,tag,img_w, img_h,batch_size,downsample_factor,max_text_len=MAX_TEXT_LEN):self.img_h = img_hself.img_w = img_wself.batch_size = batch_sizeself.max_text_len = max_text_lenself.downsample_factor = downsample_factorimg_dirpath = dirpathself.samples = []for filename in os.listdir(img_dirpath):name, ext = os.path.splitext(filename)if ext in ['.png', '.jpg']:img_filepath = join(img_dirpath, filename)m = re.search(LABEL_REGEX, filename, re.M | re.I)description = m.group(1)if len(description) < MAX_TEXT_LEN:description = description + '_' * (MAX_TEXT_LEN - len(description))# if is_valid_str(description):#     self.samples.append([img_filepath, description])self.samples.append([img_filepath, description])self.n = len(self.samples)self.indexes = list(range(self.n))self.cur_index = 0# build data:self.imgs = np.zeros((self.n, self.img_h, self.img_w))self.texts = []for i, (img_filepath, text) in enumerate(self.samples):img = cv2.imread(img_filepath)img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)     # cv2默認是BGR模式img = cv2.resize(img, (self.img_w, self.img_h))img = img.astype(np.float32)img /= 255# width and height are backwards from typical Keras convention# because width is the time dimension when it gets fed into the RNNself.imgs[i, :, :] = imgself.texts.append(text)@staticmethoddef get_output_size():return len(letters) + 1def next_sample(self):self.cur_index += 1if self.cur_index >= self.n:self.cur_index = 0random.shuffle(self.indexes)return self.imgs[self.indexes[self.cur_index]], self.texts[self.indexes[self.cur_index]]def next_batch(self):while True:# width and height are backwards from typical Keras convention# because width is the time dimension when it gets fed into the RNNif K.image_data_format() == 'channels_first':X_data = np.ones([self.batch_size, 1, self.img_w, self.img_h])else:X_data = np.ones([self.batch_size, self.img_w, self.img_h, 1])Y_data = np.ones([self.batch_size, self.max_text_len])input_length = np.ones((self.batch_size, 1)) * (self.img_w // self.downsample_factor - 2)label_length = np.zeros((self.batch_size, 1))source_str = []for i in range(self.batch_size):img, text = self.next_sample()img = img.Tif K.image_data_format() == 'channels_first':img = np.expand_dims(img, 0)else:img = np.expand_dims(img, -1)X_data[i] = imgY_data[i] = text_to_labels(text)source_str.append(text)text = text.replace("_", "")  # important steplabel_length[i] = len(text)inputs = {'the_input': X_data,'the_labels': Y_data,'input_length': input_length,'label_length': label_length,# 'source_str': source_str}outputs = {'ctc': np.zeros([self.batch_size])}yield (inputs, outputs)tiger = TextImageGenerator(VALID_SET_PATH, 'val', IMG_W, IMG_H, 8, 4)for inp, out in tiger.next_batch():print('Text generator output (data which will be fed into the neutral network):')print('1) the_input (image)')if K.image_data_format() == 'channels_first':img = inp['the_input'][0, 0, :, :]else:img = inp['the_input'][0, :, :, 0]# plt.imshow(img.T, cmap='gray')# plt.show()print('2) the_labels (plate number): %s is encoded as %s' %(labels_to_text(inp['the_labels'][0]), list(map(int, inp['the_labels'][0]))))print('3) input_length (width of image that is fed to the loss function): %d == %d / 4 - 2' %(inp['input_length'][0], tiger.img_w))print('4) label_length (length of plate number): %d' % inp['label_length'][0])break# # Loss and train functions, network architecture
def ctc_lambda_func(args):y_pred, labels, input_length, label_length = args# the 2 is critical here since the first couple outputs of the RNN# tend to be garbage:y_pred = y_pred[:, 2:, :]return K.ctc_batch_cost(labels, y_pred, input_length, label_length)downsample_factor = 4def train(img_w=IMG_W, img_h=IMG_H, dropout=DROPOUT, batch_size=BATCH_SIZE, rnn_size=RNN_SIZE):# Input Parameters# Network parametersconv_filters = 16kernel_size = (3, 3)pool_size = 2time_dense_size = 32if K.image_data_format() == 'channels_first':input_shape = (1, img_w, img_h)else:input_shape = (img_w, img_h, 1)global downsample_factordownsample_factor = pool_size ** 2tiger_train = TextImageGenerator(TRAIN_SET_PTAH, 'train', img_w, img_h, batch_size, downsample_factor)tiger_val = TextImageGenerator(VALID_SET_PATH, 'val', img_w, img_h, batch_size, downsample_factor)act = 'relu'input_data = Input(name='the_input', shape=input_shape, dtype='float32')inner = Conv2D(conv_filters, kernel_size, padding='same',activation=None, kernel_initializer='he_normal',name='conv1')(input_data)inner = BatchNormalization()(inner)  # add BNinner = Activation(act)(inner)inner = MaxPooling2D(pool_size=(pool_size, pool_size), name='max1')(inner)inner = Conv2D(conv_filters, kernel_size, padding='same',activation=None, kernel_initializer='he_normal',name='conv2')(inner)inner = BatchNormalization()(inner)  # add BNinner = Activation(act)(inner)inner = MaxPooling2D(pool_size=(pool_size, pool_size), name='max2')(inner)conv_to_rnn_dims = (img_w // (pool_size ** 2), (img_h // (pool_size ** 2)) * conv_filters)inner = Reshape(target_shape=conv_to_rnn_dims, name='reshape')(inner)# cuts down input size going into RNN:inner = Dense(time_dense_size, activation=None, name='dense1')(inner)inner = BatchNormalization()(inner)  # add BNinner = Activation(act)(inner)if dropout:inner = Dropout(dropout)(inner)  # 防止過擬合# Two layers of bidirecitonal GRUs# GRU seems to work as well, if not better than LSTM:gru_1 = GRU(rnn_size, return_sequences=True, kernel_initializer='he_normal', name='gru1')(inner)gru_1b = GRU(rnn_size, return_sequences=True, go_backwards=True, kernel_initializer='he_normal', name='gru1_b')(inner)gru1_merged = add([gru_1, gru_1b])gru_2 = GRU(rnn_size, return_sequences=True, kernel_initializer='he_normal', name='gru2')(gru1_merged)gru_2b = GRU(rnn_size, return_sequences=True, go_backwards=True, kernel_initializer='he_normal', name='gru2_b')(gru1_merged)inner = concatenate([gru_2, gru_2b])if dropout:inner = Dropout(dropout)(inner)  # 防止過擬合# transforms RNN output to character activations:inner = Dense(tiger_train.get_output_size(), kernel_initializer='he_normal',name='dense2')(inner)y_pred = Activation('softmax', name='softmax')(inner)base_model = Model(inputs=input_data, outputs=y_pred)base_model.summary()labels = Input(name='the_labels', shape=[tiger_train.max_text_len], dtype='float32')input_length = Input(name='input_length', shape=[1], dtype='int64')label_length = Input(name='label_length', shape=[1], dtype='int64')# Keras doesn't currently support loss funcs with extra parameters# so CTC loss is implemented in a lambda layerloss_out = Lambda(ctc_lambda_func, output_shape=(1,), name='ctc')([y_pred, labels, input_length, label_length])model = Model(inputs=[input_data, labels, input_length, label_length], outputs=loss_out)# the loss calc occurs elsewhere, so use a dummy lambda func for the lossmodel.compile(loss={'ctc': lambda y_true, y_pred: y_pred}, optimizer='adadelta')# if not load:# captures output of softmax so we can decode the output during visualization# test_func = K.function([input_data], [y_pred])earlystoping = EarlyStopping(monitor=MONITOR, patience=PATIENCE, verbose=1, mode=MODE, baseline=BASELINE)train_model_path = './tmp/train_' + MODEL_NAMEcheckpointer = ModelCheckpoint(filepath=train_model_path,verbose=1,save_best_only=True)if os.path.exists(train_model_path):model.load_weights(train_model_path)print('load model weights:%s' % train_model_path)evaluator = Evaluate(model)model.fit_generator(generator=tiger_train.next_batch(),steps_per_epoch=tiger_train.n,epochs=EPOCHS,initial_epoch=1,validation_data=tiger_val.next_batch(),validation_steps=tiger_val.n,callbacks=[checkpointer, earlystoping, evaluator])base_model.save('./model/' + MODEL_NAME)print('----train end----')# For a real OCR application, this should be beam search with a dictionary
# and language model.  For this example, best path is sufficient.
def decode_batch(out):ret = []for j in range(out.shape[0]):out_best = list(np.argmax(out[j, 2:], 1))out_best = [k for k, g in itertools.groupby(out_best)]outstr = ''for c in out_best:if c < len(letters):outstr += letters[c]ret.append(outstr)return retclass Evaluate(Callback):def __init__(self, model):self.accs = []self.model = modeldef on_epoch_end(self, epoch, logs=None):acc = evaluate(self.model)self.accs.append(acc)# Test on validation images
def evaluate(model):global downsample_factortiger_test = TextImageGenerator(TEST_SET_PATH, 'test', IMG_W, IMG_H, TEST_BATCH_SIZE, downsample_factor)net_inp = model.get_layer(name='the_input').inputnet_out = model.get_layer(name='softmax').outputpredict_model = Model(inputs=net_inp, outputs=net_out)equalsIgnoreCaseNum = 0.00equalsNum = 0.00totalNum = 0.00for inp_value, _ in tiger_test.next_batch():batch_size = inp_value['the_input'].shape[0]X_data = inp_value['the_input']# net_out_value = sess.run(net_out, feed_dict={net_inp: X_data})net_out_value = predict_model.predict(X_data)pred_texts = decode_batch(net_out_value)labels = inp_value['the_labels']texts = []for label in labels:text = labels_to_text(label)texts.append(text)for i in range(batch_size):# print('Predict: %s ---> Label: %s' % (pred_texts[i], texts[i]))totalNum += 1if pred_texts[i] == texts[i]:equalsNum += 1if pred_texts[i].lower() == texts[i].lower():equalsIgnoreCaseNum += 1else:print('Predict: %s ---> Label: %s' % (pred_texts[i], texts[i]))if totalNum >= TEST_SET_NUM:breakprint('---Result---')print('Test num: %d, accuracy: %.5f, ignoreCase accuracy: %.5f' % (totalNum, equalsNum / totalNum, equalsIgnoreCaseNum / totalNum))return equalsIgnoreCaseNum / totalNumif __name__ == '__main__':train()

轉載于:https://www.cnblogs.com/CoolJayson/p/10602040.html

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/449444.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/449444.shtml
英文地址,請注明出處:http://en.pswp.cn/news/449444.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

[云框架]KONG API Gateway v1.5 -框架說明、快速部署、插件開發

前些天發現了一個巨牛的人工智能學習網站&#xff0c;通俗易懂&#xff0c;風趣幽默&#xff0c;忍不住分享一下給大家。點擊跳轉到教程。 當前版本采用KONGv0.12.3 當我們決定對應用進行微服務改造時&#xff0c;應用客戶端如何與微服務交互的問題也隨之而來&#xff0c;畢竟…

真格量化-主力跟買策略

#!/usr/bin/env python # coding:utf-8 from PoboAPI import * import datetime import time import numpy as np import pandas as pd #日線級別 #開始時間,用于初始化一些參數 def OnStart(context):print("I\m starting...")#設定一個全局變量品種,本策略交易50E…

頂級投資者的21條箴言(組圖)

每天你都會聽見五花八門的投資建議&#xff0c;告訴你應該買入還是賣出。如果這讓你感到無所適從&#xff0c;不妨靜下心來&#xff0c;聽聽歷史上最成功的投資者的建議。 我們搜集了21位頂尖大牛的投資箴言&#xff0c;以饗讀者。 1、George Soros&#xff1a;好的投資總是無…

python 游戲 —— 漢諾塔(Hanoita)

一、漢諾塔問題 1. 問題來源 問題源于印度的一個古老傳說&#xff0c;大梵天創造世界的時候做了三根金剛石柱子&#xff0c;在一根柱子上從下往上按照大小順序摞著64片黃金圓盤。大梵天命令婆羅門把圓盤從下面開始按大小順序重新擺放在另一根柱子上。并且規定&#xff0c;在小圓…

Base62x比Base64的編碼速度更快嗎?

現在幾乎所有企事業單位、政府機構、軍工系統等的IT生產系統都會用到Base64編碼&#xff0c;從RSA安全密鑰到管理信息系統登錄入口回跳&#xff0c;目前越來越多的IT系統研發者開始使用 Base62x 替換 Base64. -Base62x 提供了一種無符號輸出的Base64的編碼方案&#xff0c;在許…

對Docker常用命令的整理

前些天發現了一個巨牛的人工智能學習網站&#xff0c;通俗易懂&#xff0c;風趣幽默&#xff0c;忍不住分享一下給大家。點擊跳轉到教程。 查看docker版本信息、 #docker version #docker -v #docker info image鏡像操作命令 #docker search image_name //檢索image #docker p…

再說千遍萬遍,都不如這四句話管用,不服不行!

一、健康是最大的利益    人有時候&#xff0c;真不知要謀求什么&#xff1f;往往把最值得維護和珍貴的東西忽視了&#xff0c;卻不知揀了芝麻丟了西瓜。   現在好多人都在透支健康&#xff0c;燃燒生命&#xff0c;經常借口工作忙、應酬多&#xff0c;不注意生活方式&…

error: failed to push some refs to 'https://gitee.com/xxx/xxx'

一開始以為是本地版本和線上的差異 果斷先直接pull 之后 還是不對,哎 不瞎搞了 搜... 獲得消息: git pull --rebase origin master 原來如此:是缺失了文件 轉載于:https://www.cnblogs.com/G921123/p/10605956.html

真格量化-歷史波動率

#!/usr/bin/env python # coding:utf-8 from PoboAPI import * import datetime import time import numpy as np #日線級別 #開始時間,用于初始化一些參數 def OnStart(context):print("I\m starting...")#設定一個全局變量品種,本策略交易50ETF期權g.code = "…

DevOps團隊結構類型匯總:總有一款適合你

前言 組織中任何DevOps工作的主要目標都是改進客戶和業務的價值交付&#xff0c;而不是降低成本、提升自動化或者通過配置管理驅動一切&#xff1b;這意味著&#xff0c;為了實現有效的Dev和Ops協同&#xff0c;不同的組織可能需要不同的團隊結構。 概述 具體哪種DevOps團隊結構…

magic

轉載于:https://www.cnblogs.com/P201821430028/p/10611080.html

真格量化-bs套利

#!/usr/bin/env python # coding:utf-8 from PoboAPI import * import datetime import time import numpy as np from copy import *#開始時間,用于初始化一些參數 def OnStart(context) :context.myacc = None#登錄交易賬號if context.accounts["回測期權"].Login…

人生歷練必備的十個心態(圖)

成功源自心態&#xff0c;如果為自己鑲嵌上雄心、信心、決心、愛心、專心、誠心、耐心、恒心、虛心、靜心這十顆心&#xff0c;不斷打造自己的心態&#xff0c;你就一定會取得人生的成功! 第一個&#xff1a;雄心 你應該讓自己試著從人生的地平線上躍起。 第二個&#xf…

【docker】常用docker命令,及一些坑

前些天發現了一個巨牛的人工智能學習網站&#xff0c;通俗易懂&#xff0c;風趣幽默&#xff0c;忍不住分享一下給大家。點擊跳轉到教程。 查看容器的root用戶密碼 docker logs <容器名orID> 2>&1 | grep ^User: | tail -n1因為docker容器啟動時的root用戶的密碼…

kubernetes系列10—存儲卷詳解

kubernetes系列10—存儲卷詳解 1、認識存儲卷 1.1 背景 默認情況下容器中的磁盤文件是非持久化的&#xff0c;容器中的磁盤的生命周期是短暫的&#xff0c;這就帶來了一系列的問題&#xff1a;第一&#xff0c;當一個容器損壞之后&#xff0c;kubelet 會重啟這個容器&#xff0…

真格量化-隱含波動率計算

#!/usr/bin/env python # coding:utf-8 from PoboAPI import * import datetime import time import numpy as np from copy import *#開始時間,用于初始化一些參數 def OnStart(context) :context.myacc = None#登錄交易賬號if context.accounts["回測期權"].Login…

Vue 后臺管理

這里是結合vue和element快速成型的一個demo 里面展示了基本的后臺管理界面的大體結構和element的基本操作 GitHub的地址&#xff1a;https://github.com/wwwming/adminDemo 轉載于:https://www.cnblogs.com/wangming1002/p/10613014.html

生活竅門 這樣用錢就會富足

當我終于從惡劣處境中解脫之后&#xff0c;我想買棟房子&#xff0c;然而父親卻絲毫不為我感到興奮。他說&#xff1a;“在盡一項新的支付義務前&#xff0c;你應該多投資。”那個時候&#xff0c;許多人相信自己的房子是一種投資。我的父親問我&#xff1a;“如果你買了一棟房…

如何在Kubernetes集群動態使用 NAS 持久卷

1. 介紹&#xff1a; 本文介紹的動態生成NAS存儲卷的方案&#xff1a;在一個已有文件系統上&#xff0c;自動生成一個目錄&#xff0c;這個目錄定義為目標存儲卷&#xff1b; 鏡像地址&#xff1a;registry.cn-hangzhou.aliyuncs.com/acs/alicloud-nas-controller:v1.11.5.4-43…

Linux查看MySQL版本的四種方法

前些天發現了一個巨牛的人工智能學習網站&#xff0c;通俗易懂&#xff0c;風趣幽默&#xff0c;忍不住分享一下給大家。點擊跳轉到教程。 1 在終端下執行 mysql -V 2 在help中查找 mysql --help |grep Distrib 3 在mysql 里查看 select version() 4 在mysql 里查看 status…