我是自动差分编程的新手,所以这可能是一个天真的问题。下面是我试图解决的问题的简化版本。
我有两个输入数组——大小为N
的向量A
和形状为(N, M)
的矩阵B
,以及大小为M
的参数向量theta
。我定义了一个新的数组CCD_ 7来获得一个大小为CCD_。然后,我获得位于C
的上四分位数和下四分位数的元素的索引,并使用它们创建新的数组A_low(theta) = A[lower quartile indices of C]
和A_high(theta) = A[upper quartile indices of C]
。很明显,这两者确实依赖于theta
,但是否有可能将A_low
和A_high
与theta
区分开来?
到目前为止,我的尝试似乎表明没有——我使用了autograd、JAX和tensorflow的python库,但它们都返回了零的梯度。(到目前为止,我尝试过的方法包括使用argsort或使用tf.top_k
提取相关的子数组。(
我寻求的帮助要么是证明导数没有定义(或无法解析计算(,要么是如果它确实存在,就如何估计它提出建议。我的最终目标是最小化f(A_low, A_high)
和theta
的一些函数。
这是我根据您的描述编写的JAX计算:
import numpy as np
import jax.numpy as jnp
import jax
N = 10
M = 20
rng = np.random.default_rng(0)
A = jnp.array(rng.random((N,)))
B = jnp.array(rng.random((N, M)))
theta = jnp.array(rng.random(M))
def f(A, B, theta, k=3):
C = B @ theta
_, i_upper = lax.top_k(C, k)
_, i_lower = lax.top_k(-C, k)
return A[i_lower], A[i_upper]
x, y = f(A, B, theta)
dx_dtheta, dy_dtheta = jax.jacobian(f, argnums=2)(A, B, theta)
导数都是零,我相信这是正确的,因为输出值的变化不取决于theta
值的变化。
但是,你可能会问,这怎么可能呢?毕竟,theta
会进入计算,如果为theta
输入不同的值,则会得到不同的输出。梯度怎么可能是零?
不过,您必须记住的是,差异化并不能衡量输入是否影响输出。它测量在输入变化极小的情况下输出的变化。
让我们用一个稍微简单一点的函数作为例子:
import jax
import jax.numpy as jnp
A = jnp.array([1.0, 2.0, 3.0])
theta = jnp.array([5.0, 1.0, 3.0])
def f(A, theta):
return A[jnp.argmax(theta)]
x = f(A, theta)
dx_dtheta = jax.grad(f, argnums=1)(A, theta)
这里,由于与上述相同的原因,将f
相对于theta
微分的结果全部为零。为什么?如果对theta
进行无穷小的更改,通常不会影响theta
的排序顺序。因此,从A
中选择的条目在θ发生微小变化的情况下不会发生变化,因此相对于θ的导数为零。
现在,你可能会争辩说,在某些情况下情况并非如此:例如,如果θ中的两个值非常接近,那么肯定会对其中一个值进行微小的扰动,从而改变它们各自的秩。这是真的,但这个过程产生的梯度是未定义的(输出的变化相对于输入的变化是不平滑的(。好消息是,这种不连续性是片面的:如果你在另一个方向上扰动,则秩没有变化,梯度是明确定义的。为了避免未定义的梯度,大多数autodiff系统将隐式地使用这种更安全的导数定义来进行基于秩的计算。
结果是,当你无限小地扰动输入时,输出的值不会改变,这是梯度为零的另一种说法。这并不是autodiff的失败——这是给定autodiff所建立的微分定义的正确梯度。此外,如果你试图在这些不连续处更改为导数的不同定义,你所希望的最好结果是未定义的输出,因此可以说导致零的定义更有用、更正确。