使用 java 8 设计尾递归

Designing tail recursion using java 8

我正在尝试 talk 中提供的以下示例以了解 java8 中的尾递归。

@FunctionalInterface
public interface TailCall<T> {
    TailCall<T> apply();

    default boolean isComplete() {
        return false;
    }

    default T result() {
        throw new Error("not implemented");
    }

    default T get() {
        return Stream.iterate(this, TailCall::apply).filter(TailCall::isComplete)
                                                .findFirst().get().result();
    }
}

实用程序 class 使用 TailCall

public class TailCalls {
    public static <T> TailCall<T> call(final TailCall<T> nextcall) {
        return nextcall;
    }

    public static <T> TailCall<T> done(final T value) {
        return new TailCall<T>() {
            @Override
            public boolean isComplete() {
                return true;
            }

            @Override
            public T result() {
                return value;
            }

            @Override
            public TailCall<T> apply() {
                throw new Error("not implemented.");
            }
        };
    }
}

这里是尾递归的使用:

public class Main {

    public static TailCall<Integer> factorial(int fact, int n) {
        if (n == 1) {
            return TailCalls.done(fact);
        } else {
            return TailCalls.call(factorial(fact * n, n-1));
        }
    }

    public static void main(String[] args) {
        System.out.println(factorial(1, 5).get());
        }
}

它工作正常,但我觉得我们不需要 TailCall::get 来计算结果。根据我的理解,我们可以使用以下方法直接计算结果:

System.out.println(factorial(1, 5).result());

而不是:

System.out.println(factorial(1, 5).get());

如果我遗漏了 TailCall::get 的要点,请告诉我。

示例中有错误。它只会执行普通递归,没有尾调用优化。您可以通过将 Thread.dumpStack 添加到基本情况来看到这一点:

if (n == 1) {
    Thread.dumpStack();
    return TailCalls.done(fact);
}

堆栈跟踪类似于:

java.lang.Exception: Stack trace
    at java.lang.Thread.dumpStack(Thread.java:1333)
    at test.Main.factorial(Main.java:14)
    at test.Main.factorial(Main.java:18)
    at test.Main.factorial(Main.java:18)
    at test.Main.factorial(Main.java:18)
    at test.Main.factorial(Main.java:18)
    at test.Main.main(Main.java:8)

如您所见,对 factorial 进行了多次调用。这意味着发生普通递归,没有尾调用优化。在那种情况下,调用 get 确实没有意义,因为您从 factorial 返回的 TailCall 对象中已经有了结果。


正确的实现方法是 return 一个新的 TailCall 延迟实际调用的对象:

public static TailCall<Integer> factorial(int fact, int n) {
    if (n == 1) {
        return TailCalls.done(fact);
    }

    return () -> factorial(fact * n, n-1);
}

如果您还添加了 Thread.dumpStack,那么将只有 1 次调用 factorial:

java.lang.Exception: Stack trace
    at java.lang.Thread.dumpStack(Thread.java:1333)
    at test.Main.factorial(Main.java:14)
    at test.Main.lambda[=13=](Main.java:18)
    at java.util.stream.Stream.next(Stream.java:1033)
    at java.util.Spliterators$IteratorSpliterator.tryAdvance(Spliterators.java:1812)
    at java.util.stream.ReferencePipeline.forEachWithCancel(ReferencePipeline.java:126)
    at java.util.stream.AbstractPipeline.copyIntoWithCancel(AbstractPipeline.java:498)
    at java.util.stream.AbstractPipeline.copyInto(AbstractPipeline.java:485)
    at java.util.stream.AbstractPipeline.wrapAndCopyInto(AbstractPipeline.java:471)
    at java.util.stream.FindOps$FindOp.evaluateSequential(FindOps.java:152)
    at java.util.stream.AbstractPipeline.evaluate(AbstractPipeline.java:234)
    at java.util.stream.ReferencePipeline.findFirst(ReferencePipeline.java:464)
    at test.TailCall.get(Main.java:36)
    at test.Main.main(Main.java:9)