Keras 中的自定义损失与 softmax 到 one-hot

Custom loss in Keras with softmax to one-hot

我有一个输出 Softmax 的模型,我想开发一个自定义损失函数。期望的行为是:

1) Softmax 到 one-hot(通常我做 numpy.argmax(softmax_vector) 并在空向量中将该索引设置为 1,但这在损失函数中是不允许的)。

2) 将生成的单热向量乘以我的嵌入矩阵以获得嵌入向量(在我的上下文中:与给定词关联的词向量,其中词已被标记化并分配给索引,或 类 用于 Softmax 输出)。

3) 将这个向量与目标进行比较(这可能是一个正常的 Keras 损失函数)。

我知道一般如何编写自定义损失函数,但不会这样做。我发现了这个 closely related question(未回答),但我的情况有点不同,因为我想保留我的 softmax 输出。

可以在客户损失函数中混合使用 tensorflow 和 keras。一旦您可以访问所有 Tensorflow 功能,事情就会变得非常简单。我只是给你一个例子来说明这个函数是如何实现的。

import tensorflow as tf
def custom_loss(target, softmax):
    max_indices = tf.argmax(softmax, -1)

    # Get the embedding matrix. In Tensorflow, this can be directly done
    # with tf.nn.embedding_lookup
    embedding_vectors = tf.nn.embedding_lookup(you_embedding_matrix, max_indices)

    # Do anything you want with normal keras loss function
    loss = some_keras_loss_function(target, embedding_vectors)

    loss = tf.reduce_mean(loss)
    return loss

凡洛的回答指向了正确的方向,但最终行不通,因为涉及到不可导的操作。请注意,此类操作对于真实值是可以接受的(损失函数采用真实值和预测值,不可推导的操作仅适用于真实值)。

公平地说,这就是我首先要问的。 不可能做我想做的事,但我们可以得到类似的可推导的行为:

1) softmax 值的逐元素幂。这使得较小的值小得多。例如,4 的幂 [0.5, 0.2, 0.7] 变为 [0.0625, 0.0016, 0.2400]。请注意,0.2 与 0.7 相当,但 0.0016 相对于 0.24 可忽略不计。 my_power越高,最终结果越接近one-hot

soft_extreme = Lambda(lambda x: x ** my_power)(softmax)

2) 重要的是,softmax 和 one-hot 向量都被归一化了,但我们的 "soft_extreme" 没有。首先求数组的和:

norm = tf.reduce_sum(soft_extreme, 1)

3) 归一化 soft_extreme:

almost_one_hot = Lambda(lambda x: x / norm)(soft_extreme)

注意:在 1) 中将 my_power 设置得太高将导致 NaN。如果您需要更好的 softmax 到 one-hot 转换,那么您可以连续执行步骤 1 到 3 两次或更多次。

4) 最后我们需要字典中的向量。禁止查找,但我们可以使用矩阵乘法来获取平均向量。因为我们的 soft_normalized 类似于单热编码,所以这个平均值将类似于与最高参数(原始预期行为)相关联的向量。 (1) 中的 my_power 越高,这将越真实:

target_vectors = tf.tensordot(almost_one_hot, embedding_matrix, axes=[[1], [0]])

注意:直接使用批处理是不行的!就我而言,我重塑了 "one hot"(从 [batch, dictionary_length][batch, 1, dictionary_length] 使用 tf.reshape。然后平铺我的 embedding_matrix 批次,最后使用:

predicted_vectors = tf.matmul(reshaped_one_hot, tiled_embedding)

可能有更优雅的解决方案(或更少的内存消耗,如果平铺嵌入矩阵不是一个选项),请随意探索更多。