连接从循环内生成的 N 个 pytorch 张量(形状相同)
Concatenate N pytorch tensors (of the same shape) generated from within loop
从一个循环中返回相同形状的张量,我想尽可能简洁地连接它们,并尽可能以 python/pytorchly 的方式连接它们。
当前解决方案:
import torch
for object_id in object_ids:
dataset = Dataset(object_id)
image_tensor = dataset.get_random_image_tensor()
if 'concatenated_image_tensors' in locals():
concatenated_image_tensors = torch.cat((merged_image_tensors, image_tensor))
else:
concatenated_image_tensors = image_tensor
有没有更好的方法?
一个好的方法是首先附加到 python list,然后在最后连接整个 list。否则,每次调用 torch.cat
时,您最终都会在内存中移动数据。
all_img = []
for object_id in object_ids:
dataset = Dataset(object_id)
image_tensor = dataset.get_random_image_tensor()
all_img.append(image_tensor)
all_img = torch.cat(all_img)
从一个循环中返回相同形状的张量,我想尽可能简洁地连接它们,并尽可能以 python/pytorchly 的方式连接它们。
当前解决方案:
import torch
for object_id in object_ids:
dataset = Dataset(object_id)
image_tensor = dataset.get_random_image_tensor()
if 'concatenated_image_tensors' in locals():
concatenated_image_tensors = torch.cat((merged_image_tensors, image_tensor))
else:
concatenated_image_tensors = image_tensor
有没有更好的方法?
一个好的方法是首先附加到 python list,然后在最后连接整个 list。否则,每次调用 torch.cat
时,您最终都会在内存中移动数据。
all_img = []
for object_id in object_ids:
dataset = Dataset(object_id)
image_tensor = dataset.get_random_image_tensor()
all_img.append(image_tensor)
all_img = torch.cat(all_img)