使用 java 原始数组的 clojure 代码比 scala 版本慢 70 倍
clojure code using java primitive arrays 70X slower than scala version
我用 clojure 和 scala 写了一个编辑距离算法。
scala 版本比 clojure 版本快 70 倍。
clojure:
(defn edit-distance
"['seq of char' 'seq of char']"
[s0 s1]
(let [n0 (count s0)
n1 (count s1)
distances (make-array Long/TYPE (inc n0) (inc n1))]
;;initialize distances
(doseq [i (range 1 (inc n0))] (aset-long distances i 0 i))
(doseq [j (range 1 (inc n1))] (aset-long distances 0 j j))
(doseq [i (range 1 (inc n0)), j (range 1 (inc n1))]
(let [ins (aget distances i (dec j))
del (aget distances (dec i) j)
match (aget distances (dec i) (dec j))
min-dist (min ins del match)]
(cond
(not= match min-dist) (aset-long distances i j (inc min-dist))
(not= (nth s0 (dec i)) (nth s1 (dec j))) (aset-long distances i j (inc min-dist))
:else (aset-long distances i j min-dist))))
(aget distances n0 n1)))
scala:
def editDistance(s0: Array[Char], s1: Array[Char]):Int = {
val n0 = s0.length
val n1 = s1.length
val distances = Array.fill(n0+1)(ArrayBuffer.fill(n1+1)(0))
for(j <- 0 to n1){distances(0)(j) = j}
for(i <- 0 to n0){distances(i)(0) = i}
for(i <- 1 to n0; j <- 1 to n1){
val ins = distances(i)(j-1)
val del = distances(i-1)(j)
val matches = distances(i-1)(j-1)
val minDist = (ins::del::matches::Nil).reduceLeft(_ min _)
if (matches != minDist)
distances(i)(j) = minDist + 1
else if (s0(i-1) == s1(j-1))
distances(i)(j) = minDist
else
distances(i)(j) = minDist + 1
}
distances(n0)(n1)
}
我在 clojure 中使用 java 的数组以获得最佳性能。我考虑过在调用 aget
时进行提示,但我的代码执行得更差(这可能是预期的,因为 make-array
已经定义了类型化数组)。我还覆盖了 projects.clj 中的 clojure :jvm-opts
。然而,我得到的较低性能差距是 70 倍。
我在 clojure 中使用 java 数组有什么问题?
感谢您的见解。
我想我找到问题所在了。
正如您在评论中提到的,反射调用消耗了大部分时间。原因如下。
在分析代码之前,我将 *warn-on-reflection*
设置为 true:
(set! *warn-on-reflection* true)
然后,如果您查看 aset
or macro that generates aset-long
function, you'll see that for 4+ arities it uses apply
to invoke the functions. Same thing for aget
for 3+ arities. I'm not 100% sure, but I believe that information about types of arguments is lost during apply
ing a function. Also if you look closely here and here 的源代码,您可能会注意到 aget
和 aset
函数可以在编译期间内联。我们绝对想要:
(defn edit-distance
"['seq of char' 'seq of char']"
[s0 s1]
(let [n0 (count s0)
n1 (count s1)
distances (make-array Long/TYPE (inc n0) (inc n1))]
;; I've unwinded all aget/aset calls, so they can be inlined by compiler.
;; Also I'm type hinting first argument of toplevel aget/aset calls.
;; The reason is explained next.
(doseq [^long i (range 1 (inc n0))] (aset ^longs (aget distances i) 0 i))
(doseq [^long j (range 1 (inc n1))] (aset ^longs (aget distances 0) j j))
(doseq [i (range 1 (inc n0)), j (range 1 (inc n1))]
(let [ins (aget ^longs (aget distances i) (dec j))
del (aget ^longs (aget distances (dec i)) j)
match (aget ^longs (aget distances (dec i)) (dec j))
min-dist (min ins del match)]
(cond
(not= match min-dist) (aset ^longs (aget distances i) j (inc min-dist))
(not= (nth s0 (dec i)) (nth s1 (dec j))) (aset ^longs (aget distances i) j (inc min-dist))
:else (aset ^longs (aget distances i) j min-dist))))
;; we can leave this, since it is not placed within loop
(aget distances n0 n1)))
让我们编译我们的新函数。还记得我们一开始设置的全局变量吗?如果设置为true
,编译器会在编译过程中产生一堆警告:
Reflection warning, core.clj:75:23 - call to static method aget on clojure.lang.RT can't be resolved (argument types: unknown, int).
Reflection warning, core.clj:76:23 - call to static method aget on clojure.lang.RT can't be resolved (argument types: unknown, int).
Reflection warning, core.clj:77:25 - call to static method aget on clojure.lang.RT can't be resolved (argument types: unknown, int).
...
问题是 Clojure 无法识别 (make-array Long/TYPE (inc n0) (inc n1))
的类型,将其标记为 unknown
。我们需要输入提示:
(let [...
;; type hint for 2d array of primitive longs
^"[[J" distances (make-array Long/TYPE (inc n0) (inc n1))
...]
...)
至此,我们似乎已经准备就绪。最终版本如下:
(defn edit-distance
"['seq of char' 'seq of char']"
[s0 s1]
(let [n0 (count s0)
n1 (count s1)
^"[[J" distances (make-array Long/TYPE (inc n0) (inc n1))]
;;initialize distances
(doseq [^long i (range 1 (inc n0))] (aset ^longs (aget distances i) 0 i))
(doseq [^long j (range 1 (inc n1))] (aset ^longs (aget distances 0) j j))
(doseq [i (range 1 (inc n0)), j (range 1 (inc n1))]
(let [ins (aget ^longs (aget distances i) (dec j))
del (aget ^longs (aget distances (dec i)) j)
match (aget ^longs (aget distances (dec i)) (dec j))
min-dist (min ins del match)]
(cond
(not= match min-dist) (aset ^longs (aget distances i) j (inc min-dist))
(not= (nth s0 (dec i)) (nth s1 (dec j))) (aset ^longs (aget distances i) j (inc min-dist))
:else (aset ^longs (aget distances i) j min-dist))))
(aget distances n0 n1)))
这里是基准:
之前:
> (time (edit-distance i1 i2))
"Elapsed time: 4601.025555 msecs"
291
之后:
> (time (edit-distance i1 i2))
"Elapsed time: 27.782828 msecs"
291
我用 clojure 和 scala 写了一个编辑距离算法。
scala 版本比 clojure 版本快 70 倍。
clojure:
(defn edit-distance
"['seq of char' 'seq of char']"
[s0 s1]
(let [n0 (count s0)
n1 (count s1)
distances (make-array Long/TYPE (inc n0) (inc n1))]
;;initialize distances
(doseq [i (range 1 (inc n0))] (aset-long distances i 0 i))
(doseq [j (range 1 (inc n1))] (aset-long distances 0 j j))
(doseq [i (range 1 (inc n0)), j (range 1 (inc n1))]
(let [ins (aget distances i (dec j))
del (aget distances (dec i) j)
match (aget distances (dec i) (dec j))
min-dist (min ins del match)]
(cond
(not= match min-dist) (aset-long distances i j (inc min-dist))
(not= (nth s0 (dec i)) (nth s1 (dec j))) (aset-long distances i j (inc min-dist))
:else (aset-long distances i j min-dist))))
(aget distances n0 n1)))
scala:
def editDistance(s0: Array[Char], s1: Array[Char]):Int = {
val n0 = s0.length
val n1 = s1.length
val distances = Array.fill(n0+1)(ArrayBuffer.fill(n1+1)(0))
for(j <- 0 to n1){distances(0)(j) = j}
for(i <- 0 to n0){distances(i)(0) = i}
for(i <- 1 to n0; j <- 1 to n1){
val ins = distances(i)(j-1)
val del = distances(i-1)(j)
val matches = distances(i-1)(j-1)
val minDist = (ins::del::matches::Nil).reduceLeft(_ min _)
if (matches != minDist)
distances(i)(j) = minDist + 1
else if (s0(i-1) == s1(j-1))
distances(i)(j) = minDist
else
distances(i)(j) = minDist + 1
}
distances(n0)(n1)
}
我在 clojure 中使用 java 的数组以获得最佳性能。我考虑过在调用 aget
时进行提示,但我的代码执行得更差(这可能是预期的,因为 make-array
已经定义了类型化数组)。我还覆盖了 projects.clj 中的 clojure :jvm-opts
。然而,我得到的较低性能差距是 70 倍。
我在 clojure 中使用 java 数组有什么问题?
感谢您的见解。
我想我找到问题所在了。
正如您在评论中提到的,反射调用消耗了大部分时间。原因如下。
在分析代码之前,我将 *warn-on-reflection*
设置为 true:
(set! *warn-on-reflection* true)
然后,如果您查看 aset
or macro that generates aset-long
function, you'll see that for 4+ arities it uses apply
to invoke the functions. Same thing for aget
for 3+ arities. I'm not 100% sure, but I believe that information about types of arguments is lost during apply
ing a function. Also if you look closely here and here 的源代码,您可能会注意到 aget
和 aset
函数可以在编译期间内联。我们绝对想要:
(defn edit-distance
"['seq of char' 'seq of char']"
[s0 s1]
(let [n0 (count s0)
n1 (count s1)
distances (make-array Long/TYPE (inc n0) (inc n1))]
;; I've unwinded all aget/aset calls, so they can be inlined by compiler.
;; Also I'm type hinting first argument of toplevel aget/aset calls.
;; The reason is explained next.
(doseq [^long i (range 1 (inc n0))] (aset ^longs (aget distances i) 0 i))
(doseq [^long j (range 1 (inc n1))] (aset ^longs (aget distances 0) j j))
(doseq [i (range 1 (inc n0)), j (range 1 (inc n1))]
(let [ins (aget ^longs (aget distances i) (dec j))
del (aget ^longs (aget distances (dec i)) j)
match (aget ^longs (aget distances (dec i)) (dec j))
min-dist (min ins del match)]
(cond
(not= match min-dist) (aset ^longs (aget distances i) j (inc min-dist))
(not= (nth s0 (dec i)) (nth s1 (dec j))) (aset ^longs (aget distances i) j (inc min-dist))
:else (aset ^longs (aget distances i) j min-dist))))
;; we can leave this, since it is not placed within loop
(aget distances n0 n1)))
让我们编译我们的新函数。还记得我们一开始设置的全局变量吗?如果设置为true
,编译器会在编译过程中产生一堆警告:
Reflection warning, core.clj:75:23 - call to static method aget on clojure.lang.RT can't be resolved (argument types: unknown, int).
Reflection warning, core.clj:76:23 - call to static method aget on clojure.lang.RT can't be resolved (argument types: unknown, int).
Reflection warning, core.clj:77:25 - call to static method aget on clojure.lang.RT can't be resolved (argument types: unknown, int).
...
问题是 Clojure 无法识别 (make-array Long/TYPE (inc n0) (inc n1))
的类型,将其标记为 unknown
。我们需要输入提示:
(let [...
;; type hint for 2d array of primitive longs
^"[[J" distances (make-array Long/TYPE (inc n0) (inc n1))
...]
...)
至此,我们似乎已经准备就绪。最终版本如下:
(defn edit-distance
"['seq of char' 'seq of char']"
[s0 s1]
(let [n0 (count s0)
n1 (count s1)
^"[[J" distances (make-array Long/TYPE (inc n0) (inc n1))]
;;initialize distances
(doseq [^long i (range 1 (inc n0))] (aset ^longs (aget distances i) 0 i))
(doseq [^long j (range 1 (inc n1))] (aset ^longs (aget distances 0) j j))
(doseq [i (range 1 (inc n0)), j (range 1 (inc n1))]
(let [ins (aget ^longs (aget distances i) (dec j))
del (aget ^longs (aget distances (dec i)) j)
match (aget ^longs (aget distances (dec i)) (dec j))
min-dist (min ins del match)]
(cond
(not= match min-dist) (aset ^longs (aget distances i) j (inc min-dist))
(not= (nth s0 (dec i)) (nth s1 (dec j))) (aset ^longs (aget distances i) j (inc min-dist))
:else (aset ^longs (aget distances i) j min-dist))))
(aget distances n0 n1)))
这里是基准:
之前:
> (time (edit-distance i1 i2))
"Elapsed time: 4601.025555 msecs"
291
之后:
> (time (edit-distance i1 i2))
"Elapsed time: 27.782828 msecs"
291