Tail Recursion

Table of Contents

There is a special form of recursion called tail recursion that can be compiled to a loop in the machine code. This can save a significant amount of memory if the code needs to make a lot of recursive calls, and in a functional language like Haskell this optimization can make the difference between the code working or running out of memory.

To understand how this works, let’s review how a function keeps track of its information.

1. Function calls and frames

When a program calls a function, the computer allocates memory for the particular instance of that call. This memory is called a stack frame or sometimes an activation record.

For example, suppose we have the following Haskell code:

foo x =
  let a = 10
      b = 20
  in a + b + bar x
bar x =
  let a = 30
      b = 40
  in a + b + baz x
baz x =
  let a = 50
      b = 60
  in a + b + x

If we call foo 10, the program will allocate a frame for foo. When foo calls bar, the program allocates a frame for bar, and when bar called baz a frame gets allocated for baz too.

The important thing to note is that while baz is running, the frames for foo and bar are still in memory, since they are waiting for the results of baz.

When a function is recursive, we can end up with a lot of frames. Take this factorial code for example:

fact 0 = 1
fact n = n * fact (n-1)

Computing the factorial of 100 will require 100 frames.

2. Tail calls

Now consider this code.

foo x =
  let a = 10
      b = 20
  in bar (a + b + x)
bar x =
  let a = 30
      b = 40
  in baz (a + b + x)
baz x =
  let a = 50
      b = 60
  in a + b + x

When foo calls bar, note that whatever bar returns will be the value for foo as well. Similarly, when bar calls baz, whatever baz return will be the return value for bar. The function foo will not need to keep track of its variables after calling bar since it will not be using them again, and likewise when bar calls baz.

So, instead of keeping those stack frames around, the compiler will allow bar to use the same memory that foo was using, and in turn allow baz to use that same memory as well once bar calls baz. The function calls get eliminated and replaced by straight-line code.

The question is: how do we know when it’s safe to do this? The answer is that we can eliminate the function call if the call is in tail position.

3. Tail Code

There are certain kinds of expressions that have subexpressions that will become the value of the whole expression if evaluated. These subexpressions are said to be in tail position. Consider this code:

if a > b then c else d

let e = 10 + 20
 in f

foo 0 = 1
foo n = foo (n-1)

The variables c, d, and f are in tail position. The expressions a>b, 10+20, and n-1 are not in tail position. If a>b is true, and c has value 99, then the value of the whole if expression will be 99.

Note that not every expression has subexpressions that are in tail position.

Here is a non-exhaustive list of tail subexpressions:

  • The then and else branches of an if expression
  • The return values in a case expression
  • The body of a let expression
  • The return value of a function

If the expression that is in tail position happens to be a function call, then we call it a tail call. The call foo (n-1) is a tail call.

4. Accumulators

In direct-style recursion, the function makes all the calls and traverses the data structure. Once it hits the end of the data structure (or base case), it returns to the caller and performs the computations “on the way out”.

Consider this function that sums up the elements of a list.

sum [] = 0
sum (x:xs) = x + sum xs

If we apply sum to the list [2,3,4,5,6], sum will traverse to the end of the list, and then do all the adding as things get returned. We can trace the execution this way, where each line is the next call or return:

 1: sum [2,3,4,5,6]
 2: = 2 + sum [3,4,5,6]
 3: = 2 + 3 + sum [4,5,6]
 4: = 2 + 3 + 4 + sum [5,6]
 5: = 2 + 3 + 4 + 5 + sum [6]
 6: = 2 + 3 + 4 + 5 + 6 + sum []
 7: = 2 + 3 + 4 + 5 + 6 + 0
 8: = 2 + 3 + 4 + 5 + 6
 9: = 2 + 3 + 4 + 11
10: = 2 + 3 + 15
11: = 2 + 18
12: = 20

If the function is in tail form, then we need to do the computation as we make the tail calls. This usually requires an extra parameter to keep track of the computation as we go. For example:

tsum [] a = a
tsum (x:xs) a = tsum xs (x+a)

If we call tsum [2,3,4,5,6] 0 it will look like this:

1: tsum [2,3,4,5,6] 0
2: = tsum [3,4,5,6] 2
3: = tsum [4,5,6] 5
4: = tsum [5,6] 9
5: = tsum [6] 14
6: = tsum [] 20
7: = 20

Notice that the computation happens “on the way in”, and when we reach the base case the computation is complete.

5. How to write a tail recursive function

There are four steps:

  1. Add an accumulator parameter to your function.
  2. The base case returns the accumulator. It is okay if you do some computation on it.
  3. The recursive case calls with an updated value for the accumulator. Once the code makes the recurive call it cannot do anything else once it returns or else it is not it tail position.
  4. When calling the tail function initially, pass in the result that should get returned by the base case. (E.g., for tsum we called it with a 0. If was a product we would have passed in 1)

We often like to create a helper function that does the recursion, and have the main function be responsible for passing in the initial value. Here is what tsum would look like:

tsum xx = aux xx 0
  where aux []     a = a
        aux (x:xs) a = aux xs (x+a)

6. Problems

Here is some code written in direct style. Try converting them to tail recusion. The answers are below.

prod [] = 1
prod (x:xs) = x * prod xs

fact 0 = 1
fact n = n * fact (n-1)

filter f [] = []
filter f (x:xs) = if f x then x : rest
                         else rest
    where rest = filter f xs
solution for prod
prod xx = aux xx 1
  where aux []     a = a
        aux (x:xs) a = prod xs (x*a)
solution for fact
fact n = aux n 1
  where aux 0 a = a
        aux m a = aux (m-1) (m*a)
solution for filter
filter f xx = aux xx []
  where aux []     a = reverse a  -- We need to reverse the list to preserve the original order
        aux (x:xs) a = if f x
                         then aux xs (x:a)
                         else aux xs a

                         -- or --

filter f xx = aux (reverse xx) [] -- We can reverse the input first if we like.
  where aux []     a = a
        aux (x:xs) a = if f x
                         then aux xs (x:a)
                         else aux xs a

6.1. Thanks for Suggestions and Finding Typos

  • Ruben Serrano
  • Enguang Fan

Author: Mattox Beckman

Created: 2022-02-19 Sat 16:05