Tensorflow probability: InvalidArgumentError: required broadcastable shapes

Tensorflow probability: InvalidArgumentError: required broadcastable shapes

我的数据采用 channels_first 格式。当我使用 tensorflow 概率层时,出现以下错误:

这是一个示例,其中输入形状为 [1,28,28] 且可重现代码:Gist(请确保您是 运行 GPU 上的代码。)

InvalidArgumentError:  required broadcastable shapes
     [[node gradient_tape/lambda/model_3_mixture_same_family_4_MixtureSameFamily_MixtureSameFamily/log_prob/model_3_mixture_same_family_4_MixtureSameFamily_independent_normal_4_IndependentNormal_Independentmodel_3_mixture_same_family_4_MixtureSameFamily_independent_normal_4_IndependentNormal_Normal/log_prob/model_3_mixture_same_family_4_MixtureSameFamily_independent_normal_4_IndependentNormal_Normal/log_prob/truediv/RealDiv_1 (defined at <ipython-input-22-243a182981d9>:9) ]] [Op:__inference_train_function_7663]

Errors may have originated from an input operation.
Input Source operations connected to node gradient_tape/lambda/model_3_mixture_same_family_4_MixtureSameFamily_MixtureSameFamily/log_prob/model_3_mixture_same_family_4_MixtureSameFamily_independent_normal_4_IndependentNormal_Independentmodel_3_mixture_same_family_4_MixtureSameFamily_independent_normal_4_IndependentNormal_Normal/log_prob/model_3_mixture_same_family_4_MixtureSameFamily_independent_normal_4_IndependentNormal_Normal/log_prob/truediv/RealDiv_1:
 model_3/mixture_same_family_4/MixtureSameFamily/independent_normal_4/IndependentNormal/Softplus (defined at /usr/local/lib/python3.7/dist-packages/tensorflow_probability/python/layers/distribution_layer.py:988)

Function call stack:
train_function

我不确定如何更改源代码以使其与通道第一个输入形状一起使用。有人可以帮我解决这个问题吗?

您的 preprocess 函数返回 image, image 而不是 image, sample['label']。如果您更改它,它应该可以工作!

我认为您也可以在损失中删除 K.cast。

更新:实际上当我 运行 这个我得到了 nan 的损失。可能还有其他问题。但至少它克服了形状错误! ‍♂️