PyTorch - 逐元素签名 min/max?

PyTorch - Element-wise signed min/max?

我可能遗漏了一些明显的东西,但我找不到计算方法。

给定两个张量,我想保留每个张量中的最小元素以及符号。

我考虑过

sign_x = torch.sign(x)
sign_y = torch.sign(y)
min = torch.min(torch.abs(x), torch.abs(y))

为了最终将符号与获得的最小值相乘,但是我没有办法将正确的符号乘以保留的每个元素,必须选择两个张量之一。

这是一种方法。将 torch.sign(x)torch.sign(y) 乘以表示 xymin 计算结果的布尔张量。然后取两个结果张量的逻辑或 (|) 将它们组合起来,并将其乘以 min 计算。

mins = torch.min(torch.abs(x), torch.abs(y))

xSigns = (mins == torch.abs(x)) * torch.sign(x)
ySigns = (mins == torch.abs(y)) * torch.sign(y)
finalSigns = xSigns.int() | ySigns.int()

result = mins * finalSigns

如果xy对于某个元素有相同的绝对值,在上面的代码中x符号优先。要使 y 优先,请交换顺序并使用 finalSigns = ySigns.int() | xSigns.int()