Pytorch 闪电:'CIFAR10DataModule' 对象没有属性 'train_loader'

Pytorch lightning: 'CIFAR10DataModule' object has no attribute 'train_loader'

你能告诉我为什么导入 CUFAR10DataModule() 失败吗?

起初,我运行 GoogleColab 上的代码,

from pl_bolts.datamodules import CIFAR10DataModule
dm = CIFAR10DataModule()

然后,执行验证码

from torch.optim import Adam
optimizer = Adam(finetune_layer.parameters(), lr=1e-4)

for epoch in range(10):
  for batch in dm.train_loader:
    x, y = batch
    with torch.no_grad():
      features = backbone(x)

    preds = finetune_layer(features)
    loss = cross_entropy(preds, y)

    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    print(loss.item())

但是,运行输入代码后返回了消息AttributeError: 'CIFAR10DataModule' object has no attribute 'train_loader'

当代码为运行确认dm,

for batch in dm.train_dataloader:
  x, y = batch
  print(x.shape, y.shape)
  break

错误显示 TypeError: 'method' object is not iterable

代码和例子看起来一样,但是我想知道为什么会产生这样的错误?

您的代码有两个问题:

首先,获取底层 PyTorch 数据加载器的方式是 dm.train_dataloader() 而不是 dm.train_loader。它是一个 函数,而不是 属性

for batch in dm.train_dataloader():
    x, y = batch
    ...

其次,由于您尝试使用没有 TrainerLightningDataModule,您需要手动调用

dm.prepare_data()
dm.setup()

.. 为了让数据加载器可以通过 .train_dataloader().