如何解决 train_y 循环中的 IndexError?

How can I resolve an IndexError in train_y cycle?

我不明白我的错误在哪里:

train_dataset = TensorDataset(input_ids, attention_masks, labels)

print('target train 0:', len(np.where(train_y == 0)[0]))
print('target train 1:', len(np.where(train_y == 1)[0]))
print('target train 2:', len(np.where(train_y == 2)[0]))
print('target train 3:', len(np.where(train_y == 3)[0]))
print('target train 4:', len(np.where(train_y == 4)[0]))
print('target train 5:', len(np.where(train_y == 5)[0]))
print('target train 6:', len(np.where(train_y == 6)[0]))

>> target train 0: 6834
target train 1: 1200
target train 2: 0
target train 3: 4397
target train 4: 1112
target train 5: 0
target train 6: 3281

'''How many examples do classes have?'''

class_sample_count = np.array(
    [len(np.where(train_y == t)[0]) for t in np.unique(train_y)])
print("How many examples do classes have?\n", class_sample_count)

# >> [6834 1200 4397 1112 3281]

weight = 1. / class_sample_count

print("weights: ", weight)

# >> weights: [0.00014633 0.00083333 0.00022743 0.00089928 0.00030479]

samples_weight = np.array([weight[t] for t in train_y])

# >> **IndexError: index 6 is out of bounds for axis 0 with size 5**

当您调用 np.unique(train_y) 时,您创建了一个包含 5 个元素(6834、1200、4397、1112、3281)的列表,因为第二个和第五个元素都等于 0。同时, train_y 仍然包含 7 个元素。

然后您遍历 train_y,包含 7 个元素,而在索引 6 上您恰好收到一个异常,因为 weight 包含 5 个元素。