从 tf.data.Dataset.map() 返回数据集会导致对象没有属性'get_shape'错误'TensorSliceDataset'



>我正在使用数据集 API 创建输入管道。 我正在使用 tf.data.Dataset.map(( 方法,其模式类似于以下内容:

def mapped_fn(_):
X = tf.random_uniform([3,3])
y = tf.random_uniform([3,1])
dataset = tf.data.Dataset.from_tensor_slices((X,y))
return dataset
with tf.Session() as sess:
first = tf.random_uniform([1,2])         
unimportant_dataset = tf.data.Dataset.from_tensors(first)
dataset = unimportant_dataset.map(mapped_fn)
sess.run(dataset)

我收到以下错误:AttributeError: 'TensorSliceDataset' object has no attribute 'get_shape'

总体上下文是mapped_fn从 .tfrecords 文件中反序列化示例 protobuf(在本例中由unimportant_dataset表示(,重塑特征向量 (X(,并且需要返回一个数据集,其中包含由新特征向量(在本例中为形状(3,)(中的切片定义的元素。 我在返回ZipDataset时遇到了类似的错误。 提前感谢!

DomJack 关于Dataset.map()签名的回答是绝对正确的:它期望传递mapped_fn的返回值是一个或多个张量(或稀疏张量(。

如果您确实有一个返回Dataset的函数,则可以使用Dataset.flat_map()将所有返回的数据集平展并连接成单个数据集,如下所示:

def mapped_fn(_):
X = tf.random_uniform([3,3])
y = tf.random_uniform([3,1])
dataset = tf.data.Dataset.from_tensor_slices((X,y))
return dataset
# Generate 100 dummy elements.
unimportant_dataset = tf.data.Dataset.range(100)
# Convert each dummy element into a dataset of 3 nested elements, and concatenate them.
dataset = unimportant_dataset.flat_map(mapped_fn)

传递给tf.data.Dataset.mapmap_fn应该从调用数据集中获取单个示例的张量,并返回返回数据集的张量。

例如

def map_fn(example_proto):
features, labels = parse_example_proto(example_proto)
# do data augmentation here
return features, labels
dataset = tf.data.TfRecordsDataset(filenames)
dataset = dataset.repeat().shuffle().map(
map_fn, num_parallel_calls=8).prefetch(1)
features, labels = dataset.make_one_shot_iterator().get_next()

最新更新