SessionRunHook 在 运行 后返回空的 SessionRunValues

SessionRunHook returning empty SessionRunValues after run

我正在尝试编写一个钩子来计算一些全局指标(而不是批量指标)。为了制作原型,我想我会得到一个简单的连接和 运行 来捕捉并记住真正的积极因素。它看起来像这样:

class TPHook(tf.train.SessionRunHook):

    def after_create_session(self, session, coord):
        print("Starting Hook")

        tp_name = 'metrics/f1_macro/TP'
        self.tp = []
        self.args = session.graph.get_operation_by_name(tp_name)
        print(f"Got Args: {self.args}")

    def before_run(self, run_context):
        print("Starting Before Run")
        return tf.train.SessionRunArgs(self.args)

    def after_run(self, run_context, run_values):
        print("After Run")
        print(f"Got Values: {run_values.results}")

但是,挂钩的 "after_run" 部分返回的值始终是 None。我在训练和评估阶段都对此进行了测试。我对 SessionRunHooks 应该如何工作有什么误解吗?


也许相关信息: 该模型是在 keras 中构建的,并使用 keras.estimator.model_to_estimator() 函数转换为估算器。该模型已经过测试并且工作正常,我试图在钩子中检索的操作在此代码块中定义:

def _f1_macro_vector(y_true, y_pred):
    """Computes the F1-score with Macro averaging.

    Arguments:
        y_true {tf.Tensor} -- Ground-truth labels
        y_pred {tf.Tensor} -- Predicted labels

    Returns:
        tf.Tensor -- The computed F1-Score
    """
    y_true = K.cast(y_true, tf.float64)
    y_pred = K.cast(y_pred, tf.float64)

    TP = tf.reduce_sum(y_true * K.round(y_pred), axis=0, name='TP')
    FN = tf.reduce_sum(y_true * (1 - K.round(y_pred)), axis=0, name='FN')
    FP = tf.reduce_sum((1 - y_true) * K.round(y_pred), axis=0, name='FP')

    prec = TP / (TP + FP)
    rec = TP / (TP + FN)

    # Convert NaNs to Zero
    prec = tf.where(tf.is_nan(prec), tf.zeros_like(prec), prec)
    rec = tf.where(tf.is_nan(rec), tf.zeros_like(rec), rec)

    f1 = 2 * (prec * rec) / (prec + rec)

    # Convert NaN to Zero
    f1 = tf.where(tf.is_nan(f1), tf.zeros_like(f1), f1)

    return f1

以防万一有人遇到同样的问题,我发现了如何重组程序以使其正常工作。尽管文档听起来好像我可以将原始操作传递给 SessionRunArgs,但它似乎需要实际的张量(也许这是我的误读)。 这很容易实现 - 我只是将 after_create_session 代码更改为如下所示。

def after_create_session(self, session, coord):

    tp_name = 'metrics/f1_macro/TP'
    self.tp = []
    tp_tensor = session.graph.get_tensor_by_name(tp_name+':0')

    self.args = [tp_tensor]

并且成功运行。