ValueError: `validation_data` should be a tuple `(val_x, val_y, val_sample_weight)` or `(val_x, val_y)`. Found: <__main__.Generator object at>

ValueError: `validation_data` should be a tuple `(val_x, val_y, val_sample_weight)` or `(val_x, val_y)`. Found: <__main__.Generator object at>

我知道这个问题已经被问过几次了,但没有一个答案符合我的要求。

我有一个包含文本(报纸内容)和标签、第 0 列和第 1 列的 csv 文件。

我正在尝试为文本 classification 编写我的第一个自定义生成器,但出现错误

ValueError: `validation_data` should be a tuple `(val_x, val_y, val_sample_weight)` or `(val_x, val_y)`. Found: <__main__.Generator object at 0xd376a6e80>

这里是 class

class Generator(object):

    def __init__(self, data_file):
        self.data_file = data_file
        self.length = -1

    def __iter__(self):
        while True:
            with open(self.data_file, 'r') as f:
                reader = csv.reader(f)
                for row in reader:
                    yield row[0], row[1]

    def __len__(self):
        if self.length ==  -1:
            n_rows = 0
            with open(self.data_file, 'r') as f:
                reader = csv.reader(f)
                for row in reader:
                    n_rows += 1
            self.length = n_rows
        return self.length

我也用 yield row[0], row[1] 试过,还有 return。都没有用。

感谢帮助

在我让生成器 class 继承 keras.utils.Sequence 的方法之前,我遇到了同样的错误(参见 fit_generator documentation)。你可以试试这个:

import keras

class Generator(keras.utils.Sequence):
    def __init__(self, data_file):
        self.data_file = data_file
        self.length = -1

    def __iter__(self):
        while True:
            with open(self.data_file, 'r') as f:
                reader = csv.reader(f)
                for row in reader:
                    yield row[0], row[1]

    def __len__(self):
        if self.length ==  -1:
            n_rows = 0
            with open(self.data_file, 'r') as f:
                reader = csv.reader(f)
                for row in reader:
                    n_rows += 1
            self.length = n_rows
        return self.length