U net Multiclass segmentation图像输入数据集错误

U net Multiclass segmentation image input dataset error

我正在尝试使用 U-net 进行多类分割。在之前的试验中,我尝试了二进制分割并且它有效。但是当我尝试做多类时,我遇到了这个错误。

**ValueError: 'generator yielded an element of shape (128,192,1) where an element of shape (128,192,5) was expected**   

这个5表示类的个数。这就是我定义输出层的方式。 output:Tensor("output/sigmoid:0",shape(?,128,192,5),dtype=float32).

由于灰度图像,我保留了 input_shape:(128,192,1) 的裁剪尺寸 和 label_shape:(128,192,5)

数据加载到 tensorflow 数据集中并使用 tf.iterator。 生成器从 tf.dataset.

生成数据
def get_datapoint_generator(self):
  def generator():
   for i in itertools.count(1):
    datapoint_dict=self._get_next_datapoint()
    yield datapoint_dict['image'],datapoint_dict['mask']

_get_next_datapoint_ 函数从 ram 获取下一个数据点,并处理裁剪和增强。

现在,如果它与输出形状不匹配,哪里会出错?

你可以尝试使用这个实现吗?我正在使用这个,但它在 Keras

def sparse_crossentropy(y_true, y_pred):
    nb_classes = K.int_shape(y_pred)[-1]
    y_true = K.one_hot(tf.cast(y_true[:, :, 0], dtype=tf.int32), nb_classes + 1)
    return K.categorical_crossentropy(y_true, y_pred)