如何从pytorch图中删除张量?

How to delete tensors from pytorch graph?

我正在对图像进行预测,以便在 for 循环中进行对象检测。 我实际上 运行 遇到了与 tensorflow 相同的问题,希望我可以用 pytorch 解决它。 至少现在我似乎已经找到了问题所在(天真地假设它对 tensorflow 来说是一样的)

我是这样预测的

 model = detection.fasterrcnn_resnet50_fpn(pretrained=True, 
    progress=True,pretrained_backbone=True).to(DEVICE)
    for i in tqdm(range(train.shape[0])):
        image = cv2.imread(train_img_paths[i])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = image.transpose((2, 0, 1))
        image = image / 255.0
        image = np.expand_dims(image, axis=0)
        image = torch.FloatTensor(image)
        image = image.to(DEVICE)
        predictions = model(image)[0]

现在通过垃圾收集器,我发现每张图片都保留在图中。 有什么办法可以避开吗?

我无法将数据加载器或数据集与检测模型一起使用(与 tensorflow hub 相同)

不要忘记在进行测试时关闭梯度累积。您可以通过包装您的代码来做到这一点:

with torch.no_grad():
     model.eval()
     out = model(x)

或者如果你的代码是一个函数,使用装饰器来做同样的事情:

@torch.no_grad()
def model_proc(model,x):
    model.eval()
    return model(x)