在python中从lmdb检索数据是抛出错误



数据集就是这样创建的


def createDataset1(
outputPath, 
imagePathList, 
labelList, 
lexiconList=None, 
validset_percent=10,
testset_percent=0,
random_seed=1111,
checkValid=True,

):
"""
Create LMDB dataset for CRNN training.
ARGS:
outputPath    : LMDB output path
imagePathList : list of image path
labelList     : list of corresponding groundtruth texts
lexiconList   : (optional) list of lexicon lists
checkValid    : if true, check the validity of every image
"""
train_path = os.path.join(outputPath, "training", "9M")
valid_path = os.path.join(outputPath, "validation", "9M")
# CAUTION: if train_path (lmdb) already exists, this function add dataset
# into it. so remove former one and re-create lmdb.
if os.path.exists(train_path):
os.system(f"rm -r {train_path}")
if os.path.exists(valid_path):
os.system(f"rm -r {valid_path}")

os.makedirs(train_path, exist_ok=True)
os.makedirs(valid_path, exist_ok=True)
gt_train_path = gt_file.replace(".txt", "_train.txt")
gt_valid_path = gt_file.replace(".txt", "_valid.txt")
data_log = open(gt_train_path, "w", encoding="utf-8")
if testset_percent != 0:
test_path = os.path.join(outputPath, "evaluation", dataset_name)
if os.path.exists(test_path):
os.system(f"rm -r {test_path}")
os.makedirs(test_path, exist_ok=True)
gt_test_path = gtFile.replace(".txt", "_test.txt")

assert(len(imagePathList) == len(labelList))
nSamples = len(imagePathList)
num_valid_dataset = int(nSamples * validset_percent / 100.0)
num_test_dataset = int(nSamples * testset_percent / 100.0)
num_train_dataset = nSamples - num_valid_dataset - num_test_dataset
print("validation datasets: ",num_valid_dataset,"n", "test datasets: ", num_test_dataset, " n training datasets: ", num_train_dataset)
env = lmdb.open(outputPath, map_size=1099511627776)
cache = {}
cnt = 1
random.seed(random_seed)
random.shuffle(imagePathList)
for i in tqdm(range(nSamples)):
data_log.write(imagePathList[i])
imagePath = imagePathList[i]
label = labelList[i]
if len(label) == 0:
continue
if not os.path.exists(imagePath):
print('%s does not exist' % imagePath)
continue
with open(imagePath, 'rb') as f:
imageBin = f.read()
if checkValid:
if not checkImageIsValid(imageBin):
print('%s is not a valid image' % imagePath)
continue
embed_vec = fasttext_model[label]
imageKey = 'image-%09d' % cnt
labelKey = 'label-%09d' % cnt
embedKey = 'embed-%09d' % cnt
cache[imageKey] = imageBin
cache[labelKey] = label.encode()
cache[embedKey] = ' '.join(str(v) for v in embed_vec.tolist()).encode()
if lexiconList:
lexiconKey = 'lexicon-%09d' % cnt
cache[lexiconKey] = ' '.join(lexiconList[i])
if cnt % 1000 == 0:
writeCache(env, cache)
cache = {}
print('Written %d / %d' % (cnt, nSamples))

#finish train dataset and start validation dataset
if i + 1 ==  num_train_dataset:
print(f"# Train dataset: {num_train_dataset} is finished")
cache["num-samples".encode()] = str(num_train_dataset).encode()
writeCache(env, cache)
data_log.close()

#start validation set
env = lmdb.open(valid_path, map_size=30 * 2 ** 30)
cache = {}
cnt = 0
data_log = open(gt_valid_path, "w", encoding="utf-8")

# Finish train/valid dataset and Start test dataset
if (i + 1 == num_train_dataset + num_valid_dataset) and num_test_dataset != 0:
print(f"# Valid dataset: {num_valid_dataset} is finished")
cache["num-samples".encode()] = str(num_valid_dataset).encode()
writeCache(env, cache)
data_log.close()
# start test set
env = lmdb.open(test_path, map_size=30 * 2 ** 30)
cache = {}
cnt = 0  # not 1 at this time
data_log = open(gt_test_path, "w", encoding="utf-8")

cnt += 1
if testset_percent == 0:
cache["num-samples".encode()] = str(num_valid_dataset).encode()
writeCache(env, cache)
print(f"# Valid datast: {num_valid_dataset} is finished")
else:
cache["num-samples".encode()] = str(num_test_dataset).encode()
writeCache(env, cache)
print(f"# Test datast: {num_test_dataset} is finished")

这就是我试图检索数据的方式


class LmdbDataset(data.Dataset):
def __init__(self, root, voc_type, max_len, num_samples, transform=None):
super(LmdbDataset, self).__init__()
if global_args.run_on_remote:
dataset_name = os.path.basename(root)
data_cache_url = "/cache/%s" % dataset_name
if not os.path.exists(data_cache_url):
os.makedirs(data_cache_url)
if mox.file.exists(root):
mox.file.copy_parallel(root, data_cache_url)
else:
raise ValueError("%s not exists!" % root)

self.env = lmdb.open(data_cache_url, max_readers=32, readonly=True)
else:
self.env = lmdb.open(root, max_readers=32, readonly=True)
assert self.env is not None, "cannot create lmdb from %s" % root
self.txn = self.env.begin()
self.voc_type = voc_type
self.transform = transform
self.max_len = max_len
# nums = b"num-samples"  
# print('NUM SAMPLES ------ n',nums)
nSamples = self.txn.get('num-samples'.encode())
print("STRING nSamples :", nSamples)
self.nSamples = int(self.txn.get(b"num-samples"))
self.nSamples = min(self.nSamples, num_samples)
assert voc_type in ['LOWERCASE', 'ALLCASES', 'ALLCASES_SYMBOLS']
self.EOS = 'EOS'
self.PADDING = 'PADDING'
self.UNKNOWN = 'UNKNOWN'
self.voc = get_vocabulary(voc_type, EOS=self.EOS, PADDING=self.PADDING, UNKNOWN=self.UNKNOWN)
self.char2id = dict(zip(self.voc, range(len(self.voc))))
self.id2char = dict(zip(range(len(self.voc)), self.voc))
self.rec_num_classes = len(self.voc)
self.lowercase = (voc_type == 'LOWERCASE')

每当代码试图调用elf.txn.get(b"num samples"(时,我都会收到下面的错误

Traceback (most recent call last):
File "main.py", line 268, in <module>
main(args)
File "main.py", line 157, in main
train_dataset, train_loader = get_data_lmdb(args.synthetic_train_data_dir, args.voc_type, args.max_len, args.num_train,
File "main.py", line 66, in get_data_lmdb
dataset_list.append(LmdbDataset(data_dir_, voc_type, max_len, num_samples))
File "/Users/SEED/lib/datasets/dataset.py", line 189, in __init__
self.nSamples = int(self.txn.get(b"num-samples"))
TypeError: int() argument must be a string, a bytes-like object or a number, not 'NoneType'

我在网上尝试了很多不同的建议,也尝试了一些堆叠的线程,但都不知道出了什么问题。

导致此错误的原因是什么?我如何解决此问题?

您的代码密集、复杂且难以遵循数据流。它将大量计算与IO、副作用甚至系统调用混合在一起(例如,os.system(f"rm -r {test_path}"),您应该使用shutil.rmtree(。

试着分解你想要执行的每个动作,让它主要做一件特定的事情:

  • 逻辑运算,但无副作用
  • 输入(从文件或网络读取(
  • 输出(写入文件,生成结果(
  • 文件系统操作/清理

在每个阶段,您都应该执行验证,并使用最小幂规则。如果您希望self.txn始终具有'num samples, then you should useself.txn[b'num samples']rather.get, which defaults toNone'。这样可以更容易地捕捉链中较早的错误。

我也不知道lmbd模块是什么。那是一个库,还是你代码库中的另一个文件?您应该链接到您正在使用的任何库,如果它们不是众所周知的。我看到了几个与lmbd相关的python包。

从代码示例中,我可以为您剖析。也许这会帮助你解决这个问题。

日志上写着。。

self.nSamples = int(self.txn.get(b"num-samples"))
TypeError: int() argument must be a string, a bytes-like object or a number, not 'NoneType'

self.txn应该是一个字典,以便在那里使用get方法。当你试图得到它时,可能有两种情况

  1. num个样本在数据中可以为"表示没有值
  2. 数据中没有名为num samples的变量

两者都可能导致类型转换(到int(无法处理的NoneType。因此,您必须使用self.txn.keys((检查数据中是否存在该字段,并查看该密钥是否存在

此外,如果

self.nSamples = int(self.txn.get(b"num-samples")) 

正在失败,而不是上述语句

nSamples = self.txn.get('num-samples'.encode())

然后你可以简单地解码变量并在那里使用

Samples =  nSamples.decode("utf-8") 
try:
int_samples = int(Samples)
except TypeError:
print(Samples)

相关内容

最新更新