Python Tensorflow 数据集过滤器集 .issubset()
Python Tensorflow Dataset Filter Set .issubset()
我有一个张量流数据集:
def fake_sequence():
seq = [np.random.choice(["A", "B", "C", "D"]) for _ in range(100)]
mutate = [np.random.choice(["E", "F", "G", "H"]) for _ in range(100)]
mask = np.random.choice(a=[True, False], size=100, p=[0.999, 0.001])
return "".join(np.where(mask, seq, mutate))
seqs = [fake_sequence() for _ in range(100)]
ds = tf.data.Dataset.from_tensor_slices(seqs)
我想使用以下 pythonic 函数对其进行过滤:
def python_filter(x):
x = set(x)
x = x.issubset({"A", "B", "C", "D"})
return x
不幸的是,用 @tf.function
装饰不起作用。有哪位大侠能帮帮我吗?这是我目前所拥有的。
def filter(x):
x = tf.strings.bytes_split(x)
x = tf.unique(x)[0]
# tensorflow function for x.issubset({"A", "B", "C", "D"})
return x
ds = ds.filter(filter)
您可以使用 tf.lookup.StaticHashTable
和 tf.cond
来解决您想要的问题:
import tensorflow as tf
import numpy as np
def fake_sequence():
seq = [np.random.choice(["A", "B", "C", "D"]) for _ in range(100)]
mutate = [np.random.choice(["E", "F", "G", "H"]) for _ in range(100)]
mask = np.random.choice(a=[True, False], size=100, p=[0.999, 0.001])
return "".join(np.where(mask, seq, mutate))
seqs = [fake_sequence() for _ in range(100)]
ds = tf.data.Dataset.from_tensor_slices(seqs)
keys_tensor = tf.constant(['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H'])
vals_tensor = tf.constant([1, 2, 3, 4, 5, 6, 7, 8])
table = tf.lookup.StaticHashTable(
tf.lookup.KeyValueTensorInitializer(keys_tensor, vals_tensor),
default_value=-1)
def filter(x):
subset = tf.constant(["A", "B", "C", "D"])
x = tf.strings.bytes_split(x)
x = tf.unique(x)[0]
x, y = tf.sort(table.lookup(x)), tf.sort(table.lookup(subset))
return tf.cond(tf.shape(x)[0]>tf.shape(y)[0], lambda: False, lambda: tf.reduce_all(tf.equal(x, y)))
ds = ds.map(filter)
for x in ds.take(5):
print(x)
tf.lookup.StaticHashTable
只是将所有字母映射为整数值,更容易比较。
我有一个张量流数据集:
def fake_sequence():
seq = [np.random.choice(["A", "B", "C", "D"]) for _ in range(100)]
mutate = [np.random.choice(["E", "F", "G", "H"]) for _ in range(100)]
mask = np.random.choice(a=[True, False], size=100, p=[0.999, 0.001])
return "".join(np.where(mask, seq, mutate))
seqs = [fake_sequence() for _ in range(100)]
ds = tf.data.Dataset.from_tensor_slices(seqs)
我想使用以下 pythonic 函数对其进行过滤:
def python_filter(x):
x = set(x)
x = x.issubset({"A", "B", "C", "D"})
return x
不幸的是,用 @tf.function
装饰不起作用。有哪位大侠能帮帮我吗?这是我目前所拥有的。
def filter(x):
x = tf.strings.bytes_split(x)
x = tf.unique(x)[0]
# tensorflow function for x.issubset({"A", "B", "C", "D"})
return x
ds = ds.filter(filter)
您可以使用 tf.lookup.StaticHashTable
和 tf.cond
来解决您想要的问题:
import tensorflow as tf
import numpy as np
def fake_sequence():
seq = [np.random.choice(["A", "B", "C", "D"]) for _ in range(100)]
mutate = [np.random.choice(["E", "F", "G", "H"]) for _ in range(100)]
mask = np.random.choice(a=[True, False], size=100, p=[0.999, 0.001])
return "".join(np.where(mask, seq, mutate))
seqs = [fake_sequence() for _ in range(100)]
ds = tf.data.Dataset.from_tensor_slices(seqs)
keys_tensor = tf.constant(['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H'])
vals_tensor = tf.constant([1, 2, 3, 4, 5, 6, 7, 8])
table = tf.lookup.StaticHashTable(
tf.lookup.KeyValueTensorInitializer(keys_tensor, vals_tensor),
default_value=-1)
def filter(x):
subset = tf.constant(["A", "B", "C", "D"])
x = tf.strings.bytes_split(x)
x = tf.unique(x)[0]
x, y = tf.sort(table.lookup(x)), tf.sort(table.lookup(subset))
return tf.cond(tf.shape(x)[0]>tf.shape(y)[0], lambda: False, lambda: tf.reduce_all(tf.equal(x, y)))
ds = ds.map(filter)
for x in ds.take(5):
print(x)
tf.lookup.StaticHashTable
只是将所有字母映射为整数值,更容易比较。