如何在tensorflow中一次只加载一批图像



我正在从这里读取cifar100图像,我想从pickle文件中批量读取图像。对于我使用的这行代码(已加载%d个示例。"%num(加载训练数据集中存在的所有图像。然后使用tf.data我可以读取批次。但当它加载所有图像时,我的记忆就会被利用,甚至不会开始训练。我正在使用类似的东西。这个链接使用tfrecords,我想使用pickle读取cifar数据。那么,有人知道我如何从cifar100 pickle文件中只读取一批数据,这样我的内存就不会变满吗?

def read_data():
def in_data():
all_images = []
all_labels = []
with open("%s%s" % ("./data/cifar-100-python/", "train"),"rb") as fo:
dict = pickle.load(fo, encoding='latin1')
images = np.array(dict['data'])
labels = np.array(dict['fine_labels'])
num = images.shape[0]
# images = normalize(images)
images = images.astype(dtype=np.float32)
labels = labels.astype(dtype=np.int32)
images = np.reshape(images, [num, 3, 32, 32])
images = np.transpose(images, [0, 2, 3, 1])
print("Loaded %d examples." % num)
#print('BeforeLables: ', labels)
labels = one_hot_encode(labels)
#print('afterLables: ', labels)
all_images.append(images)
all_labels.append(labels)
# print('SIZE:', len(all_images))
all_images = np.concatenate(all_images)
all_labels = np.concatenate(all_labels)
self.size = len(all_images)
img_dataset = tf.data.Dataset.from_tensor_slices(all_images).batch(2)
label_dataset = tf.data.Dataset.from_tensor_slices(all_labels).batch(2)
dataset = tf.data.Dataset.zip((img_dataset, label_dataset)).repeat(None)
images, labels = dataset.make_one_shot_iterator().get_next()
print('image Shape:',images.shape)
return images, labels
return in_data

我建议您阅读有关导入数据的教程。有一个非常有用和相似的例子。在这个例子中,我们没有使用from_sensor_slice将图像数据嵌入到计算图中。相反,我们将文件名嵌入到图形中。

此外,如果你的数据是一个太大的数据,无法加载,你必须将其初步拆分为多个文件

最新更新