PyTorch最有效的雅可比/黑森计算



我正在寻找通过 Pytorch 获取函数的雅可比函数的最有效方法,到目前为止,我提出了以下解决方案:

# Setup
def func(X):
    return torch.stack((X.pow(2).sum(1),
                        X.pow(3).sum(1),
                        X.pow(4).sum(1)),1)  
X = Variable(torch.ones(1,int(1e5))*2.00094, requires_grad=True).cuda()
# Solution 1:
t = time()
Y = func(X)
J = torch.zeros(3, int(1e5))
for i in range(3):
    J[i] = grad(Y[0][i], X, create_graph=True, retain_graph=True, allow_unused=True)[0]
print(time()-t)
>>> Output: 0.002 s
# Solution 2:
def Jacobian(f,X):
    X_batch = Variable(X.repeat(3,1), requires_grad=True)
    f(X_batch).backward(torch.eye(3).cuda(), retain_graph=True)
    return X_batch.grad
t = time()
J2 = Jacobian(func,X)
print(time()-t)
>>> Output: 0.001 s

由于在第一种解决方案中使用循环与在第二种解决方案中使用循环之间似乎没有太大区别,因此我想问一下是否还有一种更快的方法来计算 pytorch 中的雅可比矩阵。

我的另一个问题是关于计算黑森的最有效方法是什么。

最后,有谁知道这样的事情是否可以在TensorFlow中更容易或更高效地完成?

>functorch可以进一步加快计算速度。 例如,此代码来自批处理雅可比计算的functorch文档(Hessian 也可以(:

batch_size = 64
Din = 31
Dout = 33
weight = torch.randn(Dout, Din)
print(f"weight shape = {weight.shape}")
bias = torch.randn(Dout)
def predict(weight, bias, x):
    return F.linear(x, weight, bias).tanh()
x = torch.randn(batch_size, Din)
compute_batch_jacobian = vmap(jacrev(predict, argnums=2), in_dims=(None, None, 0))
batch_jacobian0 = compute_batch_jacobian(weight, bias, x)

最有效的方法可能是使用 PyTorch 自己的内置函数:

torch.autograd.functional.jacobian(func, x)
torch.autograd.functional.hessian(func, x)

我有一个类似的问题,我通过手动定义雅可比矩阵(手动计算导数(来解决。对于我的问题,这是可行的,但我可以想象情况并非总是如此。与第二种解决方案相比,计算时间加快了我的机器 (cpu( 上的一些因素。

# Solution 2
def Jacobian(f,X):
    X_batch = Variable(X.repeat(3,1), requires_grad=True)
    f(X_batch).backward(torch.eye(3).cuda(),  retain_graph=True)
    return X_batch.grad
%timeit Jacobian(func,X)
11.7 ms ± 130 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
# Solution 3
def J_func(X):
    return torch.stack(( 
                 2*X,
                 3*X.pow(2),
                 4*X.pow(3)
                  ),1)
%timeit J_func(X)
539 µs ± 24.4 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

相关内容

  • 没有找到相关文章

最新更新