PyTorch:计算接近(+/- 公差)参考张量值的张量值的数量

PyTorch: count the number of tensor values that are near (+/- a tolerance) the values of a reference tensor

我有 2 个具有多个维度的任意形状的张量。

我想计算 predicted_tensor 中接近目标张量值的值的数量。

使用 for 循环应该是这样的:

targets = torch.flatten(target_tensor)
predicted = torch.flatten(predicted_tensor)

correct_values = 0
tolerance = 0.1

for i, prediction in enumerate(predicted):
    target = targets[i]
    if (target - tolerance < prediction < target + tolerance):
        correct_values =+ 1

但是,for 循环对于性能来说并不是一个好主意。

我正在寻找矢量化解决方案。我试过了:

torch.sum(target - tolerance < prediction < target + tolerance)

但是我得到了:

RuntimeError: Boolean value of Tensor with more than one value is ambiguous

在 Julia 中,它只是添加一个点以精确表示它是元素明智的。

关于如何使用带有短向量化解决方案的 PyTorch 实现它的任何想法?

谢谢

我想你在找 torch.isclose:

correct_values = torch.isclose(prediction, target, atol=tolerance, rtol=0).sum()