如何避免大内存消耗keras中的自定义损失函数
How to avoid large memory consume custom loss function in keras
我在keras中定义了一个自定义损失函数。在这个自定义损失函数中,我从 y_pred
中提取非连续值,如下所示:
sel_row = tf.constant([[2],[5],[8]])
row_tmp = y_pred
selected = tf.transpose(tf.gather_nd(tf.transpose(row_tmp), sel_row))
有了这个,我只是来自张量的 select 列。现在,如果我做同样的事情,但对于连续列,即 row_tmp[:, 2:5]
,我没有问题,但如果没有连续列,我会得到:
/tensorflow/lib/python3.7/site-packages/tensorflow_core/python/framework/indexed_slices.py:424:
UserWarning: Converting sparse IndexedSlices to a dense Tensor of unknown shape.
This may consume a large amount of memory.
"Converting sparse IndexedSlices to a dense Tensor of unknown shape. "
一切正常,但最好有一个更好的方法来避免消耗太多内存。
我试图将 tf.constant
更改为 tf.Variable
但出现此错误:
ValueError: tf.function-decorated function tried to create variables on non-first call.
有什么建议吗?
你可以这样做:
selected = tf.gather(row_tmp, tf.squeeze(sel_row, axis=1), axis=1)
我在keras中定义了一个自定义损失函数。在这个自定义损失函数中,我从 y_pred
中提取非连续值,如下所示:
sel_row = tf.constant([[2],[5],[8]])
row_tmp = y_pred
selected = tf.transpose(tf.gather_nd(tf.transpose(row_tmp), sel_row))
有了这个,我只是来自张量的 select 列。现在,如果我做同样的事情,但对于连续列,即 row_tmp[:, 2:5]
,我没有问题,但如果没有连续列,我会得到:
/tensorflow/lib/python3.7/site-packages/tensorflow_core/python/framework/indexed_slices.py:424:
UserWarning: Converting sparse IndexedSlices to a dense Tensor of unknown shape.
This may consume a large amount of memory.
"Converting sparse IndexedSlices to a dense Tensor of unknown shape. "
一切正常,但最好有一个更好的方法来避免消耗太多内存。
我试图将 tf.constant
更改为 tf.Variable
但出现此错误:
ValueError: tf.function-decorated function tried to create variables on non-first call.
有什么建议吗?
你可以这样做:
selected = tf.gather(row_tmp, tf.squeeze(sel_row, axis=1), axis=1)