我一直在尝试使用tf.scatter_update
进行N维更新(在tf.scatter_nd
由于形状不匹配而失败之后)。通常,这些将用于创建掩码,用于过滤传入张量的切片。
假设输入张量A具有形状(batch,i,j,k(深度))。我只对修改所有k和全部b的I,j值感兴趣。
MWE:
import tensorflow as tf
b, i, j, k = 64, 128, 128, 256
A = tf.random_uniform(shape=(64, 128, 128, 256), dtype='int32', seed=1234) # Batch, i, j, k
mask = tf.ones(shape=(b,i,j,k), dtype='int32')
# Placeholder for more complicated index Tensor. GPU Ignores OOB indices.
indices = tf.random_uniform(shape=(b, 25, k, 2), dtype='int32', seed=4321) # Index number, k, i-j coord.
updates = tf.random_uniform(shape=(i, j, k), dtype='int32', seed=1111)
scatter = tf.scatter_update(mask, indices, updates)
with tf.Session() as sess:
sess.run(scatter)
结果:
属性错误:"Tensor"对象没有属性"_lazy_read"我已经通过Python Script、Python Notebook和使用/不使用Eager Execution进行了尝试。运气不好。
输入绝对必须是一个张量,因为我们的想法是在一系列操作的中途稀疏地更新这个张量。
关于tf.scatter_update
,我是否缺少一些基本的东西?tf.scatter_nd
会更合适吗?如果是,有什么区别,特别是更新的索引。
当引用tf.scatter_update文档时,示例是基本的,并使用常量;我很难将其应用于更现实的情况和问题。
Tensorflow的文档通过将ref参数输入为tf来使用所有scatter操作(如scatter_nd_add等)。变量:
ref:一个可变张量。必须是以下类型之一:blabla。可变张量应该来自"变量"节点。
我也遇到了同样的问题,当它用于ref的tf变量时,效果很好。我想,所有其他论点都可以保持原样,但我没有彻底调查。