AttributeError: 'numpy.ndarray' object has no attribute 'unsqueeze'

AttributeError: 'numpy.ndarray' object has no attribute 'unsqueeze'

我是 运行 使用 pyhtorchnumpy 的训练代码。

这是plot_example函数:

def plot_example(low_res_folder, gen):
    files=os.listdir(low_res_folder)
    
    gen.eval()
    for file in files:
        image=Image.open("test_images/" + file)
        with torch.no_grad():
            upscaled_img=gen(
                config1.both_transform(image=np.asarray(image))["image"]
                .unsqueeze(0)
                .to(config1.DEVICE)
            )
        save_image(upscaled_img * 0.5 + 0.5, f"saved/{file}")
    gen.train()

我遇到的问题是 unsqueeze 属性引发错误:

File "E:\Downloads\esrgan-tf2-masteren\modules\train1.py", line 58, in train_fn
    plot_example("test_images/", gen)
    
File "E:\Downloads\esrgan-tf2-masteren\modules\utils1.py", line 46, in plot_example
    config1.both_transform(image=np.asarray(image))["image"]
    
AttributeError: 'numpy.ndarray' object has no attribute 'unsqueeze'

网络是GAN网络,gen()代表Generator

在进入任何 Pytorch 层之前,确保图像是形状为 [batch size, channels, height, width] 的张量。

给你 image=np.asarray(image)

我会删除这个 numpy 转换并保留它 torch.tensor。

或者,如果您真的希望它成为一个 numpy 数组,那么在它进入您的生成器之前,确保在它被解压缩之前使用 torch.from_numpy(),如本文档中所示,在您的 numpy 图像上:https://pytorch.org/docs/stable/generated/torch.from_numpy.html

如果您不想摆脱原来的转换,当然可以选择此功能。

萨萨克·耆那教