如何正确使用tf.scatter_update进行N维更新



我一直在尝试使用tf.scatter_update进行N维更新(在tf.scatter_nd由于形状不匹配而失败之后)。通常,这些将用于创建掩码,用于过滤传入张量的切片。

假设输入张量A具有形状(batch,i,j,k(深度))。我只对修改所有k全部bI,j值感兴趣。

MWE:

import tensorflow as tf
b, i, j, k = 64, 128, 128, 256
A = tf.random_uniform(shape=(64, 128, 128, 256), dtype='int32', seed=1234) # Batch, i, j, k
mask = tf.ones(shape=(b,i,j,k), dtype='int32')
# Placeholder for more complicated index Tensor. GPU Ignores OOB indices.
indices = tf.random_uniform(shape=(b, 25, k, 2), dtype='int32', seed=4321) # Index number, k, i-j coord.
updates = tf.random_uniform(shape=(i, j, k), dtype='int32', seed=1111)
scatter = tf.scatter_update(mask, indices, updates)
with tf.Session() as sess:
sess.run(scatter)

结果:

属性错误:"Tensor"对象没有属性"_lazy_read"

我已经通过Python Script、Python Notebook和使用/不使用Eager Execution进行了尝试。运气不好。

输入绝对必须是一个张量,因为我们的想法是在一系列操作的中途稀疏地更新这个张量。

关于tf.scatter_update,我是否缺少一些基本的东西?tf.scatter_nd会更合适吗?如果是,有什么区别,特别是更新的索引。

当引用tf.scatter_update文档时,示例是基本的,并使用常量;我很难将其应用于更现实的情况和问题。

Tensorflow的文档通过将ref参数输入为tf来使用所有scatter操作(如scatter_nd_add等)。变量:

ref:一个可变张量。必须是以下类型之一:blabla。可变张量应该来自"变量"节点

我也遇到了同样的问题,当它用于ref的tf变量时,效果很好。我想,所有其他论点都可以保持原样,但我没有彻底调查。

最新更新