在自定义损失函数中使用 `switch`/`cond`

using `switch`/`cond` in a custom loss function

我需要在 keras 中实现一个自定义损失函数来计算标准分类交叉熵,除非 y_true 全为零。

这是我的尝试:

def masked_crossent(y_true, y_pred):
    return K.switch(K.any(y_true),
                    losses.categorical_crossentropy(y_true, y_pred),
                    losses.categorical_crossentropy(y_true, y_pred) * 0)

但是,训练开始后出现以下错误(编译工作正常):

~/anaconda3/lib/python3.5/site-packages/tensorflow/python/client/session.py in init(self, graph, fetches, feeds) 419 self._ops.append(True) 420 else: --> 421 self._assert_fetchable(graph, fetch.op) 422 self._fetches.append(fetch_name) 423 self._ops.append(False)

~/anaconda3/lib/python3.5/site-packages/tensorflow/python/client/session.py in _assert_fetchable(self, graph, op) 432 if not graph.is_fetchable(op): 433 raise ValueError( --> 434 'Operation %r has been marked as not fetchable.' % op.name) 435 436 def fetches(self):

ValueError: Operation 'IsVariableInitialized_4547' has been marked as not fetchable.

代替losses.categorical_crossentropy(y_true, y_pred) * 0,我还尝试了以下各种其他错误(在编译期间或训练开始后):

                    K.zeros_like(losses.categorical_crossentropy(y_true, y_pred))

                    K.zeros((K.int_shape(y_true)[0]))

                    K.zeros((K.int_shape(y_true)[0], 1))

...虽然我认为有一种简单的方法可以做到这一点。

我只有一个解决方法:

def masked_crossent(y_true, y_pred):
    return K.max( y_true ) * K.categorical_crossentropy(y_true, y_pred)

如果是整批,则需要添加 axis = -1