我定义了一个字典A
,并希望找到给定一批值a
的键:
def dictionary(r):
return dict(enumerate(r))
def get_key(val, my_dict):
for key, value in my_dict.items():
if np.array_equal(val,value):
return key
# dictionary
A = jnp.array([[0, 0],[1,1],[2,2],[3,3]])
A = dictionary(A)
a = jnp.array([[[1, 1],[2, 2], [3,3]],[[0, 0],[3, 3], [2,2]]])
keys = jax.vmap(jax.vmap(get_key, in_axes=(0,None)), in_axes=(0,None))(a, A)
预期的输出应该是:keys = [[1,2,3],[0,3,2]]
为什么我得到None
作为输出?
JAX通过跟踪函数来进行类似vmap
的转换,这意味着它们用值的抽象表示替换值,以提取函数中编码的操作序列(参见如何在JAX中思考以获得此概念的良好介绍)。
这意味着要正确使用vmap
,函数只能使用JAX方法,不能使用numpy方法,所以使用np.array_equal
打破了抽象。
不幸的是,它没有任何真正的替代品,因为没有机制可以在具体的Python字典中查找抽象的JAX值。如果您想对JAX值进行字典查找,则应该避免转换,而只使用Python循环:
keys = jnp.array([[get_key(x, A) for x in row] for row in a])
另一方面,我怀疑这更像是XY问题;您的目标不是在ajax转换中查找字典值,而是解决一些问题。也许你应该问一个关于如何解决问题的问题,而不是如何用你已经尝试过的解决方案来回避问题。
但是,如果您不愿意直接使用字典,那么与JAX兼容的另一种get_key
实现可能看起来像这样:
def get_key(val, my_dict):
keys = jnp.array(list(my_dict.keys()))
values = jnp.array(list(my_dict.values()))
return keys[jnp.where((values == val).all(-1), size=1)]