如何在 tf.keras 自定义损失函数中触发 python 函数?
How to trigger a python function inside a tf.keras custom loss function?
在我的自定义损失函数中,我需要调用一个纯 python 函数,传入计算出的 TD 误差和一些索引。该函数不需要 return 任何东西或被区分。这是我要调用的函数:
def update_priorities(self, traces_idxs, td_errors):
"""Updates the priorities of the traces with specified indexes."""
self.priorities[traces_idxs] = td_errors + eps
我试过使用 tf.py_function
来调用包装函数,但它只有在嵌入到图形中时才会被调用,即如果它有输入和输出并且使用了输出。因此,我试图通过一些张量而不对它们执行任何操作,现在调用了该函数。这是我的整个自定义损失函数:
def masked_q_loss(data, y_pred):
"""Computes the MSE between the Q-values of the actions that were taken and the cumulative
discounted rewards obtained after taking those actions. Updates trace priorities.
"""
action_batch, target_qvals, traces_idxs = data[:,0], data[:,1], data[:,2]
seq = tf.cast(tf.range(0, tf.shape(action_batch)[0]), tf.int32)
action_idxs = tf.transpose(tf.stack([seq, tf.cast(action_batch, tf.int32)]))
qvals = tf.gather_nd(y_pred, action_idxs)
def update_priorities(_qvals, _target_qvals, _traces_idxs):
"""Computes the TD error and updates memory priorities."""
td_error = _target_qvals - _qvals
_traces_idxs = tf.cast(_traces_idxs, tf.int32)
mem.update_priorities(_traces_idxs, td_error)
return _qvals
qvals = tf.py_function(func=update_priorities, inp=[qvals, target_qvals, traces_idxs], Tout=[tf.float32])
return tf.keras.losses.mse(qvals, target_qvals)
但是由于调用 mem.update_priorities(_traces_idxs, td_error)
,我收到以下错误
ValueError: An operation has `None` for gradient. Please make sure that all of your ops have a gradient defined (i.e. are differentiable). Common ops without gradient: K.argmax, K.round, K.eval.
我不需要为 update_priorities
计算梯度,我只想在图形计算的特定点调用它,然后忘记它。我该怎么做?
在包装函数内的张量上使用 .numpy()
解决了问题:
def update_priorities(_qvals, _target_qvals, _traces_idxs):
"""Computes the TD error and updates memory priorities."""
td_error = np.abs((_target_qvals - _qvals).numpy())
_traces_idxs = (tf.cast(_traces_idxs, tf.int32)).numpy()
mem.update_priorities(_traces_idxs, td_error)
return _qvals
在我的自定义损失函数中,我需要调用一个纯 python 函数,传入计算出的 TD 误差和一些索引。该函数不需要 return 任何东西或被区分。这是我要调用的函数:
def update_priorities(self, traces_idxs, td_errors):
"""Updates the priorities of the traces with specified indexes."""
self.priorities[traces_idxs] = td_errors + eps
我试过使用 tf.py_function
来调用包装函数,但它只有在嵌入到图形中时才会被调用,即如果它有输入和输出并且使用了输出。因此,我试图通过一些张量而不对它们执行任何操作,现在调用了该函数。这是我的整个自定义损失函数:
def masked_q_loss(data, y_pred):
"""Computes the MSE between the Q-values of the actions that were taken and the cumulative
discounted rewards obtained after taking those actions. Updates trace priorities.
"""
action_batch, target_qvals, traces_idxs = data[:,0], data[:,1], data[:,2]
seq = tf.cast(tf.range(0, tf.shape(action_batch)[0]), tf.int32)
action_idxs = tf.transpose(tf.stack([seq, tf.cast(action_batch, tf.int32)]))
qvals = tf.gather_nd(y_pred, action_idxs)
def update_priorities(_qvals, _target_qvals, _traces_idxs):
"""Computes the TD error and updates memory priorities."""
td_error = _target_qvals - _qvals
_traces_idxs = tf.cast(_traces_idxs, tf.int32)
mem.update_priorities(_traces_idxs, td_error)
return _qvals
qvals = tf.py_function(func=update_priorities, inp=[qvals, target_qvals, traces_idxs], Tout=[tf.float32])
return tf.keras.losses.mse(qvals, target_qvals)
但是由于调用 mem.update_priorities(_traces_idxs, td_error)
ValueError: An operation has `None` for gradient. Please make sure that all of your ops have a gradient defined (i.e. are differentiable). Common ops without gradient: K.argmax, K.round, K.eval.
我不需要为 update_priorities
计算梯度,我只想在图形计算的特定点调用它,然后忘记它。我该怎么做?
在包装函数内的张量上使用 .numpy()
解决了问题:
def update_priorities(_qvals, _target_qvals, _traces_idxs):
"""Computes the TD error and updates memory priorities."""
td_error = np.abs((_target_qvals - _qvals).numpy())
_traces_idxs = (tf.cast(_traces_idxs, tf.int32)).numpy()
mem.update_priorities(_traces_idxs, td_error)
return _qvals