用火炬中的最后一个非零值替换所有零
Replace all zeros with last non-zero value in torch
有什么有效的方法可以用 torch 中的最后一个非零值替换张量中的所有零吗?
例如,如果我有张量:
tensor([[1, 0, 0, 4, 0, 5, 0, 0],
[0, 3, 0, 6, 0, 0, 8, 0]])
输出应该是:
tensor([[1, 1, 1, 4, 4, 5, 5, 5],
[0, 3, 3, 6, 6, 6, 8, 8]])
我目前有以下代码:
def replace_zeros_with_prev_nonzero(tensor):
output = tensor.clone()
for i in range(len(output)):
prev_value = 0
for j in range(len(tensor[i])):
if tensor[i,j] == 0:
output[i,j] = prev_value
else:
prev_value = tensor[i,j].item()
return output
但感觉有点笨拙,我相信必须有更好的方法来做到这一点。那么是否可以用更少的行来编写它,或者更好的是在不将张量视为数组的情况下并行化操作?
您可以通过对第一个维度进行矢量化来移除其中一个循环。
def replace_zeros_with_prev_nonzero(tensor):
output = tensor.clone()
for i in range(1, tensor.shape[1]):
mask = tensor[:, i] == 0
output[mask, i] = output[mask, i-1]
return output
output[mask, i] = output[mask, i-1]
将 0 替换为之前的值(如果最初为 0,除了第 0 个索引之外,它本身将被替换)。
有什么有效的方法可以用 torch 中的最后一个非零值替换张量中的所有零吗?
例如,如果我有张量:
tensor([[1, 0, 0, 4, 0, 5, 0, 0],
[0, 3, 0, 6, 0, 0, 8, 0]])
输出应该是:
tensor([[1, 1, 1, 4, 4, 5, 5, 5],
[0, 3, 3, 6, 6, 6, 8, 8]])
我目前有以下代码:
def replace_zeros_with_prev_nonzero(tensor):
output = tensor.clone()
for i in range(len(output)):
prev_value = 0
for j in range(len(tensor[i])):
if tensor[i,j] == 0:
output[i,j] = prev_value
else:
prev_value = tensor[i,j].item()
return output
但感觉有点笨拙,我相信必须有更好的方法来做到这一点。那么是否可以用更少的行来编写它,或者更好的是在不将张量视为数组的情况下并行化操作?
您可以通过对第一个维度进行矢量化来移除其中一个循环。
def replace_zeros_with_prev_nonzero(tensor):
output = tensor.clone()
for i in range(1, tensor.shape[1]):
mask = tensor[:, i] == 0
output[mask, i] = output[mask, i-1]
return output
output[mask, i] = output[mask, i-1]
将 0 替换为之前的值(如果最初为 0,除了第 0 个索引之外,它本身将被替换)。