根据另一个张量中的索引将张量中的每个值映射到新值

Mapping each value in a tensor to a new value depending on its index in another tensor

我正在使用 Tensorflow 2.0。我有一个值在 0 到 255 范围内的 (256 x 256) 张量,我们称它为 gray。每个值都是 10 个唯一值之一。我有另一个张量,uniqueValues,包含 10 个唯一值。我试图找到一种方法来创建一个新的 (256 x 256) 张量,result 其中 result 的第 i,j 个值等于 uniqueValues 的索引,其中gray 的第 i,j 个值出现:

  gray = tf.image.decode_png(png, channels=1)
  flattened = tf.reshape(gray, [-1])

  # creates a tensor of length 10 holding each unique value
  uniqueValues, idx = tf.unique(flattened)
  gray = tf.reshape(gray, (256, 256))

  # Convert the gray (256x256) tensor...
  # [[255 255 255 ... 255
  # ...
  #  255 15 15 ... 200]]

  # using 'uniqueValues'...
  # [ 15 200 255 ]

  # To result (256x256) tensor...
  # [[2 2 2 ... 2
  # ...
  #  2 0 0 ... 1 ]]

  # possibly using the tf.map_fn?
  result = tf.map_fn( # how to do this part?, gray)

  # now I can create the one-hot version of gray
  oneHot = tf.one_hot(result, 10)

一直在研究 tf.wheretf.equal,但我似乎无法让它工作。

为了防止其他人遇到这个问题,这里有一个基于使用 StaticHashTable:

的解决方案
import tensorflow as tf

# define mapping from keys to values...
lookupTable = tf.lookup.StaticHashTable(
    initializer=tf.lookup.KeyValueTensorInitializer(
        keys=tf.constant([0, 76, 78, 117, 178, 202, 211, 225, 242, 255]),
        values=tf.constant([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
    ),
    default_value=tf.constant(0)
)

  gray = tf.image.decode_png(png, channels=1)

  # cast source from uint8 to int32 because StaticHashMap only works 
  # with restricted set of types
  gray = tf.dtypes.cast(tf.reshape(gray, (256, 256)), tf.int32)

  # voila, works like a charm!
  result = tf.map_fn(lambda x: lookupTable.lookup(x), gray)
  oneHot = tf.one_hot(result, 10)