使用PyTorch的dcgan鉴别器精度度量



我正在使用PyTorch实现dcgan。

它工作得很好,因为我可以得到合理质量的生成图像,但是现在我想通过使用度量来评估GAN模型的健康状况,主要是本指南介绍的那些https://machinelearningmastery.com/practical-guide-to-gan-failure-modes/

它们的实现使用Keras,该SDK允许您在编译模型时定义您想要的指标,请参阅https://keras.io/api/models/model/。在这种情况下,鉴别器的准确性,即成功识别图像为真实图像或生成图像的百分比。

使用PyTorch SDK,我似乎找不到类似的功能来帮助我轻松地从模型中获取此指标。

Pytorch是否提供了能够从模型中定义和提取公共指标的功能?

Pure PyTorch提供开箱即用的指标,但自己定义这些指标非常容易。

也没有"从模型中提取指标"这样的事情。度量就是度量,它们度量(在这种情况下是判别器的准确性),它们不是模型固有的。

二进制精度

在您的例子中,您正在寻找二进制精度度量。下面的代码适用于logits(由discriminator输出的非归一化概率,可能是没有激活的最后一层nn.Linear)或probabilities(最后一层nn.Linear之后是sigmoid激活):

import typing
import torch

class BinaryAccuracy:
def __init__(
self,
logits: bool = True,
reduction: typing.Callable[
[
torch.Tensor,
],
torch.Tensor,
] = torch.mean,
):
self.logits = logits
if logits:
self.threshold = 0
else:
self.threshold = 0.5
self.reduction = reduction
def __call__(self, y_pred, y_true):
return self.reduction(((y_pred > self.threshold) == y_true.bool()).float())

用法:

metric = BinaryAccuracy()
target = torch.randint(2, size=(64,))
outputs = torch.randn(size=(64, 1))
print(metric(outputs, target))

PyTorch Lightning或其他第三方

你也可以在PyTorch之上使用PyTorch闪电或其他框架,这些框架定义了精度等指标

最新更新