你如何解码 Tensorflow 中的 one-hot 标签?

How do you decode one-hot labels in Tensorflow?

一直在寻找,但似乎找不到任何示例说明如何在 TensorFlow 中从单热值解码或转换回单个整数。

我使用了 tf.one_hot 并能够训练我的模型,但我对如何在分类后理解标签感到困惑。我的数据是通过我创建的 TFRecords 文件输入的。我考虑过在文件中存储一个文本标签,但无法让它工作。似乎 TFRecords 无法存储文本字符串或者我弄错了。

您可以使用tf.argmax找出矩阵中最大元素的索引。由于您的一个热向量将是一维的,并且只有一个 1 和其他 0,因此假设您正在处理单个向量,这将起作用。

index = tf.argmax(one_hot_vector, axis=0)

对于batch_size * num_classes更标准的矩阵,使用axis=1得到大小为batch_size * 1的结果。

由于 one-hot 编码通常只是一个具有 batch_size 行和 num_classes 列的矩阵,并且每一行都是零,并且有一个非零值对应于所选择的 class,你可以使用tf.argmax()来恢复整数标签的向量:

BATCH_SIZE = 3
NUM_CLASSES = 4
one_hot_encoded = tf.constant([[0, 1, 0, 0],
                               [1, 0, 0, 0],
                               [0, 0, 0, 1]])

# Compute the argmax across the columns.
decoded = tf.argmax(one_hot_encoded, axis=1)

# ...
print sess.run(decoded)  # ==> array([1, 0, 3])
data = np.array([1, 5, 3, 8])
print(data)


def encode(data):
    print('Shape of data (BEFORE encode): %s' % str(data.shape))
    encoded = to_categorical(data)
    print('Shape of data (AFTER  encode): %s\n' % str(encoded.shape))
    return encoded


encoded_data = encode(data)
print(encoded_data)

def decode(datum):
    return np.argmax(datum)

decoded_Y = []
print("****************************************")
for i in range(encoded_data.shape[0]):
    datum = encoded_data[i]
    print('index: %d' % i)
    print('encoded datum: %s' % datum)
    decoded_datum = decode(encoded_data[i])
    print('decoded datum: %s' % decoded_datum)
    decoded_Y.append(decoded_datum)


print("****************************************")

print(decoded_Y)

tf.argmax is depreciated (all links within the answers on this page are thus 404) and now tf.math.argmax should be used .

用法:

import tensorflow as tf
a = [1, 10, 26.9, 2.8, 166.32, 62.3]
b = tf.math.argmax(input = a)
c = tf.keras.backend.eval(b)
# c = 4
# here a[4] = 166.32 which is the largest element of a across axis 0

注意:你也可以用numpy来做到这一点。