PyTorch 如何将张量每一行的最小值设置为零?
PyTorch how to set the minimum value along each row of a tensor to zero?
我们不知道输入张量的形状,我们不应该使用任何循环。仅 归约和索引操作 。我们如何将每行的最小值设置为零?
例如:
输入:
x = torch.tensor([[
[10, 20, 30]
[2, 5, 1]
]])
输出:
torch.tensor([
[0, 20, 30],
[2, 5, 0]
])
想不通也找不到相关问题。我卡住了。
这样做的一种方法是计算逐行最小值 x.min(dim=-1)
,获取最小值 x.min(dim=-1).values
(索引在多个最小元素的情况下不起作用),获取指示位置的掩码非最小元素使用比较并乘以它:
axis = -1 # Minimum iterating over the last dimension
min_values = x.min(dim=axis).values # Get minimal values
min_values_shape_corrected = min_values.unsqueeze(axis) # Reshape the minimal values so we can compare it with `x`
mask = (x != min_values_shape_corrected) # Get the mask of non-minimal elements
result = x * mask # Multiplying by a boolean mask leaves only True elements and sets False ones to 0
或在一行中
x * (x != x.min(dim=axis).values.unsqueeze(axis))
我们不知道输入张量的形状,我们不应该使用任何循环。仅 归约和索引操作 。我们如何将每行的最小值设置为零?
例如:
输入:
x = torch.tensor([[
[10, 20, 30]
[2, 5, 1]
]])
输出:
torch.tensor([
[0, 20, 30],
[2, 5, 0]
])
想不通也找不到相关问题。我卡住了。
这样做的一种方法是计算逐行最小值 x.min(dim=-1)
,获取最小值 x.min(dim=-1).values
(索引在多个最小元素的情况下不起作用),获取指示位置的掩码非最小元素使用比较并乘以它:
axis = -1 # Minimum iterating over the last dimension
min_values = x.min(dim=axis).values # Get minimal values
min_values_shape_corrected = min_values.unsqueeze(axis) # Reshape the minimal values so we can compare it with `x`
mask = (x != min_values_shape_corrected) # Get the mask of non-minimal elements
result = x * mask # Multiplying by a boolean mask leaves only True elements and sets False ones to 0
或在一行中
x * (x != x.min(dim=axis).values.unsqueeze(axis))