给定一批样本找到x,y位置- python, jax



我想找到在尺寸为[batch, 4,4] =[2,4,4]的下一批抽样数组中'1'的位置。

import jax
import jax.numpy as jnp
a = jnp.array([[[0., 0., 0., 1.],
[0., 0., 0., 0.],
[0., 1., 0., 1.],
[0., 0., 1., 1.]],

[[1., 0., 1., 0.],
[1., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 1., 0., 1.]]])

我尝试通过批次的维度(与vmap),并使用jax函数找到坐标与

b = jax.vmap(jnp.where)(a)
print('b', b)

但是我得到一个错误,我不知道如何修复:

The size argument of jnp.nonzero must be statically specified to use jnp.nonzero within JAX transformations.
This Tracer was created on line /home/imi/Desktop/Backflow/backflow/src/debug.py:17 (<module>)

我希望得到以下输出:

b = [[[0,3], [2,1],[2,3],[3,2],[3,3]],
[[0,0],[0,2],[1,0],[3,1],[3,3]]

[x,y]坐标的第一行对应于第一批中有'1'的位置,第二行对应于第二批中有'1'的位置。

vmap这样的JAX转换要求数组的大小是静态的,所以没有办法精确地执行您所想到的计算(因为1条目的数量,以及输出数组的大小,是依赖于数据的)。

但是如果你先验地知道每批有五个条目,你可以这样做:

from functools import partial
indices = jax.vmap(partial(jnp.where, size=5))(a)
print(jnp.stack(indices, axis=2))
[[[0 3]
[2 1]
[2 3]
[3 2]
[3 3]]
[[0 0]
[0 2]
[1 0]
[3 1]
[3 3]]]

如果您不知道先验有多少1条目,那么您有几个选项:一个是避免JAX转换并在每个批次上调用未转换的jnp.where:

result = [jnp.column_stack(jnp.where(b)) for b in a]
print(result)
[DeviceArray([[0, 3],
[2, 1],
[2, 3],
[3, 2],
[3, 3]], dtype=int32), DeviceArray([[0, 0],
[0, 2],
[1, 0],
[3, 1],
[3, 3]], dtype=int32)]

请注意,对于这种情况,通常不可能将结果存储在单个数组中,因为每个批处理中可能有不同数量的1条目,并且JAX不支持不规则数组。

另一个选项是将size设置为某个最大值,并输出填充结果:

max_size = a[0].size  # size of slice is the upper bound
fill_value = a[0].shape  # fill with out-of-bound indices
indices = jax.vmap(partial(jnp.where, size=max_size, fill_value=fill_value))(a)
print(jnp.stack(indices, axis=2))
[[[0 3]
[2 1]
[2 3]
[3 2]
[3 3]
[4 4]
[4 4]
[4 4]
[4 4]
[4 4]
[4 4]
[4 4]
[4 4]
[4 4]
[4 4]
[4 4]]
[[0 0]
[0 2]
[1 0]
[3 1]
[3 3]
[4 4]
[4 4]
[4 4]
[4 4]
[4 4]
[4 4]
[4 4]
[4 4]
[4 4]
[4 4]
[4 4]]]

有了填充的结果,您就可以编写剩余的代码来预测这些填充的值。

最新更新