一次更新二维(或多维)Jax数组的多个索引的正确方法是什么?



一次更新2D(或多维)Jax数组的多个索引的正确方法是什么?

这是我之前关于批量更新1D Jax数组的后续问题,目标是避免在训练期间创建数百万个数组。

I have try:

x = jnp.zeros((3,3))
# Update 1 index at a time
x = x.at[2, 2].set(1) # or x = x.at[(2, 2)].set(1)
[[0. 0. 0.]
[0. 0. 0.]
[0. 0. 1.]]
# Nice, it works.
# but how about 2 indexes at the same time?
x = jnp.zeros((3,3))
x = x.at[(1, 0), (0, 1) ].set([1, 3])
print(x)
[[0. 3. 0.]
[1. 0. 0.]
[0. 0. 0.]]
It works again, but when I tried to update 3 or more indexes,
x = x.at[(1, 0), (0, 1), (1,1) ].set([1, 3, 6])
print(x)
IndexError: Too many indices for array: 3 non-None/Ellipsis indices for dim 2.

我花了一些时间浏览Jax的文档,但我找不到最好的方法。任何帮助吗?

.at中给出的值是行和列,而不是行/列对。在引用dim 2的错误消息中暗示了这一点(dim 0是行,dim 1是列,没有dim 2)。这应该给出期望的行为

x = x.at[(1, 0, 1), (0, 1, 1) ].set([1, 3, 6])
[[0. 3. 0.]
[1. 6. 0.]
[0. 0. 0.]]

相关内容

  • 没有找到相关文章

最新更新