如何检查两个 Torch 张量或矩阵是否相等?

How to check if two Torch tensors or matrices are equal?

我需要一个 Torch 命令来检查两个张量是否具有相同的内容,并且 returns 如果它们具有相同的内容则为 TRUE。

例如:

local tens_a = torch.Tensor({9,8,7,6});
local tens_b = torch.Tensor({9,8,7,6});

if (tens_a EQUIVALENCE_COMMAND tens_b) then ... end

我应该在此脚本中使用什么来代替 EQUIVALENCE_COMMAND

我只是尝试使用 ==,但它不起作用。

torch.eq(a, b)

eq() 实现 == 运算符将 a 中的每个元素与 b(如果 b 是一个值)或 a 中的每个元素与它在 b 中的相应元素(如果 b 是张量)。


@deltheil 的替代方案:

torch.all(tens_a.eq(tens_b))

如果你想忽略浮点数常见的小精度差异,试试这个

torch.all(torch.lt(torch.abs(torch.add(tens_a, -tens_b)), 1e-12))

下面的解决方案对我有用:

torch.equal(tensorA, tensorB)

来自the documentation

True if two tensors have the same size and elements, False otherwise.

要比较张量,您可以在元素方面进行操作:

torch.eq 是元素明智的:

torch.eq(torch.tensor([[1., 2.], [3., 4.]]), torch.tensor([[1., 1.], [4., 4.]]))
tensor([[True, False], [False, True]])

torch.equal 对整个张量正好:

torch.equal(torch.tensor([[1., 2.], [3, 4.]]), torch.tensor([[1., 1.], [4., 4.]]))
# False
torch.equal(torch.tensor([[1., 2.], [3., 4.]]), torch.tensor([[1., 2.], [3., 4.]]))
# True

但是你可能会迷路,因为在某些时候你会想忽略一些小的差异。例如,浮点数 1.01.0000000001 非常接近,您可能认为它们是相等的。对于这种比较,您有 torch.allclose.

torch.allclose(torch.tensor([[1., 2.], [3., 4.]]), torch.tensor([[1., 2.000000001], [3., 4.]]))
# True

在某些时候,与元素的全部数量相比,检查元素方面有多少元素相等可能很重要。如果你有两个张量 dt1dt2 你会得到 dt1 的元素数为 dt1.nelement()

用这个公式你可以得到百分比:

print(torch.sum(torch.eq(dt1, dt2)).item()/dt1.nelement())

您可以将两个张量转换为 numpy 数组:

local tens_a = torch.Tensor((9,8,7,6));
local tens_b = torch.Tensor((9,8,7,6));

a=tens_a.numpy()
b=tens_b.numpy()

然后是

np.sum(a==b)
4

会让你很好地了解他们是多么平等。