使用 java 原始数组的 clojure 代码比 scala 版本慢 70 倍

clojure code using java primitive arrays 70X slower than scala version

我用 clojure 和 scala 写了一个编辑距离算法。

scala 版本比 clojure 版本快 70 倍。

clojure:

(defn edit-distance                                                                                                                                                                                                                                                             
  "['seq of char' 'seq of char']"                                                                                                                                                                                                                                               
  [s0 s1]                                                                                                                                                                                                                                                                       
  (let [n0 (count s0)                                                                                                                                                                                                                                                           
        n1 (count s1)                                                                                                                                                                                                                                                           
        distances (make-array Long/TYPE (inc n0) (inc n1))]                                                                                                                                                                                                                     
    ;;initialize distances                                                                                                                                                                                                                                                      
    (doseq [i (range 1 (inc n0))] (aset-long distances i 0 i))                                                                                                                                                                                                                  
    (doseq [j (range 1 (inc n1))] (aset-long distances 0 j j))                                                                                                                                                                                                                  

    (doseq [i (range 1 (inc n0)), j (range 1 (inc n1))]                                                                                                                                                                                                                         
      (let [ins (aget distances i (dec j))                                                                                                                                                                                                                                      
            del (aget distances (dec i) j)                                                                                                                                                                                                                                      
            match (aget distances (dec i) (dec j))                                                                                                                                                                                                                              
            min-dist (min ins del match)]                                                                                                                                                                                                                                       
        (cond                                                                                                                                                                                                                                                                   
          (not= match min-dist) (aset-long distances i j (inc min-dist))                                                                                                                                                                                                        
          (not= (nth s0 (dec i)) (nth s1 (dec j))) (aset-long distances i j (inc min-dist))                                                                                                                                                                                     
          :else (aset-long distances i j min-dist))))                                                                                                                                                                                                                           
    (aget distances n0 n1)))     

scala:

 def editDistance(s0: Array[Char], s1: Array[Char]):Int = {                                                                                                                                                                                                                   
      val n0 = s0.length                                                                                                                                                                                                                                                        
      val n1 = s1.length                                                                                                                                                                                                                                                        
      val distances = Array.fill(n0+1)(ArrayBuffer.fill(n1+1)(0))                                                                                                                                                                                                               
      for(j <- 0 to n1){distances(0)(j) = j}                                                                                                                                                                                                                                    
      for(i <- 0 to n0){distances(i)(0) = i}                                                                                                                                                                                                                                    
      for(i <- 1 to n0; j <- 1 to n1){                                                                                                                                                                                                                                          
         val ins = distances(i)(j-1)                                                                                                                                                                                                                                            
         val del = distances(i-1)(j)                                                                                                                                                                                                                                            
         val matches = distances(i-1)(j-1)                                                                                                                                                                                                                                      
         val minDist = (ins::del::matches::Nil).reduceLeft(_ min _)                                                                                                                                                                                                             
         if (matches != minDist)                                                                                                                                                                                                                                                
            distances(i)(j) = minDist + 1                                                                                                                                                                                                                                       
         else if (s0(i-1) == s1(j-1))                                                                                                                                                                                                                                           
            distances(i)(j) = minDist                                                                                                                                                                                                                                           
         else                                                                                                                                                                                                                                                                   
            distances(i)(j) = minDist + 1                                                                                                                                                                                                                                       
      }                                                                                                                                                                                                                                                                         
      distances(n0)(n1)                                                                                                                                                                                                                                                         
   }                                 

我在 clojure 中使用 java 的数组以获得最佳性能。我考虑过在调用 aget 时进行提示,但我的代码执行得更差(这可能是预期的,因为 make-array 已经定义了类型化数组)。我还覆盖了 projects.clj 中的 clojure :jvm-opts。然而,我得到的较低性能差距是 70 倍。

我在 clojure 中使用 java 数组有什么问题?

感谢您的见解。

我想我找到问题所在了。

正如您在评论中提到的,反射调用消耗了大部分时间。原因如下。

在分析代码之前,我将 *warn-on-reflection* 设置为 true:

(set! *warn-on-reflection* true)

然后,如果您查看 aset or macro that generates aset-long function, you'll see that for 4+ arities it uses apply to invoke the functions. Same thing for aget for 3+ arities. I'm not 100% sure, but I believe that information about types of arguments is lost during applying a function. Also if you look closely here and here 的源代码,您可能会注意到 agetaset 函数可以在编译期间内联。我们绝对想要:

(defn edit-distance
  "['seq of char' 'seq of char']"
  [s0 s1]
  (let [n0 (count s0)
        n1 (count s1)
        distances (make-array Long/TYPE (inc n0) (inc n1))]
    ;; I've unwinded all aget/aset calls, so they can be inlined by compiler.
    ;; Also I'm type hinting first argument of toplevel aget/aset calls.
    ;; The reason is explained next.
    (doseq [^long i (range 1 (inc n0))] (aset ^longs (aget distances i) 0 i))
    (doseq [^long j (range 1 (inc n1))] (aset ^longs (aget distances 0) j j))

    (doseq [i (range 1 (inc n0)), j (range 1 (inc n1))]
      (let [ins (aget ^longs (aget distances i) (dec j))
            del (aget ^longs (aget distances (dec i))  j)
            match (aget ^longs (aget distances (dec i)) (dec j))
            min-dist (min ins del match)]
        (cond
          (not= match min-dist) (aset ^longs (aget distances i) j (inc min-dist))
          (not= (nth s0 (dec i)) (nth s1 (dec j))) (aset ^longs (aget distances i) j (inc min-dist))
          :else (aset ^longs (aget distances i) j min-dist))))
    ;; we can leave this, since it is not placed within loop
    (aget distances n0 n1)))

让我们编译我们的新函数。还记得我们一开始设置的全局变量吗?如果设置为true,编译器会在编译过程中产生一堆警告:

Reflection warning, core.clj:75:23 - call to static method aget on clojure.lang.RT can't be resolved (argument types: unknown, int).
Reflection warning, core.clj:76:23 - call to static method aget on clojure.lang.RT can't be resolved (argument types: unknown, int).
Reflection warning, core.clj:77:25 - call to static method aget on clojure.lang.RT can't be resolved (argument types: unknown, int).
...

问题是 Clojure 无法识别 (make-array Long/TYPE (inc n0) (inc n1)) 的类型,将其标记为 unknown。我们需要输入提示:

(let [...
      ;; type hint for 2d array of primitive longs
      ^"[[J" distances (make-array Long/TYPE (inc n0) (inc n1))
      ...]
   ...)

至此,我们似乎已经准备就绪。最终版本如下:

(defn edit-distance
  "['seq of char' 'seq of char']"
  [s0 s1]
  (let [n0 (count s0)
        n1 (count s1)
        ^"[[J" distances (make-array Long/TYPE (inc n0) (inc n1))]
    ;;initialize distances
    (doseq [^long i (range 1 (inc n0))] (aset ^longs (aget distances i) 0 i))
    (doseq [^long j (range 1 (inc n1))] (aset ^longs (aget distances 0) j j))

    (doseq [i (range 1 (inc n0)), j (range 1 (inc n1))]
      (let [ins (aget ^longs (aget distances i) (dec j))
            del (aget ^longs (aget distances (dec i))  j)
            match (aget ^longs (aget distances (dec i)) (dec j))
            min-dist (min ins del match)]
        (cond
          (not= match min-dist) (aset ^longs (aget distances i) j (inc min-dist))
          (not= (nth s0 (dec i)) (nth s1 (dec j))) (aset ^longs (aget distances i) j (inc min-dist))
          :else (aset ^longs (aget distances i) j min-dist))))
    (aget distances n0 n1)))

这里是基准:

之前:

> (time (edit-distance i1 i2))
"Elapsed time: 4601.025555 msecs"
291

之后:

> (time (edit-distance i1 i2))
"Elapsed time: 27.782828 msecs"
291