Tail Call Optimisation in Rust
Published: 2019-10-26
Tags:
rust,
recursion
When developing a recursive algorithm regardless of the language one of the issues you as a programmer have to consider is not causing a stack overflow. In some cases you might know that the recursion won't be that deep, but there are perfectly good algorithms that can be expressed in recursive form that can and will consume all of the stack memory and then cause a crash.
There is a way to solve this - by rewriting the algorithm so that the recursive call is the last instruction in the function and the result of it is then returned unchanged you can execute this without adding a new stack frame. This is called tail call optimisation and if the language/compiler you're using supports it - you're in luck. Unfortunately, the Rust compiler doesn't support tail call optimisation in the general case. In this post I will be looking at how you can restructure your recursive functions using loops to avoid causing a stack overflow and why this is effectively the same as performing tail call optimisation.
A simple recursive algorithm
First we'll need a simple and common recursive algorithm to use as an example.
One of the easiest things that make sense to implement recursively is the all
function that takes an iterator and a predicate function and checks that all
the items in the iterator match the predicate.
fn all<I, P>(mut items: I, predicate: P) -> bool where I: Iterator, P: Fn(I::Item) -> bool, { unimplemented!() }
The recursive implementation of all
is simple - if the predicate is false for
the next item of the iterator then all
should return false. If it is true
then we need to repeat the same action on the rest of the iterator until the
iterator is empty, at which point we can say that the predicate is true for
all the items.
fn all<I, P>(mut items: I, predicate: P) -> bool where I: Iterator, P: Fn(I::Item) -> bool, { match items.next() { Some(item) => { if predicate(item) { all(items, predicate) } else { false } } None => true, } }
It is clear that in this case all
could take advantage of tail-call
optimisation as the result of the recursive call is returned as-is.
So what's the issue here? Let's say we want to test a very long (potentially
infinite) sequence. Every time all
calls itself it creates a stack frame that
consumes memory. Eventually, that leads to a stack overflow. A quick test to
show it failing:
#[test] fn long_sequence() { assert!(!all(0.., |x| x < 1_000_000)); }
This fails on my machine, but if it runs fine on yours then it's always possible to make it crash just by increasing the cut-off number (or removing it altogether and making the closure always return true).
Tail call optimisation using loop
Let's see what we can do to fix this. In case of all
the solution is quite
obvious - we can rewrite it using a loop. We can simply loop until the iterator
is exhausted or we find an item where the predicate is false. To make it clear
how this maps to the recursive style let's change the match
in the first
version to a break match
and wrap it inside a loop
. That way if the selected
match
branch returns a value we break out of the loop and return from
our function. If on the other hand it would make a recursive call we can
instead simply modify the state and use continue
to start a new iteration
of the loop. When rewriting a recursive function to use a loop this is key -
instead of passing modified arguments to a recursive call we change the state
and start a new iteration of the loop.
In the case of our all
function we don't need to modify the state explicitly,
this is done for us by the next
method. If the predicate is true we move
on to the next iteration using continue
and if it's false we simply return
false
, which breaks the loop and is used as the return value of the function.
pub fn all<I, P>(mut items: I, predicate: P) -> bool where I: Iterator, P: Fn(I::Item) -> bool, { loop { break match items.next() { Some(item) => { if predicate(item) { continue; } else { false } } None => true, }; } }
If you try running the long_sequence
test again you'll see that it no longer
fails. However, continue
is not often used and I feel that many people are
more used to using break
statements. We can rewrite the same tail-call
optimised function by switching back from break match
to a regular match
-
this way simply returning a value in a match
branch does nothing, the loop
continues. Instead we need to use a break
statement to return a value from
the loop and from our functions. The continue
statement is no longer needed
as the default behavior is to continue with the loop.
pub fn all<I, P>(mut items: I, predicate: P) -> bool where I: Iterator, P: Fn(I::Item) -> bool, { loop { match items.next() { Some(item) => { if !predicate(item) { break false; } } None => break true, }; } }
Note that since the None
branch of the match
is of unit type we also don't
need the if to return any value either. This way we can get rid of one of the
branches and only keep the one with the break
statement.
Short-circuiting logic operators
While this works, we can do away with the if
statement - in Rust both ||
and &&
operators are short-circuiting, so we can use them as simple forms of
flow control. In this case we want to break false
when the predicate is false
for the current item, so we can instead write:
pub fn all<I, P>(mut items: I, predicate: P) -> bool where I: Iterator, P: Fn(I::Item) -> bool, { loop { match items.next() { Some(item) => predicate(item) || break false, None => break true, }; } }
This looks much nicer and also more compact. But remember that we started by
using continue
instead of break
? We can use it in a similar fashion, but
since we only want to continue
if the predicate is true we have to use &&
instead.
pub fn all<I, P>(mut items: I, predicate: P) -> bool where I: Iterator, P: Fn(I::Item) -> bool, { loop { break match items.next() { Some(item) => predicate(item) && continue, None => true, }; } }
This works just fine and passes the test, but I find the && continue
part
a lot less obvious than the correspondig || break true
. Overall I think that
while the continue
style is helpful for understanding as the continue
maps directly to a recursive call in code, the break
style should normally be
preferred.
Hopefully I've convinced you that while having in-built tail call optimisation is nice it's really not that difficult to rewrite your recursive code using loops thus making it a little more robust. There is also a performance benefit and I'm planning to write another post about that, complete with some benchmarks!