jax的矢量化指南

  • 本文关键字:矢量化 jax python jax
  • 更新时间 :
  • 英文 :


假设我有一个函数(为简单起见,是两个序列之间的协方差,尽管这个问题更一般):

def cov(x, y):
return jnp.dot((x-jnp.mean(x)), (y-jnp.mean(y)))

现在我有了一个"数据框架"D(一个二维数组,它的列是我的级数)我想对cov进行矢量化这样,对矢量化函数的应用产生协方差矩阵。现在,有一种很明显的方法:

cov1 = jax.vmap(cov, in_axes=(None, 1))
cov2 = jax.vmap(cov1, in_axes=(1, None))

但是看起来有点笨拙。有"标准"吗?怎么做呢?

如果您想用vmap表示与嵌套for循环等效的逻辑,那么是的,它需要嵌套vmap。我认为您所写的可能是像这样的操作所能得到的最规范的,尽管如果使用decorator可能会更清楚一些:

from functools import partial
@partial(jax.vmap, in_axes=(1, None))
@partial(jax.vmap, in_axes=(None, 1))
def cov(x, y):
return jnp.dot((x-jnp.mean(x)), (y-jnp.mean(y)))

对于这个特殊的函数,请注意,如果您愿意,可以使用单个点积表示相同的东西:

result = jnp.dot((x - x.mean(0)).T, (y - y.mean(0)))

相关内容

  • 没有找到相关文章

最新更新