PyTorch:为什么我的数据集 class 给出索引超出范围错误?
PyTorch: Why is my dataset class giving index out of range errors?
我想弄清楚为什么我的数据集给出了超出范围的索引错误。
考虑这个火炬数据集:
# prepare torch data set
class MSRH5Processor(torch.utils.data.Dataset):
def __init__(self, type, shard=False, **args):
# init configurable string
self.type = type
# init shard for sampling large ds if specified
self.shard = shard
# set seed if given
self.seed = args
# set loc
self.file_path = 'C:\data\h5py_embeds\'
# set file paths
self.val_embed_path = self.file_path + 'msr_dev_bert_embeds.h5'
# if true, initialize the dev data
if self.type == 'dev':
# embeds are shaped: [layers, tokens, features]
self.embeddings = h5py.File(self.val_embed_path, 'r')["embeds"]
def __len__(self):
return len(self.embeddings)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
if self.type == 'dev':
sample = {'embeddings': self.embeddings[idx]}
return sample
# load dataset
processor = MSRH5Processor(type='dev', shard=False)
# check length
len(processor) # 22425
# iterate over the samples
count = 0
for step, batch in enumerate(processor):
count += 1
# error: Index (22425) out of range (0-22424)
with h5py.File('C:\w266\h5py_embeds\msr_dev_bert_embeds.h5', 'r') as f:
print(f['embeds'].attrs['last_index']) # 22425
print(f['embeds'].shape) # (22425, 128, 768)
print(len(f['embeds'])) # 22425
如果我手动将数据集长度改为100
或22424
,我仍然会得到同样的错误。是什么告诉 PyTorch 寻找索引 22425?
如果我要制作一个 CSV 数据集,有 1000 个观察值(其中 len = 1000
),它会在 999 而不是 1000 时停止将索引输入到 __getitem__()
方法中。
编辑:
这似乎只是数据集 class 和 H5py 文件的问题。如果我使用 torch 数据加载器,它将 运行 我的数据集的自然长度。虽然,我很想知道 Torch 是如何为我的 H5 文件获取这个数字的,这导致它的行为与 CSV 不同。
要将 Dataset
用作可迭代对象,您必须使用 Sequence 语义实现 __iter__
方法或 __getitem__
。当方法 __getitem__
为某个索引 idx
引发 IndexError
时迭代停止
您的数据集的问题在于:
self.embeddings = h5py.File(self.val_embed_path, 'r')["embeds"]
实际上是 h5py._hl.dataset.Dataset
类型,在索引外请求时引发 ValueError
您必须在 class 构造函数中加载整个嵌入,以便在超出索引时访问 numpy 数组将引发 IndexError
或在 [= 上重新抛出 IndexError
17=] 在 __getitem__
我想弄清楚为什么我的数据集给出了超出范围的索引错误。
考虑这个火炬数据集:
# prepare torch data set
class MSRH5Processor(torch.utils.data.Dataset):
def __init__(self, type, shard=False, **args):
# init configurable string
self.type = type
# init shard for sampling large ds if specified
self.shard = shard
# set seed if given
self.seed = args
# set loc
self.file_path = 'C:\data\h5py_embeds\'
# set file paths
self.val_embed_path = self.file_path + 'msr_dev_bert_embeds.h5'
# if true, initialize the dev data
if self.type == 'dev':
# embeds are shaped: [layers, tokens, features]
self.embeddings = h5py.File(self.val_embed_path, 'r')["embeds"]
def __len__(self):
return len(self.embeddings)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
if self.type == 'dev':
sample = {'embeddings': self.embeddings[idx]}
return sample
# load dataset
processor = MSRH5Processor(type='dev', shard=False)
# check length
len(processor) # 22425
# iterate over the samples
count = 0
for step, batch in enumerate(processor):
count += 1
# error: Index (22425) out of range (0-22424)
with h5py.File('C:\w266\h5py_embeds\msr_dev_bert_embeds.h5', 'r') as f:
print(f['embeds'].attrs['last_index']) # 22425
print(f['embeds'].shape) # (22425, 128, 768)
print(len(f['embeds'])) # 22425
如果我手动将数据集长度改为100
或22424
,我仍然会得到同样的错误。是什么告诉 PyTorch 寻找索引 22425?
如果我要制作一个 CSV 数据集,有 1000 个观察值(其中 len = 1000
),它会在 999 而不是 1000 时停止将索引输入到 __getitem__()
方法中。
编辑:
这似乎只是数据集 class 和 H5py 文件的问题。如果我使用 torch 数据加载器,它将 运行 我的数据集的自然长度。虽然,我很想知道 Torch 是如何为我的 H5 文件获取这个数字的,这导致它的行为与 CSV 不同。
要将 Dataset
用作可迭代对象,您必须使用 Sequence 语义实现 __iter__
方法或 __getitem__
。当方法 __getitem__
为某个索引 idx
IndexError
时迭代停止
您的数据集的问题在于:
self.embeddings = h5py.File(self.val_embed_path, 'r')["embeds"]
实际上是 h5py._hl.dataset.Dataset
类型,在索引外请求时引发 ValueError
您必须在 class 构造函数中加载整个嵌入,以便在超出索引时访问 numpy 数组将引发 IndexError
或在 [= 上重新抛出 IndexError
17=] 在 __getitem__