pytorch 在 tensorflow 中的 autograd.detect_anomaly 等价物

pytorch's autograd.detect_anomaly equivalent in tensorflow

我正在尝试调试我的 tensorflow 代码,该代码在大约 30 个时期后突然产生 NaN 损失。您可能会在 .

中找到我的具体问题和我尝试过的事情

我在训练过程中监控了每个mini-batch的所有层的权重,发现权重突然跳到NaN,尽管在之前的迭代中所有权重值都小于1(我已经设置kernel_constraint max_norm 到 1)。这使得很难找出哪个操作是罪魁祸首。

Pytorch 有一个很酷的调试方法 torch.autograd.detect_anomaly,它会在任何产生 NaN 值并显示回溯的反向计算中产生错误。这使得调试代码变得容易。

TensorFlow 中有类似的东西吗?如果没有,您能否建议一种调试方法?

tensorflow中确实有类似的调试工具。参见 tf.debugging.check_numerics

这可用于跟踪在训练期间产生 infnan 值的张量。一旦找到这样的值,tensorflow 就会生成 InvalidArgumentError.

tf.debugging.check_numerics(LayerN, "LayerN is producing nans!")

如果张量 LayerN 有 nans,你会得到这样的错误:

Traceback (most recent call last):
  File "trainer.py", line 506, in <module>
    worker.train_model()
  File "trainer.py", line 211, in train_model
    l, tmae = train_step(*batch)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py", line 828, in __call__
    result = self._call(*args, **kwds)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py", line 855, in _call
    return self._stateless_fn(*args, **kwds)  # pylint: disable=not-callable
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py", line 2943, in __call__
    filtered_flat_args, captured_inputs=graph_function.captured_inputs)  # pylint: disable=protected-access
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py", line 1919, in _call_flat
    ctx, args, cancellation_manager=cancellation_manager))
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py", line 560, in call
    ctx=ctx)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/execute.py", line 60, in quick_execute
    inputs, attrs, num_outputs)
tensorflow.python.framework.errors_impl.InvalidArgumentError:  LayerN is producing nans! : Tensor had NaN values