Tensorflow 加权交叉熵损失函数在 DNN 分类器估计器函数中的位置?

Where does Tensorflow Weighted Cross Entropy loss function goes in the DNN Classifier Estimator function?

我目前正在研究具有高度偏斜数据(90% negative/10% 正)的二项式分类算法,使用 tf.estimator.DNNClassifier。 由于我训练的所有模型都会将所有样本标记为负样本,因此我需要实施加权损失函数。

我在这里看了很多不同的问题,其中很多很有启发性。但是,对于如何实际实现这些功能,我无法得到实用的端到端答案。 and 线程是最好的。

我的问题是:我想使用 tf.nn.weighted_cross_entropy_with_logits(),但我不知道应该在我的代码中的什么地方插入它。

我有一个构建特征列的函数:

def construct_feature_columns(input_features):
  return set([tf.feature_column.numeric_column(my_feature)
              for my_feature in input_features])

定义 tf.estimator.DNNClassifier 以及其他参数(如优化器和输入函数)的函数:

def train_nn_classifier_model(
    learning_rate,
    steps,
    batch_size,
    hidden_units,
    training_examples,
    training_targets,
    validation_examples,
    validation_targets):

    dnn_classifier = tf.estimator.DNNClassifier(
        feature_columns=construct_feature_columns(training_examples),
        hidden_units=hidden_units,
        optimizer=my_optimizer)

训练函数:

dnn_classifier.train(input_fn=training_input_fn, steps=steps_per_period)

预测函数,在训练时计算误差:

training_probabilities = dnn_classifier.predict(input_fn=predict_training_input_fn)

优化器:

  my_optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
  my_optimizer = tf.contrib.estimator.clip_gradients_by_norm(my_optimizer, 5.0)

输入函数(用于训练输入、预测训练输入和验证输入):

  training_input_fn = lambda: my_input_fn(
      training_examples, 
      training_targets['True/False'], 
      batch_size=batch_size)

我应该在哪里插入 tf.nn.weighted_cross_entropy_with_logits,以便我的模型使用此函数计算损失?

另外,交叉熵函数里面的targets (A Tensor of the same type and shape as logits)怎么调用呢? training_targets DataFrame 是 input function 的输出,输入是 training_targets 吗?

logits具体是什么?因为对我来说,它们应该是来自函数的预测:

training_probabilities = dnn_classifier.predict(input_fn=predict_training_input_fn)

但这对我来说没有意义。我尝试了很多不同的方法来实现它,但是 none 的方法奏效了。

我讨厌成为坏消息的传递者,但是DNN Classifier不支持自定义损失函数:

Loss is calculated by using softmax cross entropy.

这是文档中唯一提到的损失(函数),我找不到任何 post 讨论通过直接更改 DNNClassifier 来解决这个问题的有效方法。相反,您似乎必须构建自己的 custom Estimator.