Python Tensorflow 数据集过滤器集 .issubset()

Python Tensorflow Dataset Filter Set .issubset()

我有一个张量流数据集:

def fake_sequence():
    seq = [np.random.choice(["A", "B", "C", "D"]) for _ in range(100)]
    mutate = [np.random.choice(["E", "F", "G", "H"]) for _ in range(100)]
    mask = np.random.choice(a=[True, False], size=100, p=[0.999, 0.001])
    return "".join(np.where(mask, seq, mutate))


seqs = [fake_sequence() for _ in range(100)]
ds = tf.data.Dataset.from_tensor_slices(seqs)

我想使用以下 pythonic 函数对其进行过滤:

def python_filter(x):
   x = set(x)
   x = x.issubset({"A", "B", "C", "D"})
   return x

不幸的是,用 @tf.function 装饰不起作用。有哪位大侠能帮帮我吗?这是我目前所拥有的。

def filter(x):
    x = tf.strings.bytes_split(x)
    x = tf.unique(x)[0]
    # tensorflow function for x.issubset({"A", "B", "C", "D"})
    return x

ds = ds.filter(filter)

您可以使用 tf.lookup.StaticHashTabletf.cond 来解决您想要的问题:

import tensorflow as tf
import numpy as np

def fake_sequence():
    seq = [np.random.choice(["A", "B", "C", "D"]) for _ in range(100)]
    mutate = [np.random.choice(["E", "F", "G", "H"]) for _ in range(100)]
    mask = np.random.choice(a=[True, False], size=100, p=[0.999, 0.001])
    return "".join(np.where(mask, seq, mutate))


seqs = [fake_sequence() for _ in range(100)]
ds = tf.data.Dataset.from_tensor_slices(seqs)

keys_tensor = tf.constant(['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H'])
vals_tensor = tf.constant([1, 2, 3, 4, 5, 6, 7, 8])
table = tf.lookup.StaticHashTable(
    tf.lookup.KeyValueTensorInitializer(keys_tensor, vals_tensor),
    default_value=-1)

def filter(x):
    subset = tf.constant(["A", "B", "C", "D"])
    x = tf.strings.bytes_split(x)
    x = tf.unique(x)[0]
    x, y = tf.sort(table.lookup(x)), tf.sort(table.lookup(subset))
    return tf.cond(tf.shape(x)[0]>tf.shape(y)[0], lambda: False, lambda: tf.reduce_all(tf.equal(x, y)))

ds = ds.map(filter)
for x in ds.take(5):
  print(x)

tf.lookup.StaticHashTable只是将所有字母映射为整数值,更容易比较。