PyTorch使用外部库保留梯度



我有一个GAN,它返回一个预测的torch.tensor。为了指导这个网络,我有一个损失函数,它是二进制交叉熵损失(BCELoss(和Wasserstein距离的总和。然而,为了计算Wasserstein距离,我使用了SciPy库中的scipy.stats.wasserstein_distance函数。如您所知,此函数需要两个NumPy数组作为输入。因此,为了使用这个函数,我将我的预测张量和地面实况张量转换为NumPy阵列,如下所示

pred_np = pred_tensor.detach().cpu().clone().numpy().ravel()
target_np = target_tensor.detach().cpu().clone().numpy().ravel()
W_loss = wasserstein_distance(pred_np, target_np)

然后,通过将W_lossBCELoss相加来获得总损耗。我现在展示这一部分,因为这有点多余,与我的问题无关。

我担心的是我正在分离梯度,所以我认为在优化和更新模型参数时,它不会考虑W_loss。我有点新手,所以我希望我的问题很清楚,并感谢你提前给出答案。

将一个而非的对象添加为一个需要_grad的张量,本质上就是添加一个常量。常数的导数为零,所以这个附加项对网络的权重没有任何作用。

tl;博士:你需要在pytorch中重写损失计算(或者只是找到一个现有的实现,互联网上有很多(。

最新更新