本文实例讲述了pytorch制作自己的LMDB数据操作。分享给大家供大家参考,具体如下:
在定西等地区,都构建了全面的区域性战略布局,加强发展的系统性、市场前瞻性、产品创新能力,以专注、极致的服务理念,为客户提供网站设计、成都网站制作 网站设计制作定制网站,公司网站建设,企业网站建设,高端网站设计,成都全网营销,成都外贸网站建设,定西网站建设费用合理。前言记录下pytorch里如何使用lmdb的code,自用
制作部分的Codecode就是ASTER里数据制作部分的代码改了点,aster_train.txt里面就算图片的完整路径每行一个,图片同目录下有同名的txt,里面记着jpg的标签
import os import lmdb # install lmdb by "pip install lmdb" import cv2 import numpy as np from tqdm import tqdm import six from PIL import Image import scipy.io as sio from tqdm import tqdm import re def checkImageIsValid(imageBin): if imageBin is None: return False imageBuf = np.fromstring(imageBin, dtype=np.uint8) img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE) imgH, imgW = img.shape[0], img.shape[1] if imgH * imgW == 0: return False return True def writeCache(env, cache): with env.begin(write=True) as txn: for k, v in cache.items(): txn.put(k.encode(), v) def _is_difficult(word): assert isinstance(word, str) return not re.match('^[\w]+$', word) def createDataset(outputPath, imagePathList, labelList, lexiconList=None, checkValid=True): """ Create LMDB dataset for CRNN training. ARGS: outputPath : LMDB output path imagePathList : list of image path labelList : list of corresponding groundtruth texts lexiconList : (optional) list of lexicon lists checkValid : if true, check the validity of every image """ assert(len(imagePathList) == len(labelList)) nSamples = len(imagePathList) env = lmdb.open(outputPath, map_size=1099511627776)#大空间1048576GB cache = {} cnt = 1 for i in range(nSamples): imagePath = imagePathList[i] label = labelList[i] if len(label) == 0: continue if not os.path.exists(imagePath): print('%s does not exist' % imagePath) continue with open(imagePath, 'rb') as f: imageBin = f.read() if checkValid: if not checkImageIsValid(imageBin): print('%s is not a valid image' % imagePath) continue #数据库中都是二进制数据 imageKey = 'image-%09d' % cnt#9位数不足填零 labelKey = 'label-%09d' % cnt cache[imageKey] = imageBin cache[labelKey] = label.encode() if lexiconList: lexiconKey = 'lexicon-%09d' % cnt cache[lexiconKey] = ' '.join(lexiconList[i]) if cnt % 1000 == 0: writeCache(env, cache) cache = {} print('Written %d / %d' % (cnt, nSamples)) cnt += 1 nSamples = cnt-1 cache['num-samples'] = str(nSamples).encode() writeCache(env, cache) print('Created dataset with %d samples' % nSamples) def get_sample_list(txt_path:str): with open(txt_path,'r') as fr: jpg_list=[x.strip() for x in fr.readlines() if os.path.exists(x.replace('.jpg','.txt').strip())] txt_content_list=[] for jpg in jpg_list: label_path=jpg.replace('.jpg','.txt') with open(label_path,'r') as fr: try: str_tmp=fr.readline() except UnicodeDecodeError as e: print(label_path) raise(e) txt_content_list.append(str_tmp.strip()) return jpg_list,txt_content_list if __name__ == "__main__": txt_path='/home/gpu-server/disk/disk1/NumberData/8NumberSample/aster_train.txt' lmdb_output_path = '/home/gpu-server/project/aster/dataset/train' imagePathList,labelList=get_sample_list(txt_path) createDataset(lmdb_output_path, imagePathList, labelList)