从数据集 Tensorflow 中删除坏数据

Drop bad data from dataset Tensorflow

我有一个使用 tf.data 的训练管道。在数据集中有一些坏元素,在我的例子中是 0 值。我如何根据它们的值删除这些坏数据元素?由于数据集很大,我希望能够在训练时从管道中删除它们。

根据以下伪代码假设:

def parse_function(element):
    height = element['height']
    if height <= 0: skip() #How to skip this value

    labels = element['label']
    features['height'] = height

    return features, labels

ds = tf.data.Dataset.from_tensor_slices(ds_files)
clean_ds = ds.map(parse_function)

建议是根据特征值使用 ds.skip(1),或者提供某种中性 weight/loss?

您可以使用 tf.data.Dataset.filter:

def filter_func(elem):
    """ return True if the element is to be kept """
    return tf.math.greater(elem['height'],0)

ds = tf.data.Dataset.from_tensor_slices(ds_files)
clean_ds = ds.filter(filter_func)

假设 element 是您代码中的数据框,那么它将是:

def parse_function(element):
    element = element.query('height>0')

    labels = element['label']
    features['height'] = element['height']

    return features, labels

ds = tf.data.Dataset.from_tensor_slices(ds_files)
clean_ds = ds.map(parse_function)

`