按值对 Tensorflow 哈希表进行排序

Sort Tensorflow HashTable by value

我的代码:

h_table = tf.lookup.StaticHashTable(
      initializer=tf.lookup.KeyValueTensorInitializer(
          keys=[0, 1, 2, 3, 4, 5],
          values=[12.3, 11.1, 51.5, 34.3, 87.3, 57.8]
      ),
      default_value=tf.constant(-1),
      name="h_table"
    )

我想按值对这个 h_table 进行排序,以便新哈希 table 是:

键 = [4, 5, 2, 3, 0, 1] 值 = [87.3, 57.8, 51.5, 34.3, 12.3, 11.1]

python 中的等效过程是:

h_table = { 0: 12.3, 1: 11.1, 2: 57.8, 3: 34.3, 4: 87.3, 5: 57.8}
h_sorted = dict(sorted(h_table.items(), key=lambda x: x[1], reverse=True))

我只想用tensor在tensorflow中实现这种字典操作?

您必须创建一个新的 tf.lookup.StaticHashTable,因为它一旦初始化就不可更改:

import tensorflow as tf

h_table = tf.lookup.StaticHashTable(
      initializer=tf.lookup.KeyValueTensorInitializer(
          keys=[0, 1, 2, 3, 4, 5],
          values=[12.3, 11.1, 51.5, 34.3, 87.3, 57.8]
      ),
      default_value=tf.constant(-1.),
      name="h_table"
    )

keys = h_table._initializer._keys
values = h_table._initializer._values

value_indices = tf.argsort(tf.reverse(values, axis=[0]), -1)
keys = tf.gather(keys, value_indices)

new_h_table = tf.lookup.StaticHashTable(
      initializer=tf.lookup.KeyValueTensorInitializer(
          keys=keys,
          values=h_table.lookup(keys)
      ),
      default_value=tf.constant(-1.),
      name="new_h_table"
    )

print(new_h_table._initializer._keys)
print(new_h_table._initializer._values)
tf.Tensor([4 5 2 3 0 1], shape=(6,), dtype=int32)
tf.Tensor([87.3 57.8 51.5 34.3 12.3 11.1], shape=(6,), dtype=float32)