在 TensorFlow 中查找 table,键为字符串,值为字符串列表
Lookup table in TensorFlow with key is string and value is list of strings
我想在 TensorFlow 中生成一个查找 table,键是字符串,值是字符串列表。但是目前 tf.lookup 中似乎没有 类 支持这个。有什么想法吗?
我认为没有针对该用例的实现,但您可以尝试结合 tf.lookup.StaticHashTable
和 tf.gather
来创建您自己的自定义查找 table。您只需要确保您的键和字符串列表的顺序正确。例如keya
对应第一个字符串列表,keyb
对应第二个字符串列表,以此类推。这是一个工作示例:
class TensorLookup:
def __init__(self, keys, strings):
self.keys = keys
self.strings = strings
self.table = tf.lookup.StaticHashTable(
tf.lookup.KeyValueTensorInitializer(self.keys, tf.range(tf.shape(self.keys)[0])),
default_value=-1)
def lookup(self, key):
index = self.table.lookup(key)
return tf.cond(tf.reduce_all(tf.equal(index, -1)), lambda: tf.constant(['']), lambda: tf.gather(self.strings, index))
keys = tf.constant(['a', 'b', 'c', 'd', 'e'])
strings = tf.ragged.constant([['fish', 'eating', 'cats'],
['cats', 'everywhere'],
['you', 'are', 'a', 'fine', 'lad'],
['a', 'mountain', 'over', 'there'],
['bravo', 'at', 'charlie']
])
tensor_dict = TensorLookup(keys = keys, strings = strings)
print(tensor_dict.lookup(tf.constant('a')))
print(tensor_dict.lookup(tf.constant('b')))
print(tensor_dict.lookup(tf.constant('c')))
print(tensor_dict.lookup(tf.constant('r'))) # expected empty value since the r key does not exist
tf.Tensor([b'fish' b'eating' b'cats'], shape=(3,), dtype=string)
tf.Tensor([b'cats' b'everywhere'], shape=(2,), dtype=string)
tf.Tensor([b'you' b'are' b'a' b'fine' b'lad'], shape=(5,), dtype=string)
tf.Tensor([b''], shape=(1,), dtype=string)
我故意使用参差不齐的张量来适应不同长度的字符串列表。
我想在 TensorFlow 中生成一个查找 table,键是字符串,值是字符串列表。但是目前 tf.lookup 中似乎没有 类 支持这个。有什么想法吗?
我认为没有针对该用例的实现,但您可以尝试结合 tf.lookup.StaticHashTable
和 tf.gather
来创建您自己的自定义查找 table。您只需要确保您的键和字符串列表的顺序正确。例如keya
对应第一个字符串列表,keyb
对应第二个字符串列表,以此类推。这是一个工作示例:
class TensorLookup:
def __init__(self, keys, strings):
self.keys = keys
self.strings = strings
self.table = tf.lookup.StaticHashTable(
tf.lookup.KeyValueTensorInitializer(self.keys, tf.range(tf.shape(self.keys)[0])),
default_value=-1)
def lookup(self, key):
index = self.table.lookup(key)
return tf.cond(tf.reduce_all(tf.equal(index, -1)), lambda: tf.constant(['']), lambda: tf.gather(self.strings, index))
keys = tf.constant(['a', 'b', 'c', 'd', 'e'])
strings = tf.ragged.constant([['fish', 'eating', 'cats'],
['cats', 'everywhere'],
['you', 'are', 'a', 'fine', 'lad'],
['a', 'mountain', 'over', 'there'],
['bravo', 'at', 'charlie']
])
tensor_dict = TensorLookup(keys = keys, strings = strings)
print(tensor_dict.lookup(tf.constant('a')))
print(tensor_dict.lookup(tf.constant('b')))
print(tensor_dict.lookup(tf.constant('c')))
print(tensor_dict.lookup(tf.constant('r'))) # expected empty value since the r key does not exist
tf.Tensor([b'fish' b'eating' b'cats'], shape=(3,), dtype=string)
tf.Tensor([b'cats' b'everywhere'], shape=(2,), dtype=string)
tf.Tensor([b'you' b'are' b'a' b'fine' b'lad'], shape=(5,), dtype=string)
tf.Tensor([b''], shape=(1,), dtype=string)
我故意使用参差不齐的张量来适应不同长度的字符串列表。