如何按单个特征的值过滤张量流数据集

How to filter tensorflow dataset by value of single feature

如何根据单个特征的值过滤tensorflow数据集?

我花了很多时间了解如何使用 filter 方法过滤 tensorflow 数据集,不幸的是文档对我来说不够清晰https://www.tensorflow.org/api_docs/python/tf/data/Dataset#filter,也许对某些人有用。

在下面的示例中,目标是:Select 如果特征名称“Status”的值等于 'success' 并且特征名称 'Cost' 的值 >0。

dataset = tf.data.experimental.make_csv_dataset('file1.csv',....)
dataset = dataset.unbatch().filter(lambda x, y: True if x["Status"] == 'success' else False)
dataset = dataset.filter(lambda x, y: True if x["Cost"] > 0.0 else False)

你可以试试这样:

import tensorflow as tf
import pandas as pd

df = pd.DataFrame(data={'Status': ['Success', 'Failure','Failure', 'Success'], 'Cost': [0.0, 1.0, 1.0, 2.0]})
df.to_csv('data.csv', index=False)

dataset = tf.data.experimental.make_csv_dataset('/content/data.csv', batch_size=2, num_epochs = 1)
dataset = dataset.unbatch().filter(lambda x: x["Status"] == 'Success' and x["Cost"] > 0.0)

for x in dataset:
  print(x['Status'], x['Cost'])
tf.Tensor(b'Success', shape=(), dtype=string) tf.Tensor(2.0, shape=(), dtype=float32)