在自定义损失函数中使用 `switch`/`cond`
using `switch`/`cond` in a custom loss function
我需要在 keras 中实现一个自定义损失函数来计算标准分类交叉熵,除非 y_true
全为零。
这是我的尝试:
def masked_crossent(y_true, y_pred):
return K.switch(K.any(y_true),
losses.categorical_crossentropy(y_true, y_pred),
losses.categorical_crossentropy(y_true, y_pred) * 0)
但是,训练开始后出现以下错误(编译工作正常):
~/anaconda3/lib/python3.5/site-packages/tensorflow/python/client/session.py
in init(self, graph, fetches, feeds)
419 self._ops.append(True)
420 else:
--> 421 self._assert_fetchable(graph, fetch.op)
422 self._fetches.append(fetch_name)
423 self._ops.append(False)
~/anaconda3/lib/python3.5/site-packages/tensorflow/python/client/session.py
in _assert_fetchable(self, graph, op)
432 if not graph.is_fetchable(op):
433 raise ValueError(
--> 434 'Operation %r has been marked as not fetchable.' % op.name)
435
436 def fetches(self):
ValueError: Operation 'IsVariableInitialized_4547' has been marked as
not fetchable.
代替losses.categorical_crossentropy(y_true, y_pred) * 0
,我还尝试了以下各种其他错误(在编译期间或训练开始后):
K.zeros_like(losses.categorical_crossentropy(y_true, y_pred))
K.zeros((K.int_shape(y_true)[0]))
K.zeros((K.int_shape(y_true)[0], 1))
...虽然我认为有一种简单的方法可以做到这一点。
我只有一个解决方法:
def masked_crossent(y_true, y_pred):
return K.max( y_true ) * K.categorical_crossentropy(y_true, y_pred)
如果是整批,则需要添加 axis = -1
。
我需要在 keras 中实现一个自定义损失函数来计算标准分类交叉熵,除非 y_true
全为零。
这是我的尝试:
def masked_crossent(y_true, y_pred):
return K.switch(K.any(y_true),
losses.categorical_crossentropy(y_true, y_pred),
losses.categorical_crossentropy(y_true, y_pred) * 0)
但是,训练开始后出现以下错误(编译工作正常):
~/anaconda3/lib/python3.5/site-packages/tensorflow/python/client/session.py in init(self, graph, fetches, feeds) 419 self._ops.append(True) 420 else: --> 421 self._assert_fetchable(graph, fetch.op) 422 self._fetches.append(fetch_name) 423 self._ops.append(False)
~/anaconda3/lib/python3.5/site-packages/tensorflow/python/client/session.py in _assert_fetchable(self, graph, op) 432 if not graph.is_fetchable(op): 433 raise ValueError( --> 434 'Operation %r has been marked as not fetchable.' % op.name) 435 436 def fetches(self):
ValueError: Operation 'IsVariableInitialized_4547' has been marked as not fetchable.
代替losses.categorical_crossentropy(y_true, y_pred) * 0
,我还尝试了以下各种其他错误(在编译期间或训练开始后):
K.zeros_like(losses.categorical_crossentropy(y_true, y_pred))
K.zeros((K.int_shape(y_true)[0]))
K.zeros((K.int_shape(y_true)[0], 1))
...虽然我认为有一种简单的方法可以做到这一点。
我只有一个解决方法:
def masked_crossent(y_true, y_pred):
return K.max( y_true ) * K.categorical_crossentropy(y_true, y_pred)
如果是整批,则需要添加 axis = -1
。