连接从循环内生成的 N 个 pytorch 张量(形状相同)

Concatenate N pytorch tensors (of the same shape) generated from within loop

从一个循环中返回相同形状的张量,我想尽可能简洁地连接它们,并尽可能以 python/pytorchly 的方式连接它们。

当前解决方案:

import torch

for object_id in object_ids:
    
    dataset = Dataset(object_id)

    image_tensor = dataset.get_random_image_tensor()

    if 'concatenated_image_tensors' in locals():
        concatenated_image_tensors = torch.cat((merged_image_tensors, image_tensor))
    else:
        concatenated_image_tensors = image_tensor

有没有更好的方法?

一个好的方法是首先附加到 python list,然后在最后连接整个 list。否则,每次调用 torch.cat 时,您最终都会在内存中移动数据。

all_img = []
for object_id in object_ids:
    dataset = Dataset(object_id)
    image_tensor = dataset.get_random_image_tensor()
    all_img.append(image_tensor)

all_img = torch.cat(all_img)