假设我有一个函数(为简单起见,是两个序列之间的协方差,尽管这个问题更一般):
def cov(x, y):
return jnp.dot((x-jnp.mean(x)), (y-jnp.mean(y)))
现在我有了一个"数据框架"D
(一个二维数组,它的列是我的级数)我想对cov
进行矢量化这样,对矢量化函数的应用产生协方差矩阵。现在,有一种很明显的方法:
cov1 = jax.vmap(cov, in_axes=(None, 1))
cov2 = jax.vmap(cov1, in_axes=(1, None))
但是看起来有点笨拙。有"标准"吗?怎么做呢?
如果您想用vmap
表示与嵌套for
循环等效的逻辑,那么是的,它需要嵌套vmap。我认为您所写的可能是像这样的操作所能得到的最规范的,尽管如果使用decorator可能会更清楚一些:
from functools import partial
@partial(jax.vmap, in_axes=(1, None))
@partial(jax.vmap, in_axes=(None, 1))
def cov(x, y):
return jnp.dot((x-jnp.mean(x)), (y-jnp.mean(y)))
对于这个特殊的函数,请注意,如果您愿意,可以使用单个点积表示相同的东西:
result = jnp.dot((x - x.mean(0)).T, (y - y.mean(0)))