tensorflow tf.py_func加载pickle迭代程序抛出错误,形状未知



我正在编写一个tf.data管道,以便稍后输入到keras。问题是我的数据是以pickle文件的形式存在的。我有一个传递给tf数据的文件名列表,我将使用内部的自定义tf.py_func调用pickle来加载它。

当我试图从数据集构建迭代器时,问题就出现了,给出了

"无法转换值,),types:(tf.float32,tf.foat32)>转换为TensorFlow DType。">

我认为这是因为tensorflow无法推断加载的pickle数据的形状。我有点不知道如何继续,或者这在tf数据中是否可能。

dataset = tf.data.Dataset.from_tensor_slices(dataset_filepath_list)
def parse_input_data_function(filename):
# pickle file is a tuple, (data, label)
histogram_data, label = pickle.load(open(filename, 'rb'))
histogram_data = historgram_data.transpose(1, 0)
histogram_data = historgram_data.reshape([-1, 8, 32])
return histogram_data.astype('float32'), float(label)
dataset = dataset.map(
lambda filename : tuple(tf.py_func(
parse_input_data_function, [filename], [tf.float32, 
tf.float32])))
dataset = dataset.shuffle(len(dataset_filename_list))
.batch(batch_size).repeat()
# this line is where the error occurs
training_iterator = tf.data.Iterator.from_structure(dataset, 
dataset.output_shapes)

您的问题是向tf.data.Iterator.from_structure传递了错误的参数。它应该采用(output_types,output_shapes),但您给出的是一个数据集及其形状。试试这个:

training_iterator = tf.data.Iterator.from_structure(dataset.output_types, dataset.output_shapes)

然后,将迭代器与相应的数据集一起使用:

input_data, output_data = training_iterator.get_next()
train_init = training_iterator.make_initializer(dataset)

相关内容

最新更新