如何从 FilterDataset/MapDataset 中删除张量

How can I remove a tensor from a FilterDataset/MapDataset

我有一个 video_id、user_id 和分数张量的数据集。我想将其过滤为仅得分高于阈值的正例,然后删除得分张量。

def decode_retrieval_positive(record_bytes):
    return tf.io.parse_single_example(
        # Data
        record_bytes,
        # Schema
        {"video_id": tf.io.FixedLenFeature([], dtype=tf.int64),
        "user_id": tf.io.FixedLenFeature([], dtype=tf.int64),
        "score": tf.io.FixedLenFeature([], dtype=tf.float32)}
    )

ratings_positive = ratings.map(
            decode_retrieval_positive
        ).filter(
            lambda x: x["score"] > 0.2
        ).map(
            lambda x: {"video_id": x["video_id"], "user_id": x["user_id"]}
        )

<MapDataset element_spec={'video_id': TensorSpec(shape=(), dtype=tf.int64, name=None), 'user_id': TensorSpec(shape=(), dtype=tf.int64, name=None)}>

这给了我这个错误:

2022-02-07 08:18:52.825318: W tensorflow/core/framework/op_kernel.cc:1745] OP_REQUIRES failed at example_parsing_ops.cc:94 : INVALID_ARGUMENT: Feature: score (data type: float) is required but could not be found.

一个解决方案是简单地制作一个新的 positive_ratings.tfrecord,但这会占用更多 space,我很生气我不能这样做。

您只需确保 x['score'] 具有浮点值。这是一个工作示例:

import tensorflow as tf
tf.random.set_seed(111)

# Create dummy data
dataset = tf.data.Dataset.range(10)
dataset = dataset.map(lambda x: {'video_id': x, 'user_id': 2555 + x, 'score': tf.cast(x, dtype=tf.float32)*tf.random.normal(())})

dataset = dataset.filter(lambda x: x["score"] > 0.2)
dataset = dataset.map(lambda x: {"video_id": x["video_id"], "user_id": x["user_id"]})
for d in dataset:
  print(d)
{'video_id': <tf.Tensor: shape=(), dtype=int64, numpy=5>, 'user_id': <tf.Tensor: shape=(), dtype=int64, numpy=2560>}