如何避免大内存消耗keras中的自定义损失函数

How to avoid large memory consume custom loss function in keras

我在keras中定义了一个自定义损失函数。在这个自定义损失函数中,我从 y_pred 中提取非连续值,如下所示:

sel_row = tf.constant([[2],[5],[8]])
row_tmp = y_pred
selected = tf.transpose(tf.gather_nd(tf.transpose(row_tmp), sel_row))

有了这个,我只是来自张量的 select 列。现在,如果我做同样的事情,但对于连续列,即 row_tmp[:, 2:5],我没有问题,但如果没有连续列,我会得到:

/tensorflow/lib/python3.7/site-packages/tensorflow_core/python/framework/indexed_slices.py:424: 
UserWarning: Converting sparse IndexedSlices to a dense Tensor of unknown shape. 
This may consume a large amount of memory.
  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "

一切正常,但最好有一个更好的方法来避免消耗太多内存。

我试图将 tf.constant 更改为 tf.Variable 但出现此错误:

ValueError: tf.function-decorated function tried to create variables on non-first call.

有什么建议吗?

你可以这样做:

selected = tf.gather(row_tmp, tf.squeeze(sel_row, axis=1), axis=1)