如何包装numpy函数使其与jax.numpy一起工作



我有一些Jax代码需要使用自动微分,在部分代码中,我想从用NumPy编写的库中调用一个函数。当我现在尝试这个时,我得到

The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float32[4,22324])>with<JVPTrace(level=4/1)> with
primal = Traced<ShapedArray(float32[4,22324])>with<DynamicJaxprTrace(level=0/1)>
tangent = Traced<ShapedArray(float32[4,22324])>with<JaxprTrace(level=3/1)> with
pval = (ShapedArray(float32[4,22324]), None)
recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7fa89e8ffa80>, in_tracers=(Traced<ShapedArray(float32[22324,4]):JaxprTrace(level=3/1)>,), out_tracer_refs=[<weakref at 0x7fa89beb15e0; to 'JaxprTracer' at 0x7fa893b5ab80>], out_avals=[ShapedArray(float32[4,22324])], primitive=transpose, params={'permutation': (1, 0)}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x7fa89e9312b0>, name_stack=NameStack(stack=(Transform(name='jvp'),))))
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

这是有意义的,因为NumPy是不可自微分的。

有没有办法包装用NumPy编写的函数,使其映射到等价的jax.numpy

实现这一点的一种糟糕方法是修改库,使其调用jax.numpy而不是numpy,但这会使适用性更加困难。

谢谢!

编辑2023年1月:JAX现在添加了许多回调方法来完成这类任务;看见https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html

不,通常情况下,对于在NumPy数组上操作的函数,无法将其自动转换为在JAX中实现的等效函数。这是因为JAX不是NumPy的API的100%忠实的一对一实现;相反,您应该将jax.numpy看作是围绕JAX提供的功能的类似NumPy的的包装器。

作为一个简单的例子,考虑以下代码:

np.array(['A', 'B', 'C'])

这没有等价的JAX-,因为JAX-neneneba XLA不支持字符串数组。

如果您想在代码中使用类似autodiff的JAX转换,那么在JAX中重写代码并没有任何捷径。只要不使用在阵列上运行的外部库(如SciPy、Scikit-Learn等(,用import jax.numpy as jnp替换import numpy as np可能会有很长的路要走。

此外,在进行此类替换时,请记住JAX的Sharp Bits,在这些地方jax.numpy的行为可能与原始NumPy代码不同。

import numpy as np
import jax.numpy as jnp
import jax
import inspect
import re
def function_np(x):
return np.maximum(0, x)
function_np_str = inspect.getsource(function_np) # getting the code as a string
function_jnp_str = re.sub(r"np", "jnp", function_code) #replacing all the 'np' with 'jnp'
# The line below creates a function defined in the 'jnp_function_str' string - which uses jnp instead of numpy
exec(jnp_activation_str)  

现在您有一个名为"functionjnp"的新函数,它使用jnp库。

它有点像拐杖,但它适用于简单的功能。

最新更新