使用 tf.cond() 检查自定义层调用方法中的条件

Checking condition in call method of custom layer using tf.cond()

我正在 tensorflow 2.x 中实现自定义层。我的要求是,程序应该在返回输出之前检查条件。

class SimpleRNN_cell(tf.keras.layers.Layer):
    def __init__(self, M1, M2, fi=tf.nn.tanh, disp_name=True):
        super(SimpleRNN_cell, self).__init__()
        pass        
    def call(self, X, hidden_state, return_state=True):
        y = tf.constant(5)
        if return_state == True:
            return y, self.h
        else:
            return y

我的问题是:我应该继续使用当前代码(假设 tape.gradient(Loss, self.trainable_weights) 可以正常工作)还是应该使用 tf.cond()。 此外,如果可能,请解释在何处使用 tf.cond() 以及在何处不使用。我没有找到太多关于这个主题的内容。

tf.cond仅在根据可微计算图中的数据进行条件评估时相关。 (https://www.tensorflow.org/api_docs/python/tf/cond) This was especially necessary in TF 1.0 with graph mode being the default. For eager mode the GradientTape system allows to also do conditional data flow with the python constructs such as if ...: (https://www.tensorflow.org/guide/autodiff#control_flow)

然而,仅仅根据配置参数提供不同的行为,不依赖于来自计算图的数据,并且在模型运行时是固定的,使用简单的 python if 语句是正确的.