将张量保存到 .pt 文件以创建数据集
Saving tensors to a .pt file in order to create a dataset
我的任务是创建一个数据集来测试我们正在处理的代码的功能。
数据集必须有一组张量,稍后将在生成模型中使用。
我正在尝试将张量保存到 .pt 文件,但我正在覆盖张量,因此创建了一个只有一个的文件。我读过 torch.utils.data.dataset
,但我无法自己弄清楚如何使用它。
这是我的代码:
import torch
import numpy as np
from torch.utils.data import Dataset
#variables that will be used to create the size of the tensors:
num_jets, num_particles, num_features = 1, 30, 3
for i in range(100):
#tensor from a gaussian dist with mean=5,std=1 and shape=size:
tensor = torch.normal(5,1,size=(num_jets, num_particles, num_features))
#We will need the tensors to be of the cpu type
tensor = tensor.cpu()
#save the tensor to 'tensor_dataset.pt'
torch.save(tensor,'tensor_dataset.pt')
#open the recently created .pt file inside a list
tensor_list = torch.load('tensor_dataset.pt')
#prints the list. Just one tensor inside .pt file
print(tensor_list)
原因:你在循环中每次都覆盖了张量x
,所以你没有得到你的列表,最后只有x。
解法:你有张量的大小,可以先初始化一个张量,然后遍历lst_tensors
:
import torch
import numpy as np
from torch.utils.data import Dataset
num_jets, num_particles, num_features = 1, 30, 3
lst_tensors = torch.empty(size=(100,num_jets, num_particles, num_features))
for i in range(100):
lst_tensors[i] = torch.normal(5,1,size=(num_jets, num_particles, num_features))
lst_tensors[i] = lst_tensors[i].cpu()
torch.save(lst_tensors,'tensor_dataset.pt')
tensor_list = torch.load('tensor_dataset.pt')
print(tensor_list.shape) # [100,1,30,3]
我的任务是创建一个数据集来测试我们正在处理的代码的功能。
数据集必须有一组张量,稍后将在生成模型中使用。
我正在尝试将张量保存到 .pt 文件,但我正在覆盖张量,因此创建了一个只有一个的文件。我读过 torch.utils.data.dataset
,但我无法自己弄清楚如何使用它。
这是我的代码:
import torch
import numpy as np
from torch.utils.data import Dataset
#variables that will be used to create the size of the tensors:
num_jets, num_particles, num_features = 1, 30, 3
for i in range(100):
#tensor from a gaussian dist with mean=5,std=1 and shape=size:
tensor = torch.normal(5,1,size=(num_jets, num_particles, num_features))
#We will need the tensors to be of the cpu type
tensor = tensor.cpu()
#save the tensor to 'tensor_dataset.pt'
torch.save(tensor,'tensor_dataset.pt')
#open the recently created .pt file inside a list
tensor_list = torch.load('tensor_dataset.pt')
#prints the list. Just one tensor inside .pt file
print(tensor_list)
原因:你在循环中每次都覆盖了张量x
,所以你没有得到你的列表,最后只有x。
解法:你有张量的大小,可以先初始化一个张量,然后遍历lst_tensors
:
import torch
import numpy as np
from torch.utils.data import Dataset
num_jets, num_particles, num_features = 1, 30, 3
lst_tensors = torch.empty(size=(100,num_jets, num_particles, num_features))
for i in range(100):
lst_tensors[i] = torch.normal(5,1,size=(num_jets, num_particles, num_features))
lst_tensors[i] = lst_tensors[i].cpu()
torch.save(lst_tensors,'tensor_dataset.pt')
tensor_list = torch.load('tensor_dataset.pt')
print(tensor_list.shape) # [100,1,30,3]