Java 伪尾调用递归产生更好的性能

Java pseudo tail call recursion yields better performance

我只是在编写一些简单的实用程序来计算 linkedList 的长度,这样 linkedList 就不会托管其 size/length 的 "internal" 计数器。考虑到这一点,我有 3 个简单的方法:

  1. 迭代 linkedList 直到你点击 "end"
  2. 递归计算长度
  3. 递归计算长度而不需要return控制调用函数(使用一些尾调用递归)

下面是捕获这 3 种情况的一些代码:

// 1. iterative approach
public static <T> int getLengthIteratively(LinkedList<T> ll) {

    int length = 0;
    for (Node<T> ptr = ll.getHead(); ptr != null; ptr = ptr.getNext()) {
        length++;
    }

    return length;
}

// 2. recursive approach
public static <T> int getLengthRecursively(LinkedList<T> ll) {
    return getLengthRecursively(ll.getHead());
}

private static <T> int getLengthRecursively(Node<T> ptr) {

    if (ptr == null) {
        return 0;
    } else {
        return 1 + getLengthRecursively(ptr.getNext());
    }
}

// 3. Pseudo tail-recursive approach
public static <T> int getLengthWithFakeTailRecursion(LinkedList<T> ll) {
    return getLengthWithFakeTailRecursion(ll.getHead());
}

private static <T> int getLengthWithFakeTailRecursion(Node<T> ptr) {
    return getLengthWithFakeTailRecursion(ptr, 0);
}

private static <T> int getLengthWithFakeTailRecursion(Node<T> ptr, int result) {
    if (ptr == null) {
        return result;
    } else {
        return getLengthWithFakeTailRecursion(ptr.getNext(), result + 1);
    }
}

现在我知道 JVM 不支持开箱即用的尾递归,但是当我 运行 一些简单的测试具有 ~10k 节点的字符串链接列表时,我注意到 getLengthWithFakeTailRecursion 始终优于 getLengthRecursively 方法(约 40%)。增量是否仅归因于案例#2 的控制权正在按节点传回并且我们被迫遍历所有堆栈帧?

编辑:这是我用来检查性能数据的简单测试:

public class LengthCheckerTest {

@Test
public void testLengthChecking() {

    LinkedList<String> ll = new LinkedList<String>();
    int sizeOfList = 12000;
    // int sizeOfList = 100000; // Danger: This causes a Whosebug in recursive methods!
    for (int i = 1; i <= sizeOfList; i++) {
        ll.addNode(String.valueOf(i));
    }

    long currTime = System.nanoTime();
    Assert.assertEquals(sizeOfList, LengthChecker.getLengthIteratively(ll));
    long totalTime = System.nanoTime() - currTime;
    System.out.println("totalTime taken with iterative approach: " + (totalTime / 1000) + "ms");

    currTime = System.nanoTime();
    Assert.assertEquals(sizeOfList, LengthChecker.getLengthRecursively(ll));
    totalTime = System.nanoTime() - currTime;
    System.out.println("totalTime taken with recursive approach: " + (totalTime / 1000) + "ms");

    // Interestingly, the fakeTailRecursion always runs faster than the vanillaRecursion
    // TODO: Look into whether stack-frame collapsing has anything to do with this
    currTime = System.nanoTime();
    Assert.assertEquals(sizeOfList, LengthChecker.getLengthWithFakeTailRecursion(ll));
    totalTime = System.nanoTime() - currTime;
    System.out.println("totalTime taken with fake TCR approach: " + (totalTime / 1000) + "ms");
}
}

您的基准测试方法有缺陷。您在同一个 JVM 中执行所有三个测试,因此它们处于不同的位置。执行假尾测试时,LinkedListNode 类 已经是 JIT 编译的,因此运行速度更快。您可以更改测试的顺序,您会看到不同的数字。每个测试都应该在单独的 JVM 中执行。

让我们为您的案例编写简单的JMH microbenchmark

import java.util.concurrent.TimeUnit;

import org.openjdk.jmh.infra.Blackhole;
import org.openjdk.jmh.annotations.*;

// 5 warm-up iterations, 500 ms each, then 10 measurement iterations 500 ms each
// repeat everything three times (with JVM restart)
// output average time in microseconds
@Warmup(iterations = 5, time = 500, timeUnit = TimeUnit.MILLISECONDS)
@Measurement(iterations = 10, time = 500, timeUnit = TimeUnit.MILLISECONDS)
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.MICROSECONDS)
@Fork(3)
@State(Scope.Benchmark)
public class ListTest {
    // You did not supply Node and LinkedList implementation
    // but I assume they look like this
    static class Node<T> {
        final T value;
        Node<T> next;

        public Node(T val) {value = val;}
        public void add(Node<T> n) {next = n;}

        public Node<T> getNext() {return next;}
    }

    static class LinkedList<T> {
        Node<T> head;

        public void setHead(Node<T> h) {head = h;}
        public Node<T> getHead() {return head;}
    }

    // Code from your question follows

    // 1. iterative approach
    public static <T> int getLengthIteratively(LinkedList<T> ll) {

        int length = 0;
        for (Node<T> ptr = ll.getHead(); ptr != null; ptr = ptr.getNext()) {
            length++;
        }

        return length;
    }

    // 2. recursive approach
    public static <T> int getLengthRecursively(LinkedList<T> ll) {
        return getLengthRecursively(ll.getHead());
    }

    private static <T> int getLengthRecursively(Node<T> ptr) {

        if (ptr == null) {
            return 0;
        } else {
            return 1 + getLengthRecursively(ptr.getNext());
        }
    }

    // 3. Pseudo tail-recursive approach
    public static <T> int getLengthWithFakeTailRecursion(LinkedList<T> ll) {
        return getLengthWithFakeTailRecursion(ll.getHead());
    }

    private static <T> int getLengthWithFakeTailRecursion(Node<T> ptr) {
        return getLengthWithFakeTailRecursion(ptr, 0);
    }

    private static <T> int getLengthWithFakeTailRecursion(Node<T> ptr, int result) {
        if (ptr == null) {
            return result;
        } else {
            return getLengthWithFakeTailRecursion(ptr.getNext(), result + 1);
        }
    }

    // Benchmarking code

    // Measure for different list length        
    @Param({"10", "100", "1000", "10000"})
    int n;
    LinkedList<Integer> list;

    @Setup    
    public void setup() {
        list = new LinkedList<>();
        Node<Integer> cur = new Node<>(0);
        list.setHead(cur);
        for(int i=1; i<n; i++) {
            Node<Integer> next = new Node<>(i);
            cur.add(next);
            cur = next;
        }
    }

    // Do not forget to return result to the caller, so it's not optimized out
    @Benchmark    
    public int testIteratively() {
        return getLengthIteratively(list);
    }

    @Benchmark    
    public int testRecursively() {
        return getLengthRecursively(list);
    }

    @Benchmark    
    public int testRecursivelyFakeTail() {
        return getLengthWithFakeTailRecursion(list);
    }
}

这是我机器上的结果(x64 Win7,Java 8u71)

Benchmark                           (n)  Mode  Cnt   Score    Error  Units
ListTest.testIteratively             10  avgt   30   0,009 ±  0,001  us/op
ListTest.testIteratively            100  avgt   30   0,156 ±  0,001  us/op
ListTest.testIteratively           1000  avgt   30   2,248 ±  0,036  us/op
ListTest.testIteratively          10000  avgt   30  26,416 ±  0,590  us/op
ListTest.testRecursively             10  avgt   30   0,014 ±  0,001  us/op
ListTest.testRecursively            100  avgt   30   0,191 ±  0,003  us/op
ListTest.testRecursively           1000  avgt   30   3,599 ±  0,031  us/op
ListTest.testRecursively          10000  avgt   30  40,071 ±  0,328  us/op
ListTest.testRecursivelyFakeTail     10  avgt   30   0,015 ±  0,001  us/op
ListTest.testRecursivelyFakeTail    100  avgt   30   0,190 ±  0,002  us/op
ListTest.testRecursivelyFakeTail   1000  avgt   30   3,609 ±  0,044  us/op
ListTest.testRecursivelyFakeTail  10000  avgt   30  41,534 ±  1,186  us/op

如你所见,假尾速度与简单递归速度相同(在误差范围内),比迭代方法慢 20-60%。所以你的结果没有被重现。

如果你真的想得到的不是稳态测量结果,而是单次(没有预热)的结果,你可以使用以下选项启动相同的基准测试:-ss -wi 0 -i 1 -f 10。结果将如下所示:

Benchmark                           (n)  Mode  Cnt    Score     Error  Units
ListTest.testIteratively             10    ss   10   16,095 ±   0,831  us/op
ListTest.testIteratively            100    ss   10   19,780 ±   6,440  us/op
ListTest.testIteratively           1000    ss   10   74,316 ±  26,434  us/op
ListTest.testIteratively          10000    ss   10  366,496 ±  42,299  us/op
ListTest.testRecursively             10    ss   10   19,594 ±   7,084  us/op
ListTest.testRecursively            100    ss   10   21,973 ±   0,701  us/op
ListTest.testRecursively           1000    ss   10  165,007 ±  54,915  us/op
ListTest.testRecursively          10000    ss   10  563,739 ±  74,908  us/op
ListTest.testRecursivelyFakeTail     10    ss   10   19,454 ±   4,523  us/op
ListTest.testRecursivelyFakeTail    100    ss   10   25,518 ±  11,802  us/op
ListTest.testRecursivelyFakeTail   1000    ss   10  158,336 ±  43,646  us/op
ListTest.testRecursivelyFakeTail  10000    ss   10  755,384 ± 232,940  us/op

如您所见,第一次启动比后续启动慢很多倍。而你的结果仍然没有重现。我观察到 testRecursivelyFakeTail 实际上比 n = 10000 慢(但在预热后它达到与 testRecursively 相同的峰值速度。