使用 PyTorch 的 DCGAN 鉴别器精度度量
DCGANs discriminator accuracy metric using PyTorch
我正在使用 PyTorch 实现 DCGAN。
它工作得很好,因为我可以获得合理质量的生成图像,但是现在我想通过使用指标(主要是本指南介绍的指标)来评估 GAN 模型的健康状况 https://machinelearningmastery.com/practical-guide-to-gan-failure-modes/
他们的实现使用 Keras,该 SDK 允许您在编译模型时定义所需的指标,请参阅 https://keras.io/api/models/model/。在这种情况下,鉴别器的准确性,即它成功地将图像识别为真实图像或生成图像的百分比。
使用 PyTorch SDK,我似乎找不到可以帮助我从我的模型中轻松获取此指标的类似功能。
Pytorch 是否提供能够从模型中定义和提取通用指标的功能?
Pure PyTorch 不提供开箱即用的指标,但您自己定义这些指标非常容易。
也没有“从模型中提取指标”这样的东西。指标就是指标,它们测量(在这种情况下是判别器的准确性),它们不是模型固有的。
二进制精度
在您的例子中,您正在寻找二进制精度指标。下面的代码适用于 logits
([=13= 输出的非归一化概率],可能是没有激活的最后 nn.Linear
层)或 probabilities
(最后 nn.Linear
后跟 sigmoid
激活):
import typing
import torch
class BinaryAccuracy:
def __init__(
self,
logits: bool = True,
reduction: typing.Callable[
[
torch.Tensor,
],
torch.Tensor,
] = torch.mean,
):
self.logits = logits
if logits:
self.threshold = 0
else:
self.threshold = 0.5
self.reduction = reduction
def __call__(self, y_pred, y_true):
return self.reduction(((y_pred > self.threshold) == y_true.bool()).float())
用法:
metric = BinaryAccuracy()
target = torch.randint(2, size=(64,))
outputs = torch.randn(size=(64, 1))
print(metric(outputs, target))
PyTorch Lightning 或其他第三方
您也可以使用PyTorch Lightning or other framework on top of PyTorch which defines metrics like accuracy
我正在使用 PyTorch 实现 DCGAN。
它工作得很好,因为我可以获得合理质量的生成图像,但是现在我想通过使用指标(主要是本指南介绍的指标)来评估 GAN 模型的健康状况 https://machinelearningmastery.com/practical-guide-to-gan-failure-modes/
他们的实现使用 Keras,该 SDK 允许您在编译模型时定义所需的指标,请参阅 https://keras.io/api/models/model/。在这种情况下,鉴别器的准确性,即它成功地将图像识别为真实图像或生成图像的百分比。
使用 PyTorch SDK,我似乎找不到可以帮助我从我的模型中轻松获取此指标的类似功能。
Pytorch 是否提供能够从模型中定义和提取通用指标的功能?
Pure PyTorch 不提供开箱即用的指标,但您自己定义这些指标非常容易。
也没有“从模型中提取指标”这样的东西。指标就是指标,它们测量(在这种情况下是判别器的准确性),它们不是模型固有的。
二进制精度
在您的例子中,您正在寻找二进制精度指标。下面的代码适用于 logits
([=13= 输出的非归一化概率],可能是没有激活的最后 nn.Linear
层)或 probabilities
(最后 nn.Linear
后跟 sigmoid
激活):
import typing
import torch
class BinaryAccuracy:
def __init__(
self,
logits: bool = True,
reduction: typing.Callable[
[
torch.Tensor,
],
torch.Tensor,
] = torch.mean,
):
self.logits = logits
if logits:
self.threshold = 0
else:
self.threshold = 0.5
self.reduction = reduction
def __call__(self, y_pred, y_true):
return self.reduction(((y_pred > self.threshold) == y_true.bool()).float())
用法:
metric = BinaryAccuracy()
target = torch.randint(2, size=(64,))
outputs = torch.randn(size=(64, 1))
print(metric(outputs, target))
PyTorch Lightning 或其他第三方
您也可以使用PyTorch Lightning or other framework on top of PyTorch which defines metrics like accuracy