我有这个函数,它适用于单个向量:
def vec_to_board(vector, player, dim, reverse=False):
player_board = np.zeros(dim * dim)
player_pos = np.argwhere(vector == player)
if not reverse:
player_board[mapping[player_pos.T]] = 1
else:
player_board[reverse_mapping[player_pos.T]] = 1
return np.reshape(player_board, [dim, dim])
但是,我希望它适用于一批向量。
到目前为止我尝试过的:
states = jnp.array([[1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2], [1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2]])
batch_size = 1
b_states = vmap(vec_to_board)((states, 1, 4), batch_size)
这行不通。然而,如果我理解正确的话,vmap应该能够处理批量的这种转换吗?
在尝试vmap
此函数时会遇到几个问题:
- 这个函数是根据numpy数组定义的,而不是jax数组。我怎么知道?JAX数组是不可变的,所以像
arr[idx] = 1
这样的东西会引发错误。您需要用等效的JAX操作来替换这些操作(请参阅JAX Sharp Bits:in-place updates(,并确保您的函数使用JAX数组操作而不是numpy数组操作 - 您的函数使用动态形状的数组;例如
player_pos
具有取决于vector == player
中非零条目的数量的形状。您必须根据静态形状的数组来重写函数。在jnp.argwhere
文档字符串中对此进行了一些讨论;例如,如果您事先知道在数组中期望有多少True条目,则可以指定size
来实现这一点
祝你好运!