使用jax.lax.scan重写for循环



我很难理解JAX文档。有人能告诉我如何用jax.lax.scan重写这样的简单代码吗?

numbers = numpy.array( [ [3.0, 14.0], [15.0, -7.0], [16.0, -11.0] ])
evenNumbers = 0
for row in numbers:
for n in row:
if n % 2 == 0:
evenNumbers += 1

假设一个解决方案应该演示概念,而不是优化所示的示例,则要jax.lax.scan-ned的函数必须与预期签名匹配,并且任何动态条件都必须用jax.lax.cond替换。下面的代码是我能想到的最接近原始代码的,但请注意,我不是jaxpert。

import jax
import jax.numpy as jnp
def f(carry, row):
even = 0
for n in row:
even += jax.lax.cond(n % 2 == 0, lambda: 1, lambda: 0)
return carry + even, even
numbers = jnp.array([[3.0, 14.0], [15.0, -7.0], [16.0, -11.0]])
jax.lax.scan(f, 0, numbers)

输出

(DeviceArray(2, dtype=int32, weak_type=True),
DeviceArray([1, 0, 1], dtype=int32, weak_type=True))

最新更新