在 TensorFlow 中查找 table,键为字符串,值为字符串列表

Lookup table in TensorFlow with key is string and value is list of strings

我想在 TensorFlow 中生成一个查找 table,键是字符串,值是字符串列表。但是目前 tf.lookup 中似乎没有 类 支持这个。有什么想法吗?

我认为没有针对该用例的实现,但您可以尝试结合 tf.lookup.StaticHashTabletf.gather 来创建您自己的自定义查找 table。您只需要确保您的键和字符串列表的顺序正确。例如keya对应第一个字符串列表,keyb对应第二个字符串列表,以此类推。这是一个工作示例:

class TensorLookup:
  def __init__(self, keys, strings):
    self.keys = keys
    self.strings = strings
    self.table = tf.lookup.StaticHashTable(
    tf.lookup.KeyValueTensorInitializer(self.keys, tf.range(tf.shape(self.keys)[0])),
    default_value=-1)
  
  def lookup(self, key):
    index = self.table.lookup(key)
    return tf.cond(tf.reduce_all(tf.equal(index, -1)), lambda: tf.constant(['']), lambda: tf.gather(self.strings, index))

keys = tf.constant(['a', 'b', 'c', 'd', 'e'])
strings = tf.ragged.constant([['fish', 'eating', 'cats'], 
                              ['cats', 'everywhere'], 
                              ['you', 'are', 'a', 'fine', 'lad'], 
                              ['a', 'mountain', 'over', 'there'],
                              ['bravo', 'at', 'charlie'] 
                              ])

tensor_dict = TensorLookup(keys = keys, strings = strings)

print(tensor_dict.lookup(tf.constant('a')))
print(tensor_dict.lookup(tf.constant('b')))
print(tensor_dict.lookup(tf.constant('c')))
print(tensor_dict.lookup(tf.constant('r'))) # expected empty value since the r key does not exist
tf.Tensor([b'fish' b'eating' b'cats'], shape=(3,), dtype=string)
tf.Tensor([b'cats' b'everywhere'], shape=(2,), dtype=string)
tf.Tensor([b'you' b'are' b'a' b'fine' b'lad'], shape=(5,), dtype=string)
tf.Tensor([b''], shape=(1,), dtype=string)

我故意使用参差不齐的张量来适应不同长度的字符串列表。