AttributeError: 'DataModuleClass' object has no attribute 'training_dataset'

AttributeError: 'DataModuleClass' object has no attribute 'training_dataset'

我正在尝试通过编写一个非常简单的 DataModuleClass 来学习 PyTorch Lightning。在 prepare_data()setup() 之后,我正在尝试检查这些功能是否正常工作。所以,我试图从 setup() 获取 trainingvalidation 数据集。但是我收到一个错误

AttributeError: 'DataModuleClass' object has no attribute 'training_dataset'

代码

def prepare_data(self):
    x = np.random.uniform(0, 10, 10)
    e = np.random.normal(0, self.sigma, len(x))
    
    # Making target or labels
    y = x + e
    
    # Marging x and e for 2 features
    X = np.transpose(np.array([x, e]))

    # Converting numpy array to Tensor
    self.x_train_tensor = torch.from_numpy(X).float().to(device)
    self.y_train_tensor = torch.from_numpy(y).float().to(device)
    
    training_dataset = TensorDataset(self.x_train_tensor, self.y_train_tensor)

    self.training_dataset = training_dataset

def setup(self):
    data = self.training_dataset
    self.train_data, self.val_data = random_split(data, [8, 2])
    return self.train_data, self.val_data
    
    
def train_dataloader(self):
    return DataLoader(self.train_data)

def val_dataloader(self):
    return DataLoader(self.val_data)
    
obj = DataModuleClass()
print(obj.setup())  

你能告诉我为什么会出现这个错误吗?

从我看来代码的方式来看。

DataModuleClass的变量self.training_datasetprepare_data中初始化,setup在第一行需要它。

但是你调用了 setup 而没有调用 training_dataset

如果每次创建 DataModuleClass 对象时都希望调用 prepare_data,那么最好将 prepare_data 放在 __init__ 中。喜欢

def __init__(self, other_params):
    ..... all your code previously in __init__
    self.prepare_data()  # put this in the last line of this function

但如果您不需要,则需要在 setup

之前调用 prepare_data
obj = DataModuleClass()
obj.prepare_data()
print(obj.setup())  

或将 prepare_data 放入 setup 本身。

def setup(self):
    self.prepare_data()
    data = self.training_dataset
    self.train_data, self.val_data = random_split(data, [8, 2])
    return self.train_data, self.val_data

编辑 1:查看 self.train_dataself.val_data

的实际值

setup 返回的对象是 torch.utils.data.dataset.Subset。 基本上有两种获取张量的方法。

1。像列表一样对待它们

train_data, val_data = obj.setup()
print(train_data[0])

2。使用 for 循环

train_data, val_data = obj.setup()
for data in train_data:
    print(data)

备注

我不确定你会得到张量或 TensorDataset。如果是后者,则再次使用相同的技巧,例如

train_data, val_data = obj.setup()
train_tensor_data = train_data[0]
print(train_tensor_data[0])