如何在 TensorFlow 中将 multi-class one-hot 张量转换为 RGB?

How to convert multi-class one-hot tensor to RGB in TensorFlow?

我有一个形状为 [None, 128, 128, n_classes] 的张量。这是一个 one-hot tensor,其中最后一个索引包含多个 class 的分类值(总共有 n_classes)。 实际上,最后一个通道具有二进制值,表示每个像素的 class:例如当一个像素在通道C中有1时,表示它属于class C;该像素在其他地方将为 0。

现在,我想将这个单热张量转换为 RGB 图像,我想在 Tensorboard 上绘制它。每个 class 都必须与不同的颜色相关联,以便更容易理解。

知道怎么做吗?

谢谢,G.


编辑 2:

答案中添加了解决方案。


编辑 1:

我当前的实现(不工作):

def from_one_hot_to_rgb(incoming, palette=None):
    """ Assign a different color to each class in the input tensor """
    if palette is None:
        palette = {
            0: (0, 0, 0),
            1: (31, 12, 33),
            2: (13, 26, 33),
            3: (21, 76, 22),
            4: (22, 54, 66)
        }

    def _colorize(value):
        return palette[value]

    # from one-hot to grayscale:
    cmap = tf.expand_dims(tf.argmax(incoming, axis=-1), axis=-1)

    # flatten input tensor (pixels on the first axis):
    B, W, H, C = get_shape(camp)  # this returns batch_size, 128, 128, 5
    cmap_flat = tf.reshape(cmap, shape=[B * W * H, C])

    # assign a different color to each class:
    cmap = tf.map_fn(lambda pixel:
                     tf.py_func(_colorize, inp=[pixel], Tout=tf.int64),
                     cmap_flat)

    # back to original shape, but RGB output:
    cmap = tf.reshape(cmap, shape=[B, W, H, 3])

    return tf.cast(cmap, dtype=tf.float32)

我会用imshow* or matshow* from matplotlib to create the plot and then use 或同一问题的其他答案在张量板上显示。

import matplotlib.pyplot as plt

plt.imshow(tf.argmax(imgs[0], axis=-1))

这种方法的优点之一是您不必担心 class 到颜色的映射。


修复你已有的代码,首先你应该注意传递给colorize的参数是一个长度为1的numpy数组而不是一个int;这是不可哈希的,因此不能用于字典键。您可以将其转换为 int 类型,就像 palette[int(value)].

我在这里和那里更改了您的代码中的一些内容,并在大小为 1 的随机批次上对其进行了测试,最终代码如下所示:

def from_one_hot_to_rgb(incoming, palette=None):
    """ Assign a different color to each class in the input tensor """
    if palette is None:
        palette = {i: tf.constant(color, dtype='int64') for i, color in enumerate(
            ((0, 0, 0),
            (31, 12, 33),
            (13, 26, 33),
            (21, 76, 22),
            (22, 54, 66))
        )}

    # from one-hot to grayscale:
    B, W, H, _ = incoming.get_shape()   # this returns batch_size, 128, 128, 5
    cmap = tf.reshape(tf.argmax(incoming, axis=-1), [-1, 1])
    cmap = tf.map_fn(lambda value: palette[int(value)], cmap)

    # back to original shape, but RGB output:
    cmap = tf.reshape(cmap, shape=[B, W, H, 3])

    return tf.cast(cmap, dtype=tf.float32)

解决方案 1(慢)

一个可能的解决方案,类似于初始代码如下。请注意,由于 TensorFlow tf.map_fn

的已知 problem,这可能会非常慢
def from_one_hot_to_rgb_bkup(incoming, palette=None):

    if palette is None:
        palette = {i: tf.constant(color, dtype='int64') for i, color in enumerate(
            ((0, 0, 0),
            (31, 12, 33),
            (13, 26, 33),
            (21, 76, 22),
            (22, 54, 66))
        )}

    # from one-hot to grayscale:
    B, W, H, _ = get_shape(incoming)
    gray = tf.reshape(tf.argmax(incoming, axis=-1, output_type=tf.int32), [-1, 1], name='flatten')

    # assign colors to each class
    rgb = tf.map_fn(lambda pixel:
                    tf.py_func(lambda value: palette[int(value)], inp=[pixel], Tout=tf.int32),
                    gray, name='colorize')

    # back to original shape, but RGB output:
    rgb = tf.reshape(rgb, shape=[B, W, H, 3], name='back_to_rgb')

    return tf.cast(rgb, dtype=tf.float32)

解决方案 2(快速)

基于答案,更快的解决方案可以使用tf.gather:

def from_one_hot_to_rgb_bkup(incoming, palette=None):

    if palette is None:
        palette = {i: tf.constant(color, dtype='int64') for i, color in enumerate(
            ((0, 0, 0),
            (31, 12, 33),
            (13, 26, 33),
            (21, 76, 22),
            (22, 54, 66))
        )}

    _, W, H, _ = get_shape(incoming)
    palette = tf.constant(palette, dtype=tf.uint8)
    class_indexes = tf.argmax(incoming, axis=-1)

    class_indexes = tf.reshape(class_indexes, [-1])
    color_image = tf.gather(palette, class_indexes)
    color_image = tf.reshape(color_image, [-1, W, H, 3])

    color_image = tf.cast(color_image, dtype=tf.float32)