如何摆脱clojure/lang/RT.set和clojure/lang/RT.intCast在clojure数组处理中



我尝试在Clojure中尽可能快地进行复数数组的乘法运算。

选择的数据结构是两个元素的映射,:re:im,每个都是原始double的Java原生数组,以降低内存开销。

根据http://clojure.org/reference/java_interop,我对原始类型的数组使用了精确的类型规范。

有了这些提示,aget被转换为本地数组dload op,但是有两个低效率,循环的计数器不是int而是long,所以每次索引数组时,计数器都通过调用clojure/lang/RT.intCast转换为int。并且aset不是转换成本地操作,而是转换成对clojure/lang/RT.aset的调用。

另一个低效率是checkcast。它在每个循环中检查数组是否为double类型的数组。

结果是Clojure代码的运行时间比同等Java代码多30%(不包括启动时间)。这个函数可以在Clojure中重写,以便它更快地工作吗?

Clojure代码中,要优化的函数是multiply-complex-arrays

(def size 65536)
(defn get-zero-complex-array
    []
    {:re (double-array size)
     :im (double-array size)})
(defn multiply-complex-arrays
    [a b]
    (let [
        a-re-array (doubles (get a :re))
        a-im-array (doubles (get a :im))
        b-re-array (doubles (get b :re))
        b-im-array (doubles (get b :im))
        res-re-array (double-array size)
        res-im-array (double-array size)
        ]
        (loop [i (int 0) size (int size)]
            (if (< i size)
                (let [
                    a-re (aget a-re-array i)
                    a-im (aget a-im-array i)
                    b-re (aget b-re-array i)
                    b-im (aget b-im-array i)
                    ]
                    (aset res-re-array i (- (* a-re b-re) (* a-im b-im)))
                    (aset res-im-array i (+ (* a-re b-im) (* b-re a-im)))
                    (recur (unchecked-inc i) size))
                {:re res-re-array :im res-im-array}))))
(let [
    res (loop [i (int 0) a (get-zero-complex-array)]
            (if (< i 30000)
                (recur (inc i) (multiply-complex-arrays a a))
                a))
    ]
    (println (aget (get res :re) 0)))

multiply-complex-arrays的主循环生成的java程序集为

  91: lload         8
  93: lload         10
  95: lcmp
  96: ifge          216
  99: aload_2
 100: checkcast     #51                 // class "[D"
 103: lload         8
 105: invokestatic  #46                 // Method clojure/lang/RT.intCast:(J)I
 108: daload
 109: dstore        12
 111: aload_3
 112: checkcast     #51                 // class "[D"
 115: lload         8
 117: invokestatic  #46                 // Method clojure/lang/RT.intCast:(J)I
 120: daload
 121: dstore        14
 123: aload         4
 125: checkcast     #51                 // class "[D"
 128: lload         8
 130: invokestatic  #46                 // Method clojure/lang/RT.intCast:(J)I
 133: daload
 134: dstore        16
 136: aload         5
 138: checkcast     #51                 // class "[D"
 141: lload         8
 143: invokestatic  #46                 // Method clojure/lang/RT.intCast:(J)I
 146: daload
 147: dstore        18
 149: aload         6
 151: checkcast     #51                 // class "[D"
 154: lload         8
 156: invokestatic  #46                 // Method clojure/lang/RT.intCast:(J)I
 159: dload         12
 161: dload         16
 163: dmul
 164: dload         14
 166: dload         18
 168: dmul
 169: dsub
 170: invokestatic  #55                 // Method clojure/lang/RT.aset:([DID)D
 173: pop2
 174: aload         7
 176: checkcast     #51                 // class "[D"
 179: lload         8
 181: invokestatic  #46                 // Method clojure/lang/RT.intCast:(J)I
 184: dload         12
 186: dload         18
 188: dmul
 189: dload         16
 191: dload         14
 193: dmul
 194: dadd
 195: invokestatic  #55                 // Method clojure/lang/RT.aset:([DID)D
 198: pop2
 199: lload         8
 201: lconst_1
 202: ladd
 203: lload         10
 205: lstore        10
 207: lstore        8
 209: goto          91
Java代码:

class ComplexArray {
    static final int SIZE = 1 << 16;
    double re[];
    double im[];
    ComplexArray(double re[], double im[]) {
        this.re = re;
        this.im = im;
    }
    static ComplexArray getZero() {
        return new ComplexArray(new double[SIZE], new double[SIZE]);
    }
    ComplexArray multiply(ComplexArray second) {
        double resultRe[] = new double[SIZE];
        double resultIm[] = new double[SIZE];
        for (int i = 0; i < SIZE; i++) {
            double aRe = this.re[i];
            double aIm = this.im[i];
            double bRe = second.re[i];
            double bIm = second.im[i];
            resultRe[i] = aRe * bRe - aIm * bIm;
            resultIm[i] = aRe * bIm + bRe * aIm;
        }
        return new ComplexArray(resultRe, resultIm);
    }
    public static void main(String args[]) {
        ComplexArray a = getZero();
        for (int i = 0; i < 30000; i++) {
            a = a.multiply(a);
        }
        System.out.println(a.re[0]);
    }
}

相同循环在Java代码中的汇编:

  13: iload         4
  15: ldc           #5                  // int 65536
  17: if_icmpge     92
  20: aload_0
  21: getfield      #2                  // Field re:[D
  24: iload         4
  26: daload
  27: dstore        5
  29: aload_0
  30: getfield      #3                  // Field im:[D
  33: iload         4
  35: daload
  36: dstore        7
  38: aload_1
  39: getfield      #2                  // Field re:[D
  42: iload         4
  44: daload
  45: dstore        9
  47: aload_1
  48: getfield      #3                  // Field im:[D
  51: iload         4
  53: daload
  54: dstore        11
  56: aload_2
  57: iload         4
  59: dload         5
  61: dload         9
  63: dmul
  64: dload         7
  66: dload         11
  68: dmul
  69: dsub
  70: dastore
  71: aload_3
  72: iload         4
  74: dload         5
  76: dload         11
  78: dmul
  79: dload         9
  81: dload         7
  83: dmul
  84: dadd
  85: dastore
  86: iinc          4, 1
  89: goto          13

您如何对这些代码进行基准测试?我建议在比较时间之前使用一些类似于criterium的方法,或者至少进行多次执行。像checkcast这样的事情应该在足够热的时候由JIT进行优化。我还建议使用最新的JVM、-server和-XX:+ aggressive options。

一般来说,我发现最好不要强迫Clojure在循环中使用int型,而是将long型作为循环计数器,使用(set! *unchecked-math* true),并让Clojure在索引数组时将long型向下转换为int型。虽然这看起来像是额外的工作,但我对现代硬件/JVM/JIT的印象是,差异比您期望的要小得多(因为您主要使用64位整数)。此外,它看起来像你携带size作为循环变量,但它永远不会改变-也许你这样做是为了避免与i的类型不匹配,但我只是让size(作为一个长)在循环之前,并对i进行长增量和比较。

有时可以通过在循环之前放置一些东西来减少检查强制转换。虽然很容易观察代码并指出什么时候不需要它们,但编译器并不真正对此进行任何分析,而是将其留给JIT来优化(JIT通常非常擅长,或者在99%的代码中实际上并不重要)。

(set! *unchecked-math* :warn-on-boxed)
(def ^long ^:const size 65536)
(defn get-zero-complex-array []
  {:re (double-array size)
   :im (double-array size)})
(defn multiply-complex-arrays [a b]
  (let [a-re-array (doubles (get a :re))
        a-im-array (doubles (get a :im))
        b-re-array (doubles (get b :re))
        b-im-array (doubles (get b :im))
        res-re-array (double-array size)
        res-im-array (double-array size)
        s (long size)]
    (loop [i 0]
      (if (< i s)
        (let [a-re (aget a-re-array i)
              a-im (aget a-im-array i)
              b-re (aget b-re-array i)
              b-im (aget b-im-array i)]
          (aset res-re-array i (- (* a-re b-re) (* a-im b-im)))
          (aset res-im-array i (+ (* a-re b-im) (* b-re a-im)))
          (recur (inc i)))
        {:re res-re-array :im res-im-array}))))
(defn compute []
  (let [res (loop [i 0 a (get-zero-complex-array)]
              (if (< i 30000)
                (recur (inc i) (multiply-complex-arrays a a))
                a))]
    (aget (get res :re) 0)))

最新更新