如何在 PyTorch Lightning 中从 prepare_data() 获取数据集到 setup()
How to get dataset from prepare_data() to setup() in PyTorch Lightning
我使用 PyTorch Lightning 的 DataModules
方法在 prepare_data()
方法中使用 NumPy
创建了自己的数据集。现在,我想将数据传递到 setup()
方法以拆分为训练和验证。
import numpy as np
import pytorch_lightning as pl
from torch.utils.data import random_split, DataLoader, TensorDataset
import torch
from torch.autograd import Variable
from torchvision import transforms
np.random.seed(42)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
class DataModuleClass(pl.LightningDataModule):
def __init__(self):
super().__init__()
self.constant = 2
self.batch_size = 10
def prepare_data(self):
a = np.random.uniform(0, 500, 500)
b = np.random.normal(0, self.constant, len(a))
c = a + b
X = np.transpose(np.array([a, b]))
# Converting numpy array to Tensor
self.x_train_tensor = torch.from_numpy(X).float().to(device)
self.y_train_tensor = torch.from_numpy(c).float().to(device)
training_dataset = TensorDataset(self.x_train_tensor, self.y_train_tensor)
return training_dataset
def setup(self):
data = # What I have to write to get the data from prepare_data()
self.train_data, self.val_data = random_split(data, [400, 100])
def train_dataloader(self):
training_dataloader = setup() # Need to get the training data
return DataLoader(self.training_dataloader)
def val_dataloader(self):
validation_dataloader = prepare_data() # Need to get the validation data
return DataLoader(self.validation_dataloader)
obj = DataModuleClass()
print(obj.prepare_data())
与您上一个问题的答案相同...
def prepare_data(self):
a = np.random.uniform(0, 500, 500)
b = np.random.normal(0, self.constant, len(a))
c = a + b
X = np.transpose(np.array([a, b]))
# Converting numpy array to Tensor
self.x_train_tensor = torch.from_numpy(X).float().to(device)
self.y_train_tensor = torch.from_numpy(c).float().to(device)
training_dataset = TensorDataset(self.x_train_tensor, self.y_train_tensor)
self.training_dataset = training_dataset
def setup(self):
data = self.training_dataset
self.train_data, self.val_data = random_split(data, [400, 100])
def train_dataloader(self):
return DataLoader(self.train_data)
def val_dataloader(self):
return DataLoader(self.val_data)
我使用 PyTorch Lightning 的 DataModules
方法在 prepare_data()
方法中使用 NumPy
创建了自己的数据集。现在,我想将数据传递到 setup()
方法以拆分为训练和验证。
import numpy as np
import pytorch_lightning as pl
from torch.utils.data import random_split, DataLoader, TensorDataset
import torch
from torch.autograd import Variable
from torchvision import transforms
np.random.seed(42)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
class DataModuleClass(pl.LightningDataModule):
def __init__(self):
super().__init__()
self.constant = 2
self.batch_size = 10
def prepare_data(self):
a = np.random.uniform(0, 500, 500)
b = np.random.normal(0, self.constant, len(a))
c = a + b
X = np.transpose(np.array([a, b]))
# Converting numpy array to Tensor
self.x_train_tensor = torch.from_numpy(X).float().to(device)
self.y_train_tensor = torch.from_numpy(c).float().to(device)
training_dataset = TensorDataset(self.x_train_tensor, self.y_train_tensor)
return training_dataset
def setup(self):
data = # What I have to write to get the data from prepare_data()
self.train_data, self.val_data = random_split(data, [400, 100])
def train_dataloader(self):
training_dataloader = setup() # Need to get the training data
return DataLoader(self.training_dataloader)
def val_dataloader(self):
validation_dataloader = prepare_data() # Need to get the validation data
return DataLoader(self.validation_dataloader)
obj = DataModuleClass()
print(obj.prepare_data())
与您上一个问题的答案相同...
def prepare_data(self):
a = np.random.uniform(0, 500, 500)
b = np.random.normal(0, self.constant, len(a))
c = a + b
X = np.transpose(np.array([a, b]))
# Converting numpy array to Tensor
self.x_train_tensor = torch.from_numpy(X).float().to(device)
self.y_train_tensor = torch.from_numpy(c).float().to(device)
training_dataset = TensorDataset(self.x_train_tensor, self.y_train_tensor)
self.training_dataset = training_dataset
def setup(self):
data = self.training_dataset
self.train_data, self.val_data = random_split(data, [400, 100])
def train_dataloader(self):
return DataLoader(self.train_data)
def val_dataloader(self):
return DataLoader(self.val_data)