迭代 Torchtext.data.BucketIterator 对象抛出 AttributeError 'Field' 对象没有属性 'vocab'
Iterating over Torchtext.data.BucketIterator object throws AttributeError 'Field' object has no attribute 'vocab'
当我尝试查看批次时,通过打印 BucketIterator
对象的下一次迭代,抛出 AttributeError
。
tv_datafields=[("Tweet",TEXT), ("Anger",LABEL), ("Fear",LABEL), ("Joy",LABEL), ("Sadness",LABEL)]
train, vld = data.TabularDataset.splits(path="./data/", train="train.csv",validation="test.csv",format="csv", fields=tv_datafields)
train_iter, val_iter = BucketIterator.splits(
(train, vld),
batch_sizes=(64, 64),
device=-1,
sort_key=lambda x: len(x.Tweet),
sort_within_batch=False,
repeat=False
)
print(next(iter(train_dl)))
我不确定您遇到的具体错误,但在这种情况下,您可以使用以下代码迭代批次:
for i in train_iter:
print i.Tweet
print i.Anger
print i.Fear
print i.Joy
print i.Sadness
i.Tweet
(还有其他)是形状为 (input_data_length, batch_size)
.
的张量
因此,要查看单个批次数据(比方说批次 0),您可以执行 print i.Tweet[:,0]
。
val_iter
(如果需要,test_iter
)也是如此。
当我尝试查看批次时,通过打印 BucketIterator
对象的下一次迭代,抛出 AttributeError
。
tv_datafields=[("Tweet",TEXT), ("Anger",LABEL), ("Fear",LABEL), ("Joy",LABEL), ("Sadness",LABEL)]
train, vld = data.TabularDataset.splits(path="./data/", train="train.csv",validation="test.csv",format="csv", fields=tv_datafields)
train_iter, val_iter = BucketIterator.splits(
(train, vld),
batch_sizes=(64, 64),
device=-1,
sort_key=lambda x: len(x.Tweet),
sort_within_batch=False,
repeat=False
)
print(next(iter(train_dl)))
我不确定您遇到的具体错误,但在这种情况下,您可以使用以下代码迭代批次:
for i in train_iter:
print i.Tweet
print i.Anger
print i.Fear
print i.Joy
print i.Sadness
i.Tweet
(还有其他)是形状为 (input_data_length, batch_size)
.
因此,要查看单个批次数据(比方说批次 0),您可以执行 print i.Tweet[:,0]
。
val_iter
(如果需要,test_iter
)也是如此。