CRNN是OCR領域非常經典且被廣泛使用的識別算法,其理論基礎可以參考我上一篇文章,本文將著重講解CRNN代碼實現過程以及識別效果。
數據處理
利用圖像處理技術我們手工大批量生成文字圖像,一共360萬張圖像樣本,效果如下:
我們劃分了訓練集和測試集(10:1),并單獨存儲為兩個文本文件:
文本文件里的標簽格式如下:
我們獲取到的是最原始的數據集,在圖像深度學習訓練中我們一般都會把原始數據集轉化為lmdb格式以方便后續的網絡訓練。因此我們也需要對該數據集進行lmdb格式轉化。下面代碼就是用于lmdb格式轉化,思路比較簡單,就是首先讀入圖像和對應的文本標簽,先使用字典將該組合存儲起來(cache),再利用lmdb包的put函數把字典(cache)存儲的k,v寫成lmdb格式存儲好(cache當有了1000個元素就put一次)。
import lmdb
import cv2
import numpy as np
import osdef checkImageIsValid(imageBin):if imageBin is None:return Falsetry:imageBuf = np.fromstring(imageBin, dtype=np.uint8)img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE)imgH, imgW = img.shape[0], img.shape[1]except:return Falseelse:if imgH * imgW == 0:return Falsereturn Truedef writeCache(env, cache):with env.begin(write=True) as txn:for k, v in cache.items():txn.put(k, v)def createDataset(outputPath, imagePathList, labelList, lexiconList=None, checkValid=True):"""Create LMDB dataset for CRNN training.ARGS:outputPath : LMDB output pathimagePathList : list of image pathlabelList : list of corresponding groundtruth textslexiconList : (optional) list of lexicon listscheckValid : if true, check the validity of every image"""assert (len(imagePathList) == len(labelList))nSamples = len(imagePathList)env = lmdb.open(outputPath, map_size=1099511627776)cache = {}cnt = 1for i in range(nSamples):imagePath = ''.join(imagePathList[i]).split()[0].replace('\n', '').replace('\r\n', '')# print(imagePath)label = ''.join(labelList[i])print(label)# if not os.path.exists(imagePath):# print('%s does not exist' % imagePath)# continuewith open('.' + imagePath, 'r') as f:imageBin = f.read()if checkValid:if not checkImageIsValid(imageBin):print('%s is not a valid image' % imagePath)continueimageKey = 'image-%09d' % cntlabelKey = 'label-%09d' % cntcache[imageKey] = imageBincache[labelKey] = labelif lexiconList:lexiconKey = 'lexicon-%09d' % cntcache[lexiconKey] = ' '.join(lexiconList[i])if cnt % 1000 == 0:writeCache(env, cache)cache = {}print('Written %d / %d' % (cnt, nSamples))cnt += 1print(cnt)nSamples = cnt - 1cache['num-samples'] = str(nSamples)writeCache(env, cache)print('Created dataset with %d samples' % nSamples)OUT_PATH = '../crnn_train_lmdb'
IN_PATH = './train.txt'if __name__ == '__main__':outputPath = OUT_PATHif not os.path.exists(OUT_PATH):os.mkdir(OUT_PATH)imgdata = open(IN_PATH)imagePathList = list(imgdata)labelList = []for line in imagePathList:word = line.split()[1]labelList.append(word)createDataset(outputPath, imagePathList, labelList)
我們運行上面的代碼,可以得到訓練集和測試集的lmdb
在數據準備部分還有一個操作需要強調的,那就是文字標簽數字化,即我們用數字來表示每一個文字(漢字,英文字母,標點符號)。比如“我”字對應的id是1,“l”對應的id是1000,“?”對應的id是90,如此類推,這種編解碼工作使用字典數據結構存儲即可,訓練時先把標簽編碼(encode),預測時就將網絡輸出結果解碼(decode)成文字輸出。
class strLabelConverter(object):"""Convert between str and label.NOTE:Insert `blank` to the alphabet for CTC.Args:alphabet (str): set of the possible characters.ignore_case (bool, default=True): whether or not to ignore all of the case."""def __init__(self, alphabet, ignore_case=False):self._ignore_case = ignore_caseif self._ignore_case:alphabet = alphabet.lower()self.alphabet = alphabet + '-' # for `-1` indexself.dict = {}for i, char in enumerate(alphabet):# NOTE: 0 is reserved for 'blank' required by wrap_ctcself.dict[char] = i + 1def encode(self, text):"""Support batch or single str.Args:text (str or list of str): texts to convert.Returns:torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.torch.IntTensor [n]: length of each text."""length = []result = []for item in text:item = item.decode('utf-8', 'strict')length.append(len(item))for char in item:index = self.dict[char]result.append(index)text = result# print(text,length)return (torch.IntTensor(text), torch.IntTensor(length))def decode(self, t, length, raw=False):"""Decode encoded texts back into strs.Args:torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.torch.IntTensor [n]: length of each text.Raises:AssertionError: when the texts and its length does not match.Returns:text (str or list of str): texts to convert."""if length.numel() == 1:length = length[0]assert t.numel() == length, "text with length: {} does not match declared length: {}".format(t.numel(),length)if raw:return ''.join([self.alphabet[i - 1] for i in t])else:char_list = []for i in range(length):if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])):char_list.append(self.alphabet[t[i] - 1])return ''.join(char_list)else:# batch modeassert t.numel() == length.sum(), "texts with length: {} does not match declared length: {}".format(t.numel(), length.sum())texts = []index = 0for i in range(length.numel()):l = length[i]texts.append(self.decode(t[index:index + l], torch.IntTensor([l]), raw=raw))index += lreturn texts
網絡設計
根據CRNN的論文描述,CRNN是由CNN-》RNN-》CTC三大部分架構而成,分別對應卷積層、循環層和轉錄層。首先CNN部分用于底層的特征提取,RNN采取了BiLSTM,用于學習關聯序列信息并預測標簽分布,CTC用于序列對齊,輸出預測結果。
為了將特征輸入到Recurrent Layers,做如下處理:
- 首先會將圖像縮放到 32×W×3 大小
- 然后經過CNN后變為 1×(W/4)× 512
- 接著針對LSTM,設置 T=(W/4) , D=512 ,即可將特征輸入LSTM。
以上是理想訓練時的操作,但是CRNN論文提到的網絡輸入是歸一化好的100×32大小的灰度圖像,即高度統一為32個像素。下面是CRNN的深度神經網絡結構圖,CNN采取了經典的VGG16,值得注意的是,在VGG16的第3第4個max pooling層CRNN采取的是1×2的矩形池化窗口(w×h),這有別于經典的VGG16的2×2的正方形池化窗口,這個改動是因為文本圖像多數都是高較小而寬較長,所以其feature map也是這種高小寬長的矩形形狀,如果使用1×2的池化窗口則更適合英文字母識別(比如區分i和l)。VGG16部分還引入了BatchNormalization模塊,旨在加速模型收斂。還有值得注意一點,CRNN的輸入是灰度圖像,即圖像深度為1。CNN部分的輸出是512x1x16(c×h×w)的特征向量。
接下來分析RNN層。RNN部分使用了雙向LSTM,隱藏層單元數為256,CRNN采用了兩層BiLSTM來組成這個RNN層,RNN層的輸出維度將是(s,b,class_num) ,其中class_num為文字類別總數。
值得注意的是:Pytorch里的LSTM單元接受的輸入都必須是3維的張量(Tensors).每一維代表的意思不能弄錯。第一維體現的是序列(sequence)結構,第二維度體現的是小塊(mini-batch)結構,第三位體現的是輸入的元素(elements of input)。如果在應用中不適用小塊結構,那么可以將輸入的張量中該維度設為1,但必須要體現出這個維度。
LSTM的輸入
input of shape (seq_len, batch, input_size): tensor containing the features of the input sequence.
The input can also be a packed variable length sequence.
input shape(a,b,c)
a:seq_len -> 序列長度
b:batch
c:input_size 輸入特征數目
根據LSTM的輸入要求,我們要對CNN的輸出做些調整,即把CNN層的輸出調整為[seq_len, batch, input_size]形式,下面為具體操作:先使用squeeze函數移除h維度,再使用permute函數調整各維順序,即從原來[w, b, c]的調整為[seq_len, batch, input_size],具體尺寸為[16,batch,512],調整好之后即可以將該矩陣送入RNN層。
x = self.cnn(x)
b, c, h, w = x.size()
# print(x.size()): b,c,h,w
assert h == 1 # "the height of conv must be 1"
x = x.squeeze(2) # remove h dimension, b *512 * width
x = x.permute(2, 0, 1) # [w, b, c] = [seq_len, batch, input_size]
x = self.rnn(x)
RNN層輸出格式如下,因為我們采用的是雙向BiLSTM,所以輸出維度將是hidden_unit * 2
Outputs: output, (h_n, c_n)
output of shape (seq_len, batch, num_directions * hidden_size)
h_n of shape (num_layers * num_directions, batch, hidden_size)
c_n (num_layers * num_directions, batch, hidden_size)
然后我們再通過線性變換操作self.embedding1 = torch.nn.Linear(hidden_unit * 2, 512)
是的輸出維度再次變為512,繼續送入第二個LSTM層。第二個LSTM層后繼續接線性操作torch.nn.Linear(hidden_unit * 2, class_num)
使得整個RNN層的輸出為文字類別總數。
import torch
import torch.nn.functional as Fclass Vgg_16(torch.nn.Module):def __init__(self):super(Vgg_16, self).__init__()self.convolution1 = torch.nn.Conv2d(1, 64, 3, padding=1)self.pooling1 = torch.nn.MaxPool2d(2, stride=2)self.convolution2 = torch.nn.Conv2d(64, 128, 3, padding=1)self.pooling2 = torch.nn.MaxPool2d(2, stride=2)self.convolution3 = torch.nn.Conv2d(128, 256, 3, padding=1)self.convolution4 = torch.nn.Conv2d(256, 256, 3, padding=1)self.pooling3 = torch.nn.MaxPool2d((1, 2), stride=(2, 1)) # notice stride of the non-square poolingself.convolution5 = torch.nn.Conv2d(256, 512, 3, padding=1)self.BatchNorm1 = torch.nn.BatchNorm2d(512)self.convolution6 = torch.nn.Conv2d(512, 512, 3, padding=1)self.BatchNorm2 = torch.nn.BatchNorm2d(512)self.pooling4 = torch.nn.MaxPool2d((1, 2), stride=(2, 1))self.convolution7 = torch.nn.Conv2d(512, 512, 2)def forward(self, x):x = F.relu(self.convolution1(x), inplace=True)x = self.pooling1(x)x = F.relu(self.convolution2(x), inplace=True)x = self.pooling2(x)x = F.relu(self.convolution3(x), inplace=True)x = F.relu(self.convolution4(x), inplace=True)x = self.pooling3(x)x = self.convolution5(x)x = F.relu(self.BatchNorm1(x), inplace=True)x = self.convolution6(x)x = F.relu(self.BatchNorm2(x), inplace=True)x = self.pooling4(x)x = F.relu(self.convolution7(x), inplace=True)return x # b*512x1x16class RNN(torch.nn.Module):def __init__(self, class_num, hidden_unit):super(RNN, self).__init__()self.Bidirectional_LSTM1 = torch.nn.LSTM(512, hidden_unit, bidirectional=True)self.embedding1 = torch.nn.Linear(hidden_unit * 2, 512)self.Bidirectional_LSTM2 = torch.nn.LSTM(512, hidden_unit, bidirectional=True)self.embedding2 = torch.nn.Linear(hidden_unit * 2, class_num)def forward(self, x):x = self.Bidirectional_LSTM1(x) # LSTM output: output, (h_n, c_n)T, b, h = x[0].size() # x[0]: (seq_len, batch, num_directions * hidden_size)x = self.embedding1(x[0].view(T * b, h)) # pytorch view() reshape as [T * b, nOut]x = x.view(T, b, -1) # [16, b, 512]x = self.Bidirectional_LSTM2(x)T, b, h = x[0].size()x = self.embedding2(x[0].view(T * b, h))x = x.view(T, b, -1)return x # [16,b,class_num]# output: [s,b,class_num]
class CRNN(torch.nn.Module):def __init__(self, class_num, hidden_unit=256):super(CRNN, self).__init__()self.cnn = torch.nn.Sequential()self.cnn.add_module('vgg_16', Vgg_16())self.rnn = torch.nn.Sequential()self.rnn.add_module('rnn', RNN(class_num, hidden_unit))def forward(self, x):x = self.cnn(x)b, c, h, w = x.size()# print(x.size()): b,c,h,wassert h == 1 # "the height of conv must be 1"x = x.squeeze(2) # remove h dimension, b *512 * widthx = x.permute(2, 0, 1) # [w, b, c] = [seq_len, batch, input_size]# x = x.transpose(0, 2)# x = x.transpose(1, 2)x = self.rnn(x)return x
損失函數設計
剛剛完成了CNN層和RNN層的設計,現在開始設計轉錄層,即將RNN層輸出的結果翻譯成最終的識別文字結果,從而實現不定長的文字識別。pytorch沒有內置的CTC loss,所以只能去Github下載別人實現的CTC loss來完成損失函數部分的設計。安裝CTC-loss的方式如下:
git clone https://github.com/SeanNaren/warp-ctc.git
cd warp-ctc
mkdir build; cd build
cmake ..
make
cd ../pytorch_binding/
python setup.py install
cd ../build
cp libwarpctc.so ../../usr/lib
待安裝完畢后,我們可以直接調用CTC loss了,以一個小例子來說明ctc loss的用法。
import torch
from warpctc_pytorch import CTCLoss
ctc_loss = CTCLoss()
# expected shape of seqLength x batchSize x alphabet_size
probs = torch.FloatTensor([[[0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1]]]).transpose(0, 1).contiguous()
labels = torch.IntTensor([1, 2])
label_sizes = torch.IntTensor([2])
probs_sizes = torch.IntTensor([2])
probs.requires_grad_(True) # tells autograd to compute gradients for probs
cost = ctc_loss(probs, labels, probs_sizes, label_sizes)
cost.backward()
CTCLoss(size_average=False, length_average=False)# size_average (bool): normalize the loss by the batch size (default: False)# length_average (bool): normalize the loss by the total number of frames in the batch. If True, supersedes size_average (default: False)forward(acts, labels, act_lens, label_lens)# acts: Tensor of (seqLength x batch x outputDim) containing output activations from network (before softmax)# labels: 1 dimensional Tensor containing all the targets of the batch in one large sequence# act_lens: Tensor of size (batch) containing size of each output sequence from the network# label_lens: Tensor of (batch) containing label length of each example
從上面的代碼可以看出,CTCLoss的輸入為[probs, labels, probs_sizes, label_sizes],即預測結果、標簽、預測結果的數目和標簽數目。那么我們仿照這個例子開始設計CRNN的CTC LOSS。
preds = net(image)
preds_size = Variable(torch.IntTensor([preds.size(0)] * batch_size)) # preds.size(0)=w=16
cost = criterion(preds, text, preds_size, length) / batch_size # 這里的length就是包含每個文本標簽的長度的list,除以batch_size來求平均loss
cost.backward()
網絡訓練設計
接下來我們需要完善具體的訓練流程,我們還寫了個trainBatch函數用于bacth形式的梯度更新。
def trainBatch(net, criterion, optimizer, train_iter):data = train_iter.next()cpu_images, cpu_texts = databatch_size = cpu_images.size(0)lib.dataset.loadData(image, cpu_images)t, l = converter.encode(cpu_texts)lib.dataset.loadData(text, t)lib.dataset.loadData(length, l)preds = net(image)#print("preds.size=%s" % preds.size)preds_size = Variable(torch.IntTensor([preds.size(0)] * batch_size)) # preds.size(0)=w=22cost = criterion(preds, text, preds_size, length) / batch_size # length= a list that contains the len of text label in a batchnet.zero_grad()cost.backward()optimizer.step()return cost
整個網絡訓練的流程如下:CTC-LOSS對象->CRNN網絡對象->image,text,len的tensor初始化->優化器初始化,然后開始循環每個epoch,指定迭代次數就進行模型驗證和模型保存。CRNN論文提到所采用的優化器是Adadelta,但是經過我實驗看來,Adadelta的收斂速度非常慢,所以改用了RMSprop優化器,模型收斂速度大幅度提升。
criterion = CTCLoss()net = Net.CRNN(n_class)print(net)net.apply(lib.utility.weights_init)image = torch.FloatTensor(Config.batch_size, 3, Config.img_height, Config.img_width)text = torch.IntTensor(Config.batch_size * 5)length = torch.IntTensor(Config.batch_size)if cuda:net.cuda()image = image.cuda()criterion = criterion.cuda()image = Variable(image)text = Variable(text)length = Variable(length)loss_avg = lib.utility.averager()optimizer = optim.RMSprop(net.parameters(), lr=Config.lr)#optimizer = optim.Adadelta(net.parameters(), lr=Config.lr)#optimizer = optim.Adam(net.parameters(), lr=Config.lr,#betas=(Config.beta1, 0.999))for epoch in range(Config.epoch):train_iter = iter(train_loader)i = 0while i < len(train_loader):for p in net.parameters():p.requires_grad = Truenet.train()cost = trainBatch(net, criterion, optimizer, train_iter)loss_avg.add(cost)i += 1if i % Config.display_interval == 0:print('[%d/%d][%d/%d] Loss: %f' %(epoch, Config.epoch, i, len(train_loader), loss_avg.val()))loss_avg.reset()if i % Config.test_interval == 0:val(net, test_dataset, criterion)# do checkpointingif i % Config.save_interval == 0:torch.save(net.state_dict(), '{0}/netCRNN_{1}_{2}.pth'.format(Config.model_dir, epoch, i))
訓練過程與測試設計
下面這幅圖表示的就是CRNN訓練過程,文字類別數為6732,一共訓練20個epoch,batch_Szie設置為64,所以一共是51244次迭代/epoch。
在迭代4個epoch時,loss降到0.1左右,acc上升到0.98。
接下來我們設計推斷預測部分的代碼,首先需初始化CRNN網絡,載入訓練好的模型,讀入待預測的圖像并resize為高為32的灰度圖像,接著講該圖像送入網絡,最后再將網絡輸出解碼成文字即可輸出。
import time
import torch
import os
from torch.autograd import Variable
import lib.convert
import lib.dataset
from PIL import Image
import Net.net as Net
import alphabets
import sys
import Configos.environ['CUDA_VISIBLE_DEVICES'] = "4"crnn_model_path = './bs64_model/netCRNN_9_48000.pth'
IMG_ROOT = './test_images'
running_mode = 'gpu'
alphabet = alphabets.alphabet
nclass = len(alphabet) + 1def crnn_recognition(cropped_image, model):converter = lib.convert.strLabelConverter(alphabet) # 標簽轉換image = cropped_image.convert('L') # 圖像灰度化### Testing images are scaled to have height 32. Widths are# proportionally scaled with heights, but at least 100 pixelsw = int(image.size[0] / (280 * 1.0 / Config.infer_img_w))#scale = image.size[1] * 1.0 / Config.img_height#w = int(image.size[0] / scale)transformer = lib.dataset.resizeNormalize((w, Config.img_height))image = transformer(image)if torch.cuda.is_available():image = image.cuda()image = image.view(1, *image.size())image = Variable(image)model.eval()preds = model(image)_, preds = preds.max(2)preds = preds.transpose(1, 0).contiguous().view(-1)preds_size = Variable(torch.IntTensor([preds.size(0)]))sim_pred = converter.decode(preds.data, preds_size.data, raw=False) # 預測輸出解碼成文字print('results: {0}'.format(sim_pred))if __name__ == '__main__':# crnn networkmodel = Net.CRNN(nclass)# 載入訓練好的模型,CPU和GPU的載入方式不一樣,需分開處理if running_mode == 'gpu' and torch.cuda.is_available():model = model.cuda()model.load_state_dict(torch.load(crnn_model_path))else:model.load_state_dict(torch.load(crnn_model_path, map_location='cpu'))print('loading pretrained model from {0}'.format(crnn_model_path))files = sorted(os.listdir(IMG_ROOT)) # 按文件名排序for file in files:started = time.time()full_path = os.path.join(IMG_ROOT, file)print("=============================================")print("ocr image is %s" % full_path)image = Image.open(full_path)crnn_recognition(image, model)finished = time.time()print('elapsed time: {0}'.format(finished - started))
識別效果和總結
首先我從測試集中抽取幾張圖像送入模型識別,識別全部正確。
我也隨機在一些文檔圖片、掃描圖像上截取了一段文字圖像送入我們該模型進行識別,識別效果也挺好的,基本識別正確,表明模型泛化能力很強。
我還截取了增值稅掃描發票上的文本圖像來看看我們的模型能否還可以表現出穩定的識別效果:
這里做個小小的總結:對于端到端不定長的文字識別,CRNN是最為經典的識別算法,而且實戰看來效果非常不錯。上面識別結果可以看出,雖然我們用于訓練的數據集是自己生成的,但是我們該模型對于pdf文檔、掃描圖像等都有很不錯的識別結果,如果需要繼續提升對特定領域的文本圖像的識別,直接大量加入該類圖像用于訓練即可。CRNN的完整代碼可以參考我的Github。