在 TensorFlow 模型中的每一行上使用 softmax 激活输出矩阵

Outputting matrix with softmax activation on each row in TensorFlow model

我正在构建一个 TensorFlow (2.0) 模型,该模型将图像(30x100 矩阵)作为输入并希望能够得到形式为

的输出(和标签)
[
  [0.0, 0.3, 0.7, 0.0],
  [1.0, 0.0, 0.0, 0.0],
  [0.2, 0.2, 0.4, 0.2]
]

即每一行都有一个单独的 softmax 激活,这意味着它们总和为 1。在训练数据中,每行中只有一个元素为 1,其余为 0,我相信这被称为 one-hot 编码。因此我的问题是;如何配置最后一层(和损失函数)以在我的标签上进行单热编码?

假设行彼此独立,在我看来你的情况可以是多输出情况。您可以有 3 个输出,每个输出有 5 个值。

input1 = Input(shape = (input_shape))
# some layers
x = Dense(512, activation='relu')(input1)
# .....

outputs1 = Dense(5, activation='softmax', name="row1")(x)
outputs2 = Dense(5, activation='softmax', name="row2")(x)
outputs3 = Dense(5, activation='softmax', name="row3")(x)
model = Model(input1, [outputs1, outputs2, outputs3])

它看起来是多标签分类的变体。为了方便起见,我们可以连接 3 个 softmax 层的输出并使用二元交叉熵损失。对于预测,我们可以重塑连接的输出。

工作代码

inputs = Input(shape=(2,))
output = Dense(4, activation='relu')(inputs)

output_1 = Dense(4, activation='softmax')(output)
output_2 = Dense(4, activation='softmax')(output)
output_3 = Dense(4, activation='softmax')(output)

# concatenate the outputs
output = concatenate([output_1, output_2, output_3], axis=1)

model = Model(inputs=inputs, outputs=output)
model.compile(loss='binary_crossentropy', optimizer='adam') 

# Two training examples each of 2 features
x = np.array([[1,2],
              [2,1]])

# Output labels
y = np.array([[[0,0,0,1],[0,0,1,0],[0,1,0,0]],
              [[0,0,1,0],[0,0,1,0],[0,0,0,1]]])

# Fatten the y per sample to match the model output shape
model.fit(x, y.reshape(len(x), -1))

# Predications
y_hat = model.predict(x)
y_hat = y_hat.reshape(len(x),-1,4)

print (y_hat)

输出:

array([[[0.0418521 , 0.63207364, 0.06171123, 0.26436302],
        [0.54364955, 0.19503883, 0.09884372, 0.16246797],
        [0.06045745, 0.09223039, 0.7325132 , 0.11479893]],

       [[0.05648099, 0.40420422, 0.12369599, 0.41561884],
        [0.64175993, 0.14215547, 0.07769462, 0.13838997],
        [0.07918497, 0.1764104 , 0.57678604, 0.1676186 ]]], dtype=float32)

您还可以使用 y_hat.sum(axis=2)

验证每一行总和是否为 1