Numpy 等效于 "tf.tensor_scatter_nd_add" 方法



问题是在标题真的,我正在寻找在scipy/numpy/等方法。(不是TensorFlow),它封装了tf中描述的行为。tensor_scatter_nd_add,但对Numpy数组而不是张量。

我遇到了scipy. nimage .sum方法,但是无法让它复制我下面给出的例子。

无论您认为哪种方法合适,都必须能够重现TF文档中提供的rank-3示例:

    indices = tf.constant([[0], [2]])
    updates = tf.constant([[[5, 5, 5, 5], [6, 6, 6, 6],
                            [7, 7, 7, 7], [8, 8, 8, 8]],
                           [[5, 5, 5, 5], [6, 6, 6, 6],
                            [7, 7, 7, 7], [8, 8, 8, 8]]])
    tensor = tf.ones([4, 4, 4],dtype=tf.int32)
    updated = tf.tensor_scatter_nd_add(tensor, indices, updates)
    print(updated)
希望有人以前解决过类似的问题,可以在这里提供帮助-提前感谢!

我可以确认以下函数为我捕获了所需的行为:

    def scatter_nd_add_numpy(target, indices, updates):
        indices = tuple(indices.reshape(-1, indices.shape[-1]).T)
        np.add.at(target, indices, updates)
        return target

感谢Remy在这个stackoverflow线程上的回答。

相关内容

  • 没有找到相关文章

最新更新