Trampoline and safe tail recursion in Go

Trampoline and safe tail recursion in Go

In the previous article "Go for Functional Programming?" we touched slightly a topic of regular recursions versus tail recursions. Also we already know that Go lacks support for the tail recursion optimization present in some other languages, thus making us vulnerable to a Stack Overflow problem, even when working with tail recursive calls only. In this article we will implement a simple Trampoline pattern that could help with that.

But before we proceed, let's recall what a tail recursion is:

Tail recursion is a recursive function in which the recursive call is the last statement that is executed by the function.

This function is tail recursive:

func summation(n, current int) int {
    if n < 1 {
        return current
    }
    return summation(n-1, n + current)
}

In contrast, this function is not tail recursive:

func summation(n int) int {
   if n < 1 {
      return 0
   }
   return n + summation(n - 1) 
}

In a regular non-tail recursion we perform our recursive call first, and then we take an output of the recursive call and calculate a result. This way we don't get a result of our calculation until we have returned from every nested recursive call.

In a tail recursion we perform our calculation first (n + current in this case) and after that we make a recursive call, passing a result of our current step to the next recursive call.

Each recursive call (both regular and tail) requires stack space to store parameters and information associated with each call. But many programming languages could automatically optimize tail recursive calls to avoid stack frame swelling: for example, a compiler could "rewrite" such functions to replace tail recursive calls with a traditional loop. However, that is not the case in Go. On the other hand, Go has a clever implementation of the stack: instead of having a fixed amount of stack memory, the space is elastic - it's growing and shrinking based on-demand, starting with some small default. But even that has a limit - having a too deep recursion eventually would lead to Stack Overflow anyway:

func summation(n, current uint64) uint64 {
    if n < 1 {
        return current
    }
    return summation(n-1, n+current)
}
summation(100000000, 0)
// fatal error: stack overflow

Since Go compiler could not help us to optimize this tail recursion automatically, is there a way to somehow do this manually? Yes, this technique is called Trampolining. The basic idea is simple: instead of actually calling our function recursively, we return a deferred description of the next call. That call returns a description of the next call, and so on, until we get a description of the final result. It could be visualized like that:

More(func() { return More(func() { return Done(result) }) })

Essentially, our recursive function becomes a description of the recursive computation which we need to run later on to actually produce a real result. Using Go4Fun functional programming library (which includes a simple Trampoline implementation) it could look like that:

func summationT(n, current uint64) Trampoline[uint64] {
    if n < 1 {
        return DoneTrampolining(current)
    }
    return MoreTrampolining(func() Trampoline[uint64] {
        return summationT(n-1, n+current)
    })
}
summationT(100000000, 0).Run()
// Output: 5000000050000000

We created our deferred recursive computation by calling summationT(100000000, 0) and then we actually ran it calling Run(). No Stack Overflow happens this time. It works because, basically, this way we exchanged a stack allocation for heap: instead of having nested recursive calls blowing up the stack, we have a chain of structures allocated in the heap.

Let's see how this simple Trampoline is implemented in the library:

type Trampoline[A any] struct {
    call   func() Trampoline[A]
    done   bool
    result A
}

func DoneTrampolining[A any](a A) Trampoline[A] {
    return Trampoline[A]{done: true, result: a}
}

func MoreTrampolining[A any](more func() Trampoline[A]) Trampoline[A] {
    return Trampoline[A]{call: more, done: false}
}

func (t Trampoline[A]) Run() A {
    next := t
    for {
        if next.done {
            return next.result
        }
        next = next.call()
    }
}

As we can see, the implementation is very straightforward: we have a simple Trampoline structure to represent a next deferred call OR a final result. Also we have a function MoreTrampolining to wrap a deferred call into this structure and DoneTrampolining to wrap a final result. Plus, we have a method Run which traverses our Trampoline until we unwrap it fully to get a final result.

Conclusion

In this article we covered a simple Trampolining technique that allows us to keep a recursive structure of the code, which is natural for some algorithms, while avoiding potential Stack Overflow issues. This Trampoline implementation is by no means comprehensive e.g. it doesn't allow us to describe recursive functions with multiple recursive calls (e.g. a canonical implementation of Fibonacci sequence calculation). This would require a more powerful Trampoline which could combine multiple Trampolines together (e.g. using Map, FlatMap combinators). It is something we might explore in a future article. But now, if you are interested to dive deeper into some hands-on functional programming in Go - I would be happy see you joining me on GitHub (github.com/ialekseev/go4fun). Thanks for reading!