在 tf2.1.0-keras 中使用条件

Using condition in tf2.1.0-keras

我正在尝试在我自己的 tf2.1.0-keras 模型中使用 bool 条件,下面是简单的示例:

import tensorflow as tf

class TestKeras:
    def __init__(self):
        pass

    def build_graph(self):
        x = tf.keras.Input(shape=(2),batch_size=1)
        x_value = x[0,0]
        y = tf.cond(x_value > 0, lambda :tf.add(x_value,0), lambda :tf.add(x_value,0))
        return tf.keras.models.Model(inputs=[x], outputs=[y])

if __name__ == "__main__":
    tk = TestKeras()
    model = tk.build_graph()
    model.summary(line_length=100)

但它似乎不起作用并抛出异常:

using a `tf.Tensor` as a Python `bool` is not allowed in Graph execution. Use Eager execution or decorate this function with @tf.function.

我已经尝试用 tf.keras.backend.switch 替换 tf.cond,但仍然出现同样的错误。

我还尝试将代码 y = tf.cond(xxx) 拆分为一个函数并添加 @tf.funcion 装饰器:

@tf.function
def compute_y(self,x):
    return tf.cond(x > 0, lambda :tf.add(x,0), lambda :tf.add(x,0))

但是又出现了另一个错误:

Inputs to eager execution function cannot be Keras symbolic tensors, but found [<tf.Tensor 'strided_slice:0' shape=() dtype=float32>]

有人知道条件如何在 tf2.1.0-keras 中工作吗?

tf.keras.Input 是一个符号张量,用于定义 keras 模型的输入。每当你想在 keras 模型中应用自定义逻辑时,你应该子 class Layer class,或者使用 Lambda 层。

例如,Lambda层:

class TestKeras:
    def __init__(self):
        pass
    
    def build_graph(self):
        x = tf.keras.Input(shape=(2),batch_size=1)
        def custom_fct(x):
            x_value = x[0,0]
            return tf.cond(x_value > 0, lambda :tf.add(x_value,0), lambda :tf.add(x_value,0))
        y = tf.keras.layers.Lambda(custom_fct)(x)
        return tf.keras.models.Model(inputs=[x], outputs=[y])