如何根据某些条件替换 Scala 中 breeze 矩阵的元素?

How to replace elements of a breeze matrix in Scala based on some condition?

我在 Scala 中使用二维 Breeze 矩阵。在某些时候,我必须对两个矩阵进行 element-wise 除法。分母矩阵中的某些元素可以为零,导致结果中出现 NaN。

我可以遍历矩阵维度并将 0.0 替换为 >0。

但是有更简单的或 Scala 惯用的解决方案吗?

一步一步:

  • 示例矩阵:

    val dm = DenseMatrix((1.0, 0.0, 3.0), (0.0, 5.0, 6.0))
    
  • 找出哪些元素等于0.0:

    dm :== 0.0
    
    breeze.linalg.DenseMatrix[Boolean] =
    false  true   false
    true   false  false
    
  • 对矩阵进行切片:

    dm(dm :== 0.0)
    
    breeze.linalg.SliceVector[(Int, Int),Double] = breeze.linalg.SliceVector@2b
    
  • 使用切片矩阵进行替换:

    dm(dm :== 0.0) := 42.0
    
    breeze.linalg.Vector[Double] = breeze.linalg.SliceVector@2b
    
  • 检查矩阵:

    dm
    
    breeze.linalg.DenseMatrix[Double] =
    1.0   42.0  3.0
    42.0  5.0   6.0
    

映射 NaN 比切片更快。

val matr = DenseMatrix((1.0, 0.0, 3.0), (0.0, 11.0, 12.0),
      (1.0, 2.0, 0.0))
val matr2 = DenseMatrix((3.0, 0.0, 1.0), (0.0, 12.0, 11.0),
      (2.0, 1.0, 0.0))

def time[R](block: => R): R = {
  val t0 = System.nanoTime()
  val result = block    // call-by-name
  val t1 = System.nanoTime()
  println("Elapsed time: " + (t1 - t0) + "ns")
  result
}

def replaceZeroes1(mat1: DenseMatrix[Double], mat2: DenseMatrix[Double], rep: Double) = {
   (mat1 /:/ mat2).map(x => if (x.isNaN()) rep else x)
}
    
def replaceZeroes2(mat1: DenseMatrix[Double], mat2: DenseMatrix[Double], rep: Double) = {
    mat1(mat1 :== 0.0) := rep
    mat2(mat2 :== 0.0) := 1
    mat1 /:/ mat2
}
time(println(replaceZeroes1(matr, matr2, 42.0)))
time(println(replaceZeroes2(matr, matr2, 42.0)))

产生:

0.3333333333333333  42.0                3.0                 
42.0                0.9166666666666666  1.0909090909090908  
0.5                 2.0                 42.0                
Elapsed time: 13087782ns
Replace Zero2
0.3333333333333333  42.0                3.0                 
42.0                0.9166666666666666  1.0909090909090908  
0.5                 2.0                 42.0                
Elapsed time: 16613179ns

映射出 NaN 既快捷又直接。即使您从 function2 中删除第二个切片,它也会更快。

注意:这没有在 Spark 中使用非常大的数据集进行测试,只是 breeze。在那种情况下,可能有不同的时间(尽管我对此表示怀疑)。

奖金:

如果您只是想从具有任何值集的矩阵生成 1s 和 0s 矩阵(例如从加权网络生成 non-weighted 网络),我会使用:

(mat /:/ mat).map(x => if (x.isNaN()) 0.0 else x)