用火炬中的最后一个非零值替换所有零

Replace all zeros with last non-zero value in torch

有什么有效的方法可以用 torch 中的最后一个非零值替换张量中的所有零吗?

例如,如果我有张量:

tensor([[1, 0, 0, 4, 0, 5, 0, 0],
        [0, 3, 0, 6, 0, 0, 8, 0]])

输出应该是:

tensor([[1, 1, 1, 4, 4, 5, 5, 5],
        [0, 3, 3, 6, 6, 6, 8, 8]])

我目前有以下代码:

def replace_zeros_with_prev_nonzero(tensor):
    output = tensor.clone()
    for i in range(len(output)):
        prev_value = 0
        for j in range(len(tensor[i])):
            if tensor[i,j] == 0:
                output[i,j] = prev_value
            else:
                prev_value = tensor[i,j].item()      
    return output

但感觉有点笨拙,我相信必须有更好的方法来做到这一点。那么是否可以用更少的行来编写它,或者更好的是在不将张量视为数组的情况下并行化操作?

您可以通过对第一个维度进行矢量化来移除其中一个循环。

def replace_zeros_with_prev_nonzero(tensor):
    output = tensor.clone()
    for i in range(1, tensor.shape[1]):
        mask = tensor[:, i] == 0
        output[mask, i] = output[mask, i-1]

    return output

output[mask, i] = output[mask, i-1] 将 0 替换为之前的值(如果最初为 0,除了第 0 个索引之外,它本身将被替换)。