Python Tensorflow itertools groupby:在 tf.data.Dataset.filter() 中使用 itertools.groupby()
Python Tensorflow itertools groupby: using itertools.groupby() in tf.data.Dataset.filter()
我正在尝试对 tf.data.Dataset
应用过滤器,它会删除其中一组 > 50% 的所有字符串。这是我的 Dataset
:
import tensorflow as tf
strings = [
["ABCDEFGABCDEFG\tUseless\tLabel1"],
["AAAAAAAADEFGAB\tUseless\tLabel2"],
["HIJKLMNHIJKLMN\tUseless\tLabel3"],
["HIJKLMMMMMMMNH\tUseless\tLabel4"],
]
ds = tf.data.Dataset.from_tensor_slices(strings)
def _clean(x):
x = tf.strings.split(x, "\t")
return x[0], x[2]
def _filter(x):
s = tf.strings.bytes_split(x)
_, _, count = tf.unique_with_counts(s)
percent = tf.reduce_max(count) / tf.shape(s)[0]
return tf.less_equal(percent, 0.5)
ds = ds.map(_clean)
ds = ds.filter(lambda x, y: _filter(x))
for x, y in ds:
tf.print(x, y)
这会产生以下错误:
TypeError: Failed to convert elements of tf.RaggedTensor(values=Tensor("StringsByteSplit/StringSplit:1", shape=(None,), dtype=string), row_splits=Tensor("StringsByteSplit/RaggedFromValueRowIds/RowPartitionFromValueRowIds/concat:0", shape=(None,), dtype=int64)) to Tensor. Consider casting elements to a supported type.
有什么方法可以在 tf.data.Dataset
图中解决这个问题?
您可以使用 tf.strings
:
来解决这个问题
import tensorflow as tf
def filter_data(x):
s = tf.strings.strip(tf.strings.regex_replace(x, '', ' '))
s = tf.strings.split(s, sep=" ")
_, _, count = tf.unique_with_counts(s)
return tf.less_equal(tf.reduce_max(count) / tf.shape(s)[0], 0.25)
ds = tf.data.Dataset.from_tensor_slices([["AAAABBBCC", "Label1"], ["AAAAAABC", "Label2"], ["ABBAABCCCCAB", "Label3"], ["ABDC", "Label4"]])
ds = ds.map(lambda x: (x[0], x[1]))
ds = ds.filter(lambda x, y: filter_data(x))
for x, y in ds:
tf.print(x, y)
"ABDC" "Label4"
但是,我会重新考虑 25% 的阈值,因为示例数据集中的所有样本都高于此阈值,因此不会添加到数据集中。我已将第四个示例添加到您的数据集中,以表明该方法适用于 tf.less_equal
.
以AAAABBBCC
为例,A
出现次数最多(4次)除以字符串总长度(9),得到4/9=0.44
,意思是被排除在数据集中。也许这种行为是需要的。不管怎样,我只是想告诉你这件事。
我正在尝试对 tf.data.Dataset
应用过滤器,它会删除其中一组 > 50% 的所有字符串。这是我的 Dataset
:
import tensorflow as tf
strings = [
["ABCDEFGABCDEFG\tUseless\tLabel1"],
["AAAAAAAADEFGAB\tUseless\tLabel2"],
["HIJKLMNHIJKLMN\tUseless\tLabel3"],
["HIJKLMMMMMMMNH\tUseless\tLabel4"],
]
ds = tf.data.Dataset.from_tensor_slices(strings)
def _clean(x):
x = tf.strings.split(x, "\t")
return x[0], x[2]
def _filter(x):
s = tf.strings.bytes_split(x)
_, _, count = tf.unique_with_counts(s)
percent = tf.reduce_max(count) / tf.shape(s)[0]
return tf.less_equal(percent, 0.5)
ds = ds.map(_clean)
ds = ds.filter(lambda x, y: _filter(x))
for x, y in ds:
tf.print(x, y)
这会产生以下错误:
TypeError: Failed to convert elements of tf.RaggedTensor(values=Tensor("StringsByteSplit/StringSplit:1", shape=(None,), dtype=string), row_splits=Tensor("StringsByteSplit/RaggedFromValueRowIds/RowPartitionFromValueRowIds/concat:0", shape=(None,), dtype=int64)) to Tensor. Consider casting elements to a supported type.
有什么方法可以在 tf.data.Dataset
图中解决这个问题?
您可以使用 tf.strings
:
import tensorflow as tf
def filter_data(x):
s = tf.strings.strip(tf.strings.regex_replace(x, '', ' '))
s = tf.strings.split(s, sep=" ")
_, _, count = tf.unique_with_counts(s)
return tf.less_equal(tf.reduce_max(count) / tf.shape(s)[0], 0.25)
ds = tf.data.Dataset.from_tensor_slices([["AAAABBBCC", "Label1"], ["AAAAAABC", "Label2"], ["ABBAABCCCCAB", "Label3"], ["ABDC", "Label4"]])
ds = ds.map(lambda x: (x[0], x[1]))
ds = ds.filter(lambda x, y: filter_data(x))
for x, y in ds:
tf.print(x, y)
"ABDC" "Label4"
但是,我会重新考虑 25% 的阈值,因为示例数据集中的所有样本都高于此阈值,因此不会添加到数据集中。我已将第四个示例添加到您的数据集中,以表明该方法适用于 tf.less_equal
.
以AAAABBBCC
为例,A
出现次数最多(4次)除以字符串总长度(9),得到4/9=0.44
,意思是被排除在数据集中。也许这种行为是需要的。不管怎样,我只是想告诉你这件事。