Deep recursion with coroutines

Roman Elizarov
6 min readApr 25, 2020

--

Photo by Riccardo Pelati on Unsplash

Kotlin Coroutines are typically used for asynchronous programming. However, the underlying design of coroutines and their implementation in Kotlin compiler are quite universal, solving problems beyond asynchronous programming. Let’s take a look at one such problem that can be elegantly solved with coroutines— writing deeply recursive functions.

Setup

Consider a tree data structure. For this example, let’s use this simple binary tree where each Tree node has a reference to its left and right children:

class Tree(val left: Tree?, val right: Tree?)

The depth of the tree is defined as the length of the longest path from its root to its child nodes. It can be computed by the following recursive function:

fun depth(t: Tree?): Int =
if (t == null) 0 else maxOf(
depth(t.left), // recursive call one
depth(t.right) // recursive call two
) + 1

The logic here is straightforward. The depth is simply the maximum of the depth of the left and right children plus one, with the special case of zero when the tree is empty.

Recursion is a great tool for working with tree-like data structures, but there is a catch. Let’s generate a deep tree containing 100K nodes. Start with a leaf node Tree(null, null) as a seed and repeatedly generate parent nodes that link to the previous node as their left children:

val n = 100_000
val deepTree = generateSequence(Tree(null, null)) { prev ->
Tree(prev, null)
}.take(n).last()

This is not a particularly big data structure. It occupies less than 2MiB of memory, which is not much at all for a modern machine with gigabytes of available memory. Now, let’s try to use our depth function on it:

If you run it in Kotlin Playground you’ll get “Your program produces too much output!” message. If you run the same code on your local machine you’ll see what kind of output that is:

Exception in thread "main" java.lang.StackOverflowError
at FileKt.depth(File.kt:5)
... // many more lines like that

Problem

Oops! Despite the relatively small size of this tree, the depth call is running out of resources. It is written as a recursive function and it depletes a separate memory region that keeps track of thread’s call stack. A typical default limit for stack size is just a few megabytes, which is not enough here.

If the only thing we were doing in our software was to find the depth of this deepTree, then we could have reconfigured our runtime to use a larger stack size. But this kind of solution does not scale to real-life applications.

The other problem is that modern runtimes, like JVM HotSpot (that we are running this example under), are not really designed to work with large call stacks and run into all sorts of performance problems with deep recursion. Fortunately, there is another solution.

Solution

There is a lot of heap memory available in our system. Obviously, we should use the heap memory instead of a severely limited call stack memory. We can rewrite our depth algorithm by hand to avoid recursion and to use the heap. It is a mechanical and tenuous transformation whereby we write the code to create a call activation frame with the state of each ongoing call to depth in the heap and transform the sequential logic into a state-machine to avoid having to make recursive calls. The key component of this implementation is a while loop that processes the topmost call frame until there is no more work left to do.

Applying this approach, I’ve got this monster of a function. It works, but it is ugly and completely obscures the original simple algorithm that computes tree depth:

There is also another way to implement depth computation that avoids the need to explicitly maintain state using a graph depth-first traversal, but it yields a slightly different algorithm as opposed to the above code that is a direct translation of our original depth function into a non-recursive form.

A better solution

This kind of transformation of sequential logic into a state machine is exactly what happens behind the scenes in Kotlin to implement suspending functions. It can be exploited to solve the problem with deep recursion. To start, we define a wrapper class that takes a suspending block of code and turns it into a DeepRecursiveFunction<T, R> with a parameter of type T and a result of type R.

class DeepRecursiveFunction<T, R>(
val block: suspend DeepRecursiveScope<T, R>.(T) -> R
)

The block is an extension on DeepRecursiveScope class that defines a suspending function callRecursive to make recursive calls:

class DeepRecursiveScope<T, R> {
suspend fun callRecursive(value: T): R
}

In the end, we want to rewrite the depth function like this, replacing recursive calls to depth with callRecursive, but otherwise keeping all the function’s code intact:

val depth = DeepRecursiveFunction<Tree?, Int> { t ->
if
(t == null) 0 else maxOf(
callRecursive(t.left),
callRecursive(t.right)
) + 1
}

The key idea here is that callRecursive should not actually perform a regular recursive call to the block in DeepRecursiveFunction. Instead, its implementation suspends the current coroutine and saves both the current continuation and the parameter value for the recursive call to make:

suspend fun callRecursive(value: T): R = 
suspendCoroutineUninterceptedOrReturn { cont ->
this.cont = cont
this.value = value
COROUTINE_SUSPENDED
}

When the coroutine is suspended, the call stack unwinds and execution goes up to the outer loop. The purpose of this outer loop is very similar to the loop in the previous, manually crafted implementation. It initiates recursive calls when value was set by callRecursive. Essentially, it needs to call startCoroutine on a block, but the Kotlin standard library only defines startCoroutine for suspending functions with zero and one argument, while the suspending block that we have has two arguments: scope and value. To start it as a coroutine we directly employ Kotlin’s Continuation-Passing Style (CPS) conversion for suspending functions that is explained in the coroutines design document — a suspending function with two arguments is a regular function with three arguments, the last one being a continuation:

val function = block as Function3<Any?, Any?, Continuation<R>, Any?>

Now, the recursive call can be started like this:

function(this, value, cont)

The last parameter is the continuation of the current frame that becomes the completion callback of the new call. There is no need to explicitly maintain a stack. This completion reference maintains the stack implicitly. When the newly initiated call completes, it automatically continues the caller’s execution at the state where it was suspended in callRecursive.

The only piece left to handle is the case when the whole computation completes. For that we implement the Continuation<R> interface in the DeepRecursiveScope class and store the result of the whole computation:

override val context: CoroutineContext
get() = EmptyCoroutineContext

override fun resumeWith(result: Result<R>) {
this.cont = null
this.result = result
}

The DeepRecursiveScope instance itself is going to serve as the initial value for completion continuation and the whole implementation of the main loop becomes:

var cont: Continuation<R>? = thisfun runCallLoop(): R {
while (true) {
val result = this.result
val cont = this.cont // null means done
?: return result.getOrThrow()
// ~startCoroutineUninterceptedOrReturn
val r = try {
function(this, value, cont)
} catch (e: Throwable) {
cont.resumeWithException(e)
continue
}
if (r !== COROUTINE_SUSPENDED)
cont.resume(r as R)
}
}

As a final touch, we can define operator fun invoke for DeepRecursiveFunction, so that we can start the recursive computation with a regular function invocation syntax:

operator fun <T, R> DeepRecursiveFunction<T, R>.invoke(t: T): R =
DeepRecursiveScope<T, R>(block, t).runCallLoop()

What we’ve got

Having implemented deep recursive function framework using Kotlin Coroutines we can write the depth function in a direct style and run it without StackOverflowError:

Closing thoughts

While there’s not that much code, there’s a lot to understand to grasp the way it works. Reading Coroutines Design document should help with better understanding. On the other hand, you don’t have to understand how it is implemented in order to use it and benefit from it.

I have a gist (see here) with a fully productized implementation of DeepRecursiveFunction that has documentation and also supports mutually recursive functions. It might become an addition to the Kotlin standard library. Do you think it is useful or not? Leave your feedback in this Kotlin YouTrack issue, please: KT-31741.

--

--

Roman Elizarov

Project Lead for the Kotlin Programming Language @JetBrains