为 vae 预测 keras 的 NotImplementedError

NotImplementedError for predict keras for vae

所以我一直在为 mnist 使用这个卷积 vae 的例子: https://keras.io/examples/generative/vae/

vae.predict(mnist_digits)


NotImplementedError                       Traceback (most recent call last)

<ipython-input-8-2e6bf7edcacc> in <module>()
----> 1 vae.predict(mnist_digits)

1 frames

/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/func_graph.py in autograph_handler(*args, **kwargs)
   1145           except Exception as e:  # pylint:disable=broad-except
   1146             if hasattr(e, "ag_error_metadata"):
-> 1147               raise e.ag_error_metadata.to_exception(e)
   1148             else:
   1149               raise

NotImplementedError: in user code:

    File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 1801, in predict_function  *
        return step_function(self, iterator)
    File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 1790, in step_function  **
        outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 1783, in run_step  **
        outputs = model.predict_step(data)
    File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 1751, in predict_step
        return self(x, training=False)
    File "/usr/local/lib/python3.7/dist-packages/keras/utils/traceback_utils.py", line 67, in error_handler
        raise e.with_traceback(filtered_tb) from None
    File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 475, in call
        raise NotImplementedError('Unimplemented `tf.keras.Model.call()`: if you '

    NotImplementedError: Exception encountered when calling layer "vae" (type VAE).
    
    Unimplemented `tf.keras.Model.call()`: if you intend to create a `Model` with the Functional API, please provide `inputs` and `outputs` arguments. Otherwise, subclass `Model` with an overridden `call()` method.
    
    Call arguments received:
      • inputs=tf.Tensor(shape=(None, 28, 28, 1), dtype=float32)
      • training=False
      • mask=None

同样,

vae(mnist_digits)

我得到以下信息:

---------------------------------------------------------------------------

NotImplementedError                       Traceback (most recent call last)

<ipython-input-8-aa5f4fb52e20> in <module>()
----> 1 vae(mnist_digits)

1 frames

/usr/local/lib/python3.7/dist-packages/keras/engine/training.py in call(self, inputs, training, mask)
    473         a list of tensors if there are more than one outputs.
    474     """
--> 475     raise NotImplementedError('Unimplemented `tf.keras.Model.call()`: if you '
    476                               'intend to create a `Model` with the Functional '
    477                               'API, please provide `inputs` and `outputs` '

NotImplementedError: Exception encountered when calling layer "vae" (type VAE).

Unimplemented `tf.keras.Model.call()`: if you intend to create a `Model` with the Functional API, please provide `inputs` and `outputs` arguments. Otherwise, subclass `Model` with an overridden `call()` method.

Call arguments received:   • inputs=tf.Tensor(shape=(70000, 28, 28, 1), dtype=float32)   • training=None   • mask=None

我是否必须创建自定义预测函数才能解决此问题。

predict 方法不在您的 VAE class 中,而是在 encoderdecoder 中,因为它们是您的 Keras 模型。

  1. 如果您想查看不同数字 class 的聚类情况:

    z, _, _ = vae.encoder.predict(data)

  2. 如果你想从你的潜在 space 中采样数字:

    decoded = vae.decoder.predict(z)