如何在jax中对特定功能进行vmap

  • 本文关键字:功能 vmap jax python jax
  • 更新时间 :
  • 英文 :


我有这个函数,它适用于单个向量:

def vec_to_board(vector, player, dim, reverse=False):
player_board = np.zeros(dim * dim)
player_pos = np.argwhere(vector == player)
if not reverse:
player_board[mapping[player_pos.T]] = 1
else:
player_board[reverse_mapping[player_pos.T]] = 1
return np.reshape(player_board, [dim, dim])

但是,我希望它适用于一批向量。

到目前为止我尝试过的:

states = jnp.array([[1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2], [1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2]])
batch_size = 1
b_states = vmap(vec_to_board)((states, 1, 4), batch_size)

这行不通。然而,如果我理解正确的话,vmap应该能够处理批量的这种转换吗?

在尝试vmap此函数时会遇到几个问题:

  1. 这个函数是根据numpy数组定义的,而不是jax数组。我怎么知道?JAX数组是不可变的,所以像arr[idx] = 1这样的东西会引发错误。您需要用等效的JAX操作来替换这些操作(请参阅JAX Sharp Bits:in-place updates(,并确保您的函数使用JAX数组操作而不是numpy数组操作
  2. 您的函数使用动态形状的数组;例如player_pos具有取决于vector == player中非零条目的数量的形状。您必须根据静态形状的数组来重写函数。在jnp.argwhere文档字符串中对此进行了一些讨论;例如,如果您事先知道在数组中期望有多少True条目,则可以指定size来实现这一点

祝你好运!

最新更新