涉及巨大矩阵乘法的keras中的自定义损失函数

Custom loss function in keras involving huge matrix multiplication

我在 Keras 中编写自定义损失函数时遇到困难。我有层权重 "W" 和矩阵 "M"。我想执行以下操作 trace((W * M) * W') 来计算我的损失函数。 Trace 是对角线元素的总和。在 numpy 中,我会做以下事情:

np.trace(np.dot(np.dot(W,M),W.T))) or 

def custom_regularizer(W,M):
    sum_reg = 0
    for i in range(W.shape[1]):
        for j in range(i,W.shape[1]):
            vector = W[:,i] - W[:,j]
            sum_reg = sum_reg + M[i,j] * (LA.norm(vector)**2)
    return sum_reg

对于keras,我写了下面的损失函数

def custom_loss(W):

  def lossFunction(y_true,y_pred):    
    loss = tf.trace(K.dot(K.dot(W,K.constant(M)),K.transpose(W)))
    return loss

return lossFunction

问题是keras正在计算维度为200000 * 200000的整个外矩阵,内存错误。有什么方法可以让我在不进行整个矩阵计算的情况下只得到对角线元素的总和。

如何做到与keras损失函数相同?

如果您遵循一些巧妙的技巧来计算轨迹,您应该不会 运行 内存不足。比如可以参考this.