Tensorflow: NotImplementError: reduce() 转换目前不支持嵌套数据集作为输入



在Tensorflow 1.12中引入了tf.data.Dataset.reduce()tf.data.Dataset.window()方法。

从发行说明:

  • "新的tf.data.Dataset.reduce()API允许用户使用用户提供的reduce函数将有限数据集简化为单个元素。

  • "新的tf.data.Dataset.window()API允许用户创建输入数据集的有限窗口;当与tf.data.Dataset.reduce()API结合使用时,这允许用户实现自定义批处理。

但是如何使用这些功能?

def reduce_func(old_state, input_element):
pdb.set_trace()
return new_state
dataset = tf.data.Dataset.from_generator(frame_generator, (tf.string, tf.string))
dataset = dataset.window(2).reduce(np.int64(0), reduce_func)

这给出了一个 NotImplementError:

NotImplementError:reduce(( 转换当前
不支持嵌套数据集作为输入。

我使用张量流版本'1.12.0-rc1'

编辑: 从 https://www.tensorflow.org/versions/r1.12/api_docs/python/tf/contrib/data/sliding_window_batch

此函数已弃用。它将在未来的版本中删除。更新说明:使用tf.data.Dataset.window(size=window_size, shift=window_shift, stride=window_stride).flat_map(lambda x: x.batch(window.size))

但是,如果数据集是用

dataset = tf.data.Dataset.from_generator(frame_generator, (tf.string, tf.string))

因此,数据集中的每个项目都包含两个元素。然后有一个类型错误:

类型错误: (( 需要 1 个位置参数,但给出了 2

编辑: 使用 zip 解决

dataset = tf.data.Dataset.from_generator(frame_generator, (tf.string, tf.string))
window_size = 2
dataset = dataset.window(window_size).flat_map(lambda x,y: tf.data.Dataset.zip((x,y)).batch(window_size))
dataset = dataset.map(self.parse_function)

这里的问题是,当您调用.reduce()函数时,它会应用于外部数据集。 调用.window后的数据集现在是一个数据集,其中每个元素本身就是一个数据集。 您要做的是对窗口创建的那些单个数据集使用reduce。 您可以使用 map 执行此操作,然后将每个内部数据集映射到 reduce。

def reduce_func(old_state, input_element):
pdb.set_trace()
return new_state
dataset = tf.data.Dataset.from_generator(frame_generator, (tf.string, tf.string))
dataset = dataset.window(2).map(lambda ds: ds.reduce(np.int64(0), reduce_func))

最新更新