如何摆脱clojure数组处理中的clojure/lang/RT.aset和clojure/lang/RT.intCast?

How to get rid of clojure/lang/RT.aset and clojure/lang/RT.intCast in clojure array processing?

我尝试在 Clojure 中尽可能快地进行复数数组的乘法运算。

选择的数据结构是两个元素的映射,:re:im,每个元素都是 Java 原语 double 的原生数组,以降低内存开销。

根据http://clojure.org/reference/java_interop,我对原始类型数组使用了精确类型规范。

通过这些提示aget被转换为原生数组dload op,但是有两个低效的地方,确切地说循环的计数器不是int而是long ,因此每次对数组进行索引时,都会使用对 clojure/lang/RT.intCast 的调用将计数器转换为 int。而且 aset 没有转换为本机操作,而是转换为对 clojure/lang/RT.aset.

的调用

另一个低效率是 checkcast。它检查数组实际上是双精度数组的每个循环。

结果是 运行 此 Clojure 代码的时间比等效的 Java 代码多 30%(不包括启动时间)。能否用 Clojure 重写此函数以使其运行更快?

Clojure代码,优化函数为multiply-complex-arrays.

(def size 65536)

(defn get-zero-complex-array
    []
    {:re (double-array size)
     :im (double-array size)})

(defn multiply-complex-arrays
    [a b]
    (let [
        a-re-array (doubles (get a :re))
        a-im-array (doubles (get a :im))
        b-re-array (doubles (get b :re))
        b-im-array (doubles (get b :im))
        res-re-array (double-array size)
        res-im-array (double-array size)
        ]
        (loop [i (int 0) size (int size)]
            (if (< i size)
                (let [
                    a-re (aget a-re-array i)
                    a-im (aget a-im-array i)
                    b-re (aget b-re-array i)
                    b-im (aget b-im-array i)
                    ]
                    (aset res-re-array i (- (* a-re b-re) (* a-im b-im)))
                    (aset res-im-array i (+ (* a-re b-im) (* b-re a-im)))
                    (recur (unchecked-inc i) size))
                {:re res-re-array :im res-im-array}))))

(let [
    res (loop [i (int 0) a (get-zero-complex-array)]
            (if (< i 30000)
                (recur (inc i) (multiply-complex-arrays a a))
                a))
    ]
    (println (aget (get res :re) 0)))

multiply-complex-arrays的主循环生成的java程序集是

  91: lload         8
  93: lload         10
  95: lcmp
  96: ifge          216
  99: aload_2
 100: checkcast     #51                 // class "[D"
 103: lload         8
 105: invokestatic  #46                 // Method clojure/lang/RT.intCast:(J)I
 108: daload
 109: dstore        12
 111: aload_3
 112: checkcast     #51                 // class "[D"
 115: lload         8
 117: invokestatic  #46                 // Method clojure/lang/RT.intCast:(J)I
 120: daload
 121: dstore        14
 123: aload         4
 125: checkcast     #51                 // class "[D"
 128: lload         8
 130: invokestatic  #46                 // Method clojure/lang/RT.intCast:(J)I
 133: daload
 134: dstore        16
 136: aload         5
 138: checkcast     #51                 // class "[D"
 141: lload         8
 143: invokestatic  #46                 // Method clojure/lang/RT.intCast:(J)I
 146: daload
 147: dstore        18
 149: aload         6
 151: checkcast     #51                 // class "[D"
 154: lload         8
 156: invokestatic  #46                 // Method clojure/lang/RT.intCast:(J)I
 159: dload         12
 161: dload         16
 163: dmul
 164: dload         14
 166: dload         18
 168: dmul
 169: dsub
 170: invokestatic  #55                 // Method clojure/lang/RT.aset:([DID)D
 173: pop2
 174: aload         7
 176: checkcast     #51                 // class "[D"
 179: lload         8
 181: invokestatic  #46                 // Method clojure/lang/RT.intCast:(J)I
 184: dload         12
 186: dload         18
 188: dmul
 189: dload         16
 191: dload         14
 193: dmul
 194: dadd
 195: invokestatic  #55                 // Method clojure/lang/RT.aset:([DID)D
 198: pop2
 199: lload         8
 201: lconst_1
 202: ladd
 203: lload         10
 205: lstore        10
 207: lstore        8
 209: goto          91

Java代码:

class ComplexArray {

    static final int SIZE = 1 << 16;

    double re[];

    double im[];

    ComplexArray(double re[], double im[]) {
        this.re = re;
        this.im = im;
    }

    static ComplexArray getZero() {
        return new ComplexArray(new double[SIZE], new double[SIZE]);
    }

    ComplexArray multiply(ComplexArray second) {
        double resultRe[] = new double[SIZE];
        double resultIm[] = new double[SIZE];
        for (int i = 0; i < SIZE; i++) {
            double aRe = this.re[i];
            double aIm = this.im[i];
            double bRe = second.re[i];
            double bIm = second.im[i];
            resultRe[i] = aRe * bRe - aIm * bIm;
            resultIm[i] = aRe * bIm + bRe * aIm;
        }
        return new ComplexArray(resultRe, resultIm);
    }

    public static void main(String args[]) {
        ComplexArray a = getZero();
        for (int i = 0; i < 30000; i++) {
            a = a.multiply(a);
        }
        System.out.println(a.re[0]);
    }
}

Java 代码中相同循环的汇编:

  13: iload         4
  15: ldc           #5                  // int 65536
  17: if_icmpge     92
  20: aload_0
  21: getfield      #2                  // Field re:[D
  24: iload         4
  26: daload
  27: dstore        5
  29: aload_0
  30: getfield      #3                  // Field im:[D
  33: iload         4
  35: daload
  36: dstore        7
  38: aload_1
  39: getfield      #2                  // Field re:[D
  42: iload         4
  44: daload
  45: dstore        9
  47: aload_1
  48: getfield      #3                  // Field im:[D
  51: iload         4
  53: daload
  54: dstore        11
  56: aload_2
  57: iload         4
  59: dload         5
  61: dload         9
  63: dmul
  64: dload         7
  66: dload         11
  68: dmul
  69: dsub
  70: dastore
  71: aload_3
  72: iload         4
  74: dload         5
  76: dload         11
  78: dmul
  79: dload         9
  81: dload         7
  83: dmul
  84: dadd
  85: dastore
  86: iinc          4, 1
  89: goto          13

您如何对这段代码进行基准测试?我建议使用 criterium 之类的东西,或者至少在比较时间之前执行多次执行。当它足够热时,诸如检查广播之类的东西应该由 JIT 优化。我还建议使用最新的 JVM、-server 和 -XX:+AggressiveOpts。

一般来说,我发现最好不要试图强制 Clojure 在循环中使用整数 - 而是将长整数作为循环计数器,使用 (set! *unchecked-math* true),并让 Clojure 在对数组进行索引时将长整数向下转换为整数.虽然这看起来像是额外的工作,但我对现代 hardware/JVM/JIT 的印象是差异比您预期的要小得多(因为无论如何您主要使用 64 位整数)。此外,看起来您将 size 作为循环变量携带,但它永远不会改变——也许您这样做是为了避免与 i 的类型不匹配,但我只会在循环之前让 size(作为 long)并做长增量和而是在 i 上进行比较。

有时您可以通过在循环之前设置一些东西来减少检查广播。虽然很容易观察代码并说出何时不需要它们,但编译器实际上并没有对此进行任何分析,而是将其留给 JIT 来优化(它通常非常擅长,或者不擅长)这在 99% 的代码中实际上并不重要)。

(set! *unchecked-math* :warn-on-boxed)

(def ^long ^:const size 65536)

(defn get-zero-complex-array []
  {:re (double-array size)
   :im (double-array size)})

(defn multiply-complex-arrays [a b]
  (let [a-re-array (doubles (get a :re))
        a-im-array (doubles (get a :im))
        b-re-array (doubles (get b :re))
        b-im-array (doubles (get b :im))
        res-re-array (double-array size)
        res-im-array (double-array size)
        s (long size)]
    (loop [i 0]
      (if (< i s)
        (let [a-re (aget a-re-array i)
              a-im (aget a-im-array i)
              b-re (aget b-re-array i)
              b-im (aget b-im-array i)]
          (aset res-re-array i (- (* a-re b-re) (* a-im b-im)))
          (aset res-im-array i (+ (* a-re b-im) (* b-re a-im)))
          (recur (inc i)))
        {:re res-re-array :im res-im-array}))))

(defn compute []
  (let [res (loop [i 0 a (get-zero-complex-array)]
              (if (< i 30000)
                (recur (inc i) (multiply-complex-arrays a a))
                a))]
    (aget (get res :re) 0)))