Functional Programming in Java, Second Edition: All the code for Chapter 8, "Using Tail-Call Optimization", in one class

All the code for Chapter 8, Using Tail-Call Optimization, in one class. But without the part where we later need to “fix the arithmetic overflow”, an ancillary problem that is solely due to the fact that the example recursive function we use is the factorial. Here we use a simple recursive sum instead.

package chapter8.tailcalls;

import org.junit.jupiter.api.Test;

import java.util.function.Function;
import java.util.stream.Stream;

public class Chapter8_TailCallOptimization {

    private final static int limit = 70000;

    // ---
    // The "very slick stack simulator" (VSSS) for optimizable (not necessarily recursive) tail calls
    // This is the code "recur/fpij/TailCall.java" on p.142
    // with a fix: the .get() on the Stream has been replaced by .orElseThrow()
    // If we only use get(), the linter tells us: "Optional.get() without "isPresent()" check".
    // ---

    @FunctionalInterface
    public interface TailCall<T> {

        TailCall<T> apply();

        default boolean isComplete() {
            return false;
        }

        default T result() {
            throw new Error("not implemented");
        }

        default T invoke() {
            return Stream.iterate(this, TailCall::apply)
                    .filter(TailCall::isComplete)
                    .findFirst()
                    .orElseThrow()
                    .result();
        }
    }

    // "recur/fpij/TailCalls.java" on p.143

    public static class TailCalls {

        // call() simply exists so that usage has a symmetric look
        // (...but I'm not sure this improves understanding)

        public static <T> TailCall<T> call(final TailCall<T> nextCall) {
            return nextCall;
        }

        public static <T> TailCall<T> done(final T value) {
            return new TailCall<T>() {
                @Override
                public boolean isComplete() {
                    return true;
                }

                @Override
                public T result() {
                    return value;
                }

                @Override
                public TailCall<T> apply() {
                    throw new Error("not implemented");
                }
            };
        }
    }

    // ---
    // Summing using a recursive call that is not a tail call (avoid if possible, though
    // it is not always possible)
    // Replaces "recur/fpij/Factorial.java" on page 140.
    // ---

    private static class SumRecursivelyNaively {

        public static long sum(final int number) {
            return (number == 1) ? 1 : (number + sum(number - 1));
        }
    }

    // ---
    // Summing using a recursive call that is a proper tail call.
    // This approach uses an accumulator while going "down/into" the recursion.
    // On the way "back up/outo of" the recursion there are only "returns".
    // This can be optimized so that only 1 stack frame is used. However, the Java compiler
    // (and/or the JVM?) does not do that fully (maybe because it needs to keep track of
    // stack frames for debugging?), so we still get stack overflow after some time.
    // ---

    private static class SumRecursivelyUsingTailCalls {

        public static long sum(final int number) {
            return sum_inner(1, number);
        }



        private static long sum_inner(final long accumulator, final int number) {
            if (number == 1) {
                return accumulator;
            } else {
                return sum_inner(accumulator + number, number - 1);
            }
        }
    }

    // ---
    // An application of VSSS - the very sly stack simulator
    // Replaces "recur/fpij/Factorial.java" on p.141
    // ---

    private static class SumRecursivelyWithVSSS {

        public static long sum(final int number) {
            return sum_inner(1, number).invoke();
        }

        // Note that this method is *never* called recursively!

        public static TailCall<Long> sum_inner(final long accumulator, final int number) {
            if (number == 1) {
                // Return the "TailCalls" instance that ends the stream
                return TailCalls.done(accumulator);
            } else {
                // "call()" does nothing, and we could just return the closure directly, but it looks nice
                return TailCalls.call(
                        // When called by Stream.iterate(), this closure is supposed to generate&return the
                        // "next TailCall instance" of the stream
                        () -> sum_inner(accumulator + number, number - 1)
                );
            }
        }
    }

    // === TESTING SUPPORT BEGINS ===

    private static boolean callSum(final boolean skip, final String name, int n, Function<Integer, Long> sum) {
        boolean skipNextTime = skip;
        if (!skip) {
            try {
                long res = sum.apply(n);
                // Properly, this test should be in the caller
                if (n == limit - 1) {
                    System.out.println("Reached the end: " + name + "(" + n + ") = " + res);
                }
            } catch (StackOverflowError err) {
                System.out.println("Stack overflow for " + name + "(" + n + ")");
                skipNextTime = true;
            }
        }
        return skipNextTime;
    }

    // Running the three approaches at summing recursively till stack overflow occurs!
    // Works best if one reduces the maximum stack size of the JVM,
    // options "-Xss1m" or "-XX:ThreadStackSize=1024" (the latter in KiB)
    // See https://docs.oracle.com/en/java/javase/17/docs/specs/man/java.html#advanced-runtime-options-for-java
    //
    // Example output:
    //
    // Stack overflow for SumRecursivelyNaively.sum(38919)
    // Stack overflow for SumRecursivelyUsingTailCalls.sum(58375)
    // Reached the end: SumRecursivelyWithVSSS.sum(69999) = 2449965000

    @Test
    public void runThem() {
        boolean skipNaiveVersion = false;
        boolean skipTailCallVersion = false;
        boolean skipVsssVersion = false;
        for (int n = 1; n < limit && !(skipTailCallVersion && skipNaiveVersion && skipVsssVersion); n += 1) {
            skipNaiveVersion = callSum(skipNaiveVersion, "SumRecursivelyNaively.sum", n, SumRecursivelyNaively::sum);
            skipTailCallVersion = callSum(skipTailCallVersion, "SumRecursivelyUsingTailCalls.sum", n, SumRecursivelyUsingTailCalls::sum);
            skipVsssVersion = callSum(skipVsssVersion, "SumRecursivelyWithVSSS.sum", n, SumRecursivelyWithVSSS::sum);
        }
    }

}