如何在 TensorFlow 图中正确引发异常

How to properly raise exception in TensorFlow graph

我想在图形模式(在 TensorFlow 服务中)中根据输入张量的值引发 tf.errors.InvalidArgumentError 异常。

目前我使用 tf.debugging.assert_all_finite,这很好用。由于我不是对错误检查进行断言,而是根据输入引发异常,因此引发显式异常会更好。

我的问题归结为:

这样做的正确方法是什么?

编辑: 一些更多的细节。我想在不使用 tf.debugging 的情况下重新创建以下逻辑(除非这实际上是正确的方法)。

目前我正在检查是否没有像这样的 NaN 值:

assert_op = tf.debugging.assert_all_finite(
    input_data,
    'Cant have nans at beginning or end'
)

您通过邮件给我写信,这可能与 this TF issue about catching exceptions within the graph execution, and 有关。但是,我不确定这是否真的与您相关。这个 TF 问题和 SO 问题是关于如何动态 捕获 异常,所以基本上在 TF 图中实现 try: ... except: ...

引入控制结构的其他 TF 功能是:

  • tf.while_loop
  • tf.cond

tf.cond 是您如何有条件地执行代码的问题的答案。取决于条件,即 bool 标量。但也许这不是您真正的问题,而是如何制定条件?

tf.check_numerics 检查 inf/nan 的张量,如果找到这样的张量则抛出异常。

如果您想将其作为条件,可以使用此代码:

is_finite = tf.reduce_all(tf.is_finite(x))

如果你想在某些条件不成立时抛出异常,你可以这样做:

check_op = tf.Assert(is_finite, ["Tensor had inf or nan values:", x])

您可能想使用 tf.control_dependencies 来确保此操作 check_op 得到执行。