Pytorch 闪电:'CIFAR10DataModule' 对象没有属性 'train_loader'
Pytorch lightning: 'CIFAR10DataModule' object has no attribute 'train_loader'
你能告诉我为什么导入 CUFAR10DataModule() 失败吗?
起初,我运行 GoogleColab 上的代码,
from pl_bolts.datamodules import CIFAR10DataModule
dm = CIFAR10DataModule()
然后,执行验证码
from torch.optim import Adam
optimizer = Adam(finetune_layer.parameters(), lr=1e-4)
for epoch in range(10):
for batch in dm.train_loader:
x, y = batch
with torch.no_grad():
features = backbone(x)
preds = finetune_layer(features)
loss = cross_entropy(preds, y)
loss.backward()
optimizer.step()
optimizer.zero_grad()
print(loss.item())
但是,运行输入代码后返回了消息AttributeError: 'CIFAR10DataModule' object has no attribute 'train_loader'
。
当代码为运行确认dm
,
for batch in dm.train_dataloader:
x, y = batch
print(x.shape, y.shape)
break
错误显示 TypeError: 'method' object is not iterable
。
代码和例子看起来一样,但是我想知道为什么会产生这样的错误?
您的代码有两个问题:
首先,获取底层 PyTorch 数据加载器的方式是 dm.train_dataloader()
而不是 dm.train_loader
。它是一个 函数,而不是 属性。
for batch in dm.train_dataloader():
x, y = batch
...
其次,由于您尝试使用没有 Trainer
的 LightningDataModule
,您需要手动调用
dm.prepare_data()
dm.setup()
.. 为了让数据加载器可以通过 .train_dataloader()
.
你能告诉我为什么导入 CUFAR10DataModule() 失败吗?
起初,我运行 GoogleColab 上的代码,
from pl_bolts.datamodules import CIFAR10DataModule
dm = CIFAR10DataModule()
然后,执行验证码
from torch.optim import Adam
optimizer = Adam(finetune_layer.parameters(), lr=1e-4)
for epoch in range(10):
for batch in dm.train_loader:
x, y = batch
with torch.no_grad():
features = backbone(x)
preds = finetune_layer(features)
loss = cross_entropy(preds, y)
loss.backward()
optimizer.step()
optimizer.zero_grad()
print(loss.item())
但是,运行输入代码后返回了消息AttributeError: 'CIFAR10DataModule' object has no attribute 'train_loader'
。
当代码为运行确认dm
,
for batch in dm.train_dataloader:
x, y = batch
print(x.shape, y.shape)
break
错误显示 TypeError: 'method' object is not iterable
。
代码和例子看起来一样,但是我想知道为什么会产生这样的错误?
您的代码有两个问题:
首先,获取底层 PyTorch 数据加载器的方式是 dm.train_dataloader()
而不是 dm.train_loader
。它是一个 函数,而不是 属性。
for batch in dm.train_dataloader():
x, y = batch
...
其次,由于您尝试使用没有 Trainer
的 LightningDataModule
,您需要手动调用
dm.prepare_data()
dm.setup()
.. 为了让数据加载器可以通过 .train_dataloader()
.